(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)

I recently had a personal epiphany about machine learning, in which some of my beliefs were shown to be wrong and were replaced by something more right and more beautiful. I think many people share the wrong beliefs that I previously held, so I've felt compelled to tell people about this bit of insight ever since I learned it from David MacKay's course Information Theory, Pattern Recognition, and Neural Networks (lecture 15 from the video lectures (downloadable versions), and chapter 41 of his book). I highly recommend the rest of the course/book as well.

Setup

Let's say that I have some model for how some system works, and a bunch of "training data" collected from that system. Machine learning is supposed to take the model and data as input, and as output make predictions about future data. Here are some examples:

 Model Data Prediction Human heights are normally distributed with some mean and standard deviation. A bunch of people's heights. What's the probability a random person is between 180cm and 190cm? Given a bunch of features of a fruit (diameter, mass, color, etc.), one can determine whether it is an apple or a banana by seeing which side of some hyperplane it's on. Values of all those features for a bunch of fruits which have already been classified as apples or bananas. Given the values for a new fruit, what is the probability it is a banana? A neural network with two hidden layers with 20 nodes each can be used to determine whether a given x-ray is of somebody with bone cancer. A bunch of digitized x-rays, classified by human experts as either having bone cancer or not. Given an unclassified x-ray, how likely is the person to have bone cancer?

What I used to think

Each model has some free parameters (let's call them $\theta$ collectively, or $\{\theta_i\}$ if we want to discuss them individually). For any given values of these parameters, define some error function $E(\theta, D)$, where $D$ is the data, that measures how far off the model (with parameters set to $\theta$) is from explaining $D$. We should find the $\theta$ which minimizes this error function (by gradient descent, say), and then use this optimal version of the model to make our predictions.

To avoid overfitting any noise in our data, we typically add a regularization term in the error function. A typical regularization term is $\sum_i \theta_i^2$. This effectively says that the $\theta_i$ cannot be too large.

The right thing

We believe that some value of $\theta$ is correct. We can use Bayes's theorem to say exactly what the probability of a given value of $\theta$ is:

$$P(\theta | D) = \frac{P(D|\theta)P(\theta)}{P(D)}.$$

Here $P(D)$ is some normalizing constant, which we'll ignore. $P(\theta)$ is our prior belief about what $\theta$ might be, which looks a bit arbitrary. $P(D|\theta)$ is the likelihood function, which is presumably easily computable (otherwise even if we knew the true value of $\theta$, we'd have trouble predicting anything).

Awesome thing 1: there's an objectively correct error function. If we pretend like we have no prior (i.e. $P(\theta)$ is uniform, or approximately uniform, whatever that means), then the best choice of $\theta$ is given by maximizing the likelihood function $P(D|\theta)$. This is pretty sweet, since the error function in the previous section may have involved arbitrary decisions (e.g. "let's take the error to be the $L^2$ distance of the data from what the model predicts"). We can now say that the correct error function is $E(\theta, D) = - \log(P(D|\theta))$ (we could just use $-P(D|\theta)$, but the error function is traditionally non-negative and unbounded, so it's natural to introduce the log).

Awesome thing 2: the prior "is" regularization. From this point of view, the prior $P(\theta)$ exactly corresponds to the regularization term. If we used the regularization term $\sum_i \theta_i^2$, then $P(\theta|D)\propto P(D|\theta)P(\theta) = \exp(-E(\theta, D))\exp(-\sum_i\theta_i^2)$, so this choice of regularization exactly corresponds to the prior hypothesis that the probability of a given $\theta$ is correct is a normal distribution.

Awesome thing 3: the optimum value of $\theta$ is irrelevant. How do we actually make predictions? If we want to predict the probability of some new data $d$ (e.g. that a fruit with given diameter, mass, etc. is a banana), we should not just find the optimum value of $\theta$ and treat it as if it were the truth. After all, any particular value of $\theta$ is incredibly unlikely to be the true value, even if it is well-supported by the data. Roughly, the right thing to do is to take a vote among all theories that have not been ruled out by $D$, rather than to simply use the single theory that happens to fit $D$ best. More precisely, we don't know that $\theta_{opt}$ is the true value of $\theta$, so we should not compute $P(d|\theta_{opt})$; we know $D$, so we should make predictions by computing

$$P(d|D) = \int_\theta P(d|\theta)P(\theta|D)\,d\theta.$$

I really want to stress this, since it's the thing that blew my mind most: just using $\theta_{opt}$ is like trying to understand a probability distribution by it's most likely point. In this case, we're trying to compute the expected value of $P(d|\theta)$ under the probability distribution $P(\theta|D)$. Using the most likely point as a proxy for computing an expected value of a function sounds like madness once it's put into those terms.

Remark A. As always with Bayes, if you can execute this approach, then you end up exactly as confident in your predictions as you should be. In particular, predictions about things far away from the training set automatically become more moderate. For example, if you see a very large, nearly massless, blue fruit, you should probably be very unsure about whether it is an apple or a banana. With the old approach, it may have ended up far to one side of the optimal hyperplane for dividing apples and bananas, so you might have said it's an apple with very high probability. But now you're taking a (weighted) vote among all hyperplanes which explain the training data, and this weird point is likely on the apple side as often as it is on the banana side. (Of course, it could also be that your model is wrong, and apples and bananas are not in fact linearly separable, or there are things other than apples and bananas.)

Remark B. An interesting thing about the predictions that come out of this is that they do not correspond to any particular value of $\theta$. This seems weird, since we've made the assumption that there is some correct value of $\theta$ (perhaps because of physics), so we know that this model is not "completely correct". This is a feature, not a bug. There is no value of $\theta$ which correctly takes into account uncertainty in what the true value of $\theta$ is.

How to do the right thing

There is a remaining implementation problem, which is that the integral above is uncomputable. We can plug in our computed value for $P(\theta|D)$, but the denominator $P(D)$ typically ends up being difficult or impossible to compute. Remarkably, there's a way to get around this problem. If we could generate a bunch of sample values $\{\theta^1, \theta^2, \dots, \theta^N\}$ from the distribution $P(\theta|D)$. Then the expected value of $P(d|D)$ is well approximated by $\frac{1}{N}\sum_i P(d|\theta^i)$.

Awesome thing 4: it's possible to sample from an "uncomputable" distribution. We can't compute $P(\theta|D)$, but we can compute it up to a normalizing constant: $P(\theta|D)\propto P(D|\theta)P(\theta)$. It turns out that this is enough. The basic idea is that if you know how to sample from some distribution $F$, and you want to be able to sample from another distribution $G$, you need some "flow" which turns $F$ into $G$. Then you can sample from $F$, and apply the flow for a while to get a sample from $G$. If you can compute $G$ up to a scalar multiple, then you can produce such a flow by doing a random walk so that the probability stepping from $a$ to $b$ is $G(b)/G(a)$ times the probability of stepping from $b$ to $a$. (You also have to make sure that every point is reachable.) One way to do this is the Metropolis algorithm, in which you do a random walk, sometimes rejecting a step.