It’s often useful to revisit the basics, if only to make sure you haven’t forgotten anything. In this article, I want to collect basic (and not-so-basic) results for Bayesian linear regression in one place as a reference.
To start, consider the humble linear model for an unknown function of interest \(f: \R^{d} \to \R\), with parameters \(\rvw \in \R^{d}\):
\[f(\rvx) = \rvw^{\top}\rvx. \]
We might generalise this with a basis function expansion 0: While this isn’t the topic of this article, and we only treat the single-output (i.e. univariate regression) case here, I will leave these feature expansions throughout as a nod to Bayesian neural networks and the straightforward extension of Bayesian linear regression to BNNs. 0[0] \(\phi: \R^{d} \to \R^{m}\) so as to incorporate a bias term and more general non-linear features (where we now have \(\rvw \in \R^{m}\))
\[f(\rvx) = \rvw^{\top}\phi(\rvx). \]
In forming a dataset, we might additionally conclude that our observations \(y\) are corrupted by some noise
\[y_{i} = \rvw^{\top}\phi(\rvx_{i}) + \epsilon_{i}, \]
with \(\epsilon_{i}\) drawn from some noise process.
The latent function and our noisy observations thereof might look like
Preliminaries: Maximum Likelihood and MAP
Before looking at the Bayesian methods, let’s briefly review the MLE and MAP approaches.
Maximum Likelihood
Let’s make the rather simple assumption that our observations are corrupted by some zero-mean Gaussian noise, with a fixed variance \(\sigma^{2}\). That is, \(\epsilon_{i} \sim \gN(0, \sigma^{2})\) or equivalently we have the following likelihood for each individual datapoint:
\[p(y_{i} \vert \rvx_{i}, \vtheta) = \gN(y_{i} \vert \rvw^{\top}\phi(\rvx_{i}), \sigma^{2}), \]
where we use \(\vtheta = \{\rvw, \sigma^{2}\}\) to denote all the model parameters. The joint 1: Recall that \(\rvy \in \R^{N}\), \(\rmX \in \R^{N \times d}\), \(\phi(\rmX) \in \R^{N \times m}\) and \(\rvw \in \R^{m}\) and that the identity is \(\rmI \in \R^{N\times N}\) 1[1] over our dataset of \(N\) training data points \(\gD = \{(y_{i}, \rvx_{i})\}_{i=1}^{N}\) is given by
\[\begin{align*} p(\rvy \vert \rmX, \theta) &= \gN(\rvy; \phi(\rmX) \rvw, \sigma^{2}\rmI) \\ &= (2\pi)^{-N/2} \cdot \sigma^{-N} \exp\left(-\frac{1}{2\sigma^{2}}\big(\rvy - \phi(\rmX)\rvw\big)^{\top}\big(\rvy - \phi(\rmX)\rvw\big)\right), \end{align*} \]
where the \(\sigma^{-N}\) term comes from the simplification 2: the determinant of a diagonal matrix is the product of the elements on the diagonal 2[2] of \(\vert \sigma^{2}\rmI\vert^{-1/2}\), and the \(\sigma^{-2}\) coefficient in front of the quadratic form comes from the inverse of the diagonal covariance \((\sigma^{2}\rmI)^{-1}\).
Focusing only on the weighs \(\rvw\) for now, and recalling the rules of differentiation \(\nabla_{\rvx} \rvb^{\top}\rvx = \rvb\) and \(\nabla_{\rvx}\rvx^{\top}\rmA\rvx = 2\rmA\rvx\) (Petersen et al., 2012), taking the derivative of the log-likelihood wrt \(\rvw\) yields
\[\begin{align} \nabla_{\rvw}\log p(\rvy\vert \rmX, \vtheta) &= \nabla_{\rvw} \big(y - \phi(\rmX)\rvw\big)^{\top}\big(\rvy - \phi(\rmX)\rvw\big) \\ &= \nabla_{\rvw}\left(\rvy^{\top}\rvy - 2\rvy^{\top}\phi(\rmX)\rvw + \rvw^{\top}\phi(\rmX)^{\top}\phi(\rmX)\rvw\right) \\ &= -2 \phi(\rmX)^{\top}\rvy + 2 \phi(\rmX)^{\top}\phi(\rmX)\rvw. \end{align} \]
Setting the above to \(0\) and solving for \(\rvw\) gives us the familiar result
\[\begin{equation} \label{eq:ols} \rvw_{\text{ML}} = \big(\phi(\rmX)^{\top}\phi(\rmX)\big)^{-1}\phi(\rmX)^{\top}\rvy. \end{equation} \]
Note that the maximum likelihood estimate of \(\sigma^{2}\) is just the empirical variance of the training data, and this can be used to parametrise a Gaussian used to evaluate the likelihood of a test point:
\[\begin{equation} \sigma^{2}_{\text{ML}} = \E_{y, \rvx\sim \gD}\left[(y - \rvx^{\top}\rvw_{\text{ML}})^{2}\right]. \end{equation} \]
Numerically Stable Computation
Computing \(\rvw\) exactly as set out in Equation \(\ref{eq:ols}\) proves to be numerically unstable.
One more numerically stable approach we might consider is to use a Cholesky decomposition. Here, we consider the above to be in the form \(\rvx = \rmA^{-1}\rvb\) or equivalently \(\rmA\rvx = \rvb\). Recall that the Cholesky decomposition of a positive definite matrix \(\rmA\) is the product of the lower-triangular matrix \(\rmL\) and its transpose: \(\rmA = \rmL\rmL^{\top}\) which is found in \(O(N^{3})\) time. Hence, to obtain \(\rvx\), we can first solve \(\rmL\rvy = \rvb\), followed by solving \(\rmL^{\top}\rvx = \rvy\).
# Compute w using a Cholesky decomposition phi_X = phi(X, m=8) L = torch.linalg.cholesky(phi_X.T @ phi_X) w = torch.linalg.solve(L.T, torch.linalg.solve(L, phi_X.T @ y))
We can also use a QR decomposition: here we write \(\phi(\rmX) = \rmQ\rmR\) where \(\rmQ\) is an orthogonal matrix (i.e. \(\rmQ^{\top}\rmQ = \rmI\)) and \(\rmR\) is an upper triangular matrix. We can then compute the weights as \(\rvw = \rmR^{-1}\rmQ^{\top}\rvy\).
# Compute w using a QR decomposition phi_X = phi(X, m=8) Q, R = torch.linalg.qr(phi_X) w = torch.linalg.inv(R) @ (Q.T @ y)
Alternatively, and this is particularly useful when the number of data dimensions exceeds the number of training data points, \(m \gg N\), we can work with the SVD of \(\phi(\rmX) \in \R^{N\times m}\). In particular, since the matrix is rectangular, we can work with the reduced SVD:
\[\begin{equation} \label{eq:reduced_svd} \phi(\rmX) = \rmU\rmS\rmV^{\top}, \end{equation} \]
where \(\rmU\) is an \(N \times m\) matrix with orthonormal columns (i.e. \(\rmU^{\top}\rmU = \rmI_{m}\)), \(\rmS\) is a diagonal matrix containing the \(m\) singular values and \(\rmV\) is an \(m\times m\) matrix which also has orthonormal columns \(\rmV^{\top}\rmV = \rmI_{m}\) and the property that the transpose serves as the inverse, hence \(\rmV^{-\top} = (\rmV^{\top})^{-1} = \rmV\).
Substituting SVD of \(\phi(\rmX)\) into Equation \(\ref{eq:ols}\), we can find the weights as:
\[\begin{align} \rvw &= \big(\phi(\rmX)^{\top}\phi(\rmX)\big)^{-1}\phi(\rmX)^{\top}\rvy \\ &= \big(\rmV\rmS^{\top}\rmU^{\top}\rmU\rmS\rmV^{\top}\big)^{-1}\rmV\rmS^{\top}\rmU^{\top}\rvy \\ &=\rmV^{-\top}\rmS^{-1}\rmS^{-\top}\rmV^{-1}\rmV\rmS^{\top}\rmU^{\top} \rvy \\ &=\rmV\rmS^{-1}\rmU^{\top}\rvy. \end{align} \]
# Compute w using a singular value decomposition phi_X = phi(X, m=8) U, S, VT = torch.linalg.svd(phi_X, full_matrices=False) w = VT.T @ torch.linalg.inv(torch.diag(S)) @ U.T @ y
Results
Using simple polynomial basis functions \(\phi(\rvx) = [1, \rvx, \rvx^{2}, \ldots, \rvx^{m}]\), with \(m=8\), we get the following
While this is just a toy problem, and I did cherry-pick \(m = 8\) out of a few tries, this doesn’t look too bad.
It’s easy to poke holes in this solution however; let’s reduce the number of training points to \(N=10\) and increase the model complexity to \(m=10\):
This isn’t looking so good anymore. We’re just interpolating through each of the datapoints, and producing wildly inaccurate predictions everywhere else. Introducing a prior over the model parameters for cases when there isn’t much data around can let us make slightly more sensible predictions.
Maximum A Posteriori
Let’s place a simple Gaussian prior on the model weights;
\[\begin{equation} \label{eq:map_prior} p(\rvw) = \gN(\rvw; \mathbf{0}, \tau^{2}\rmI). \end{equation} \]
Now, rather than maximising the likelihood alone, we maximise a quantity that is proporational to the posterior over the weights;
\[\begin{align} p(\rvw \vert \rvy, \rmX) &\propto p(\rvy \vert \rmX, \rvw)p(\rvw) \\ &= \gN(\rvy; \phi(\rmX)\rvw, \sigma^{2}\rmI_{N})\gN(\rvw; \mathbf{0}, \tau^{2}\rmI_{m}) \end{align} \]
Once again we will differentiate the above to get the optimal weights. Focusing on the terms inside the quadratic forms which depend on \(\rvw\), and keeping the \(\sigma^{2}\) term around, we get
\[\begin{align} \log p(\rvw \vert \gD) &\propto \frac{1}{\sigma^{2}}\left(\rvy - \phi(\rmX)\rvw\right)^{\top}\left(\rvy - \phi(\rmX)\rvw\right) + \frac{1}{\tau^{2}} \rvw^\top \rvw + \text{const.} \label{eq:map_quadratic_forms} \\ &= -2\rvy^{\top}\phi(\rmX)\rvw + \rvw^{\top}\phi(\rmX)^{\top}\phi(\rmX)\rvw + \frac{\sigma^{2}}{\tau^{2}}\rvw^{\top}\rvw + \text{const}. \end{align} \]
We now take the derivative
\[\begin{align} \nabla_{\rvw}\log p(\rvw \vert \gD) = -2 \phi(\rmX)^{\top}\rvy +2 \phi(\rmX)^{\top}\phi(\rmX)\rvw + \frac{2\sigma^{2}}{\tau^{2}}\rvw, \end{align} \]
which we set to zero and solve to get
\[\begin{align} \phi(\rmX)^\top\rvy &= \left(\phi(\rmX)^{\top}\phi(\rmX) + \frac{\sigma^{2}}{\tau^{2}}\rmI_{m}\right)\rvw \\ \rvw_{\text{MAP}} &= \left(\phi(\rmX)^{\top}\phi(\rmX) + \frac{\sigma^{2}}{\tau^{2}}\rmI_{m}\right)^{-1}\phi(\rmX)^\top\rvy \label{eq:map_weights}. \end{align} \]
It is common to denote the strength of the weight regularisation the MAP scheme affords us as \(\lambda \doteq \sigma^{2} /\tau^{2}\).
Numerically Stable Computation
We before, we can compute \(\rvw_{\text{MAP}}\) in Equation \(\ref{eq:map_weights}\) as we did before with a Cholesky decomposition of \(\phi(\rmX)^{\top}\phi(\rmX) + \lambda \rmI_{m}\).
We could also proceed using a QR decomposition, where there is a neat way to use exactly the same routines as in the maximum likelihood case with some simple data augmentation. Define \(\mathbf{0}\) as a \(m \times 1\) matrix of zeroes, and \(\mLambda = \frac{1}{\tau^{2}}\rmI\) as an \(m \times m\) diagonal matrix. Picking up from the quadratic forms in Equation \(\ref{eq:map_quadratic_forms}\), we can write the MAP objective as the usual ML objective on some augmented data \(\tilde{\rmX}\), \(\tilde{\rvy}\):
\[\begin{align} &\frac{1}{\sigma^{2}}(\rvy - \phi(\rmX)\rvw)^{\top}(\rvy - \phi(\rmX)\rvw) + \rvw^{\top}\mLambda\rvw \\ &=\frac{1}{\sigma^{2}}(\rvy - \phi(\rmX)\rvw)^{\top}(\rvy - \phi(\rmX)\rvw) + (\sqrt{\mLambda}\rvw)^{\top}(\sqrt{\mLambda}\rvw) \\ &=\begin{pmatrix}\frac{1}{\sigma}(\rvy - \phi(\rmX)\rvw) \\ -\sqrt{\mLambda}\rvw\end{pmatrix}^{\top}\begin{pmatrix}\frac{1}{\sigma}(\rvy - \phi(\rmX)\rvw) \\ -\sqrt{\mLambda}\rvw\end{pmatrix} \\ &= \left(\begin{pmatrix}\rvy / \sigma \\ \mathbf{0}\end{pmatrix} - \begin{pmatrix}\phi(\rmX)/\sigma \\ \sqrt{\mLambda}\end{pmatrix}\rvw\right)^{\top} \left(\begin{pmatrix}\rvy / \sigma \\ \mathbf{0}\end{pmatrix} - \begin{pmatrix}\phi(\rmX)/\sigma \\ \sqrt{\mLambda}\end{pmatrix}\rvw\right) \\ &\doteq (\tilde{\rvy} - \tilde{\rmX}\rvw)^{\top}(\tilde{\rvy} - \tilde{\rmX}\rvw). \end{align} \]
Hence, we can make the following data augmentations
\[\begin{equation} \label{eq:map_data_augmentation} \tilde{\rmX} \doteq \begin{pmatrix}\phi(\rmX) / \sigma \\ \sqrt{\mLambda}\end{pmatrix} \in \R^{(N+m) \times m},\hspace{2em}\tilde{\rvy} \doteq \begin{pmatrix}\rvy / \sigma \\ \mathbf{0}\end{pmatrix} \in \R^{N+m}, \end{equation} \]
where \(\mLambda = \sqrt{\mLambda}\sqrt{\mLambda}^{\top}\) is the Cholesky decomposition of \(\mLambda\), and find the MAP weights exactly as we would the ML weights, using this augmented data
\[\begin{equation} \label{eq:augmented_map} \rvw_{\text{MAP}} = (\tilde{\rmX}^{\top}\tilde{\rmX})^{-1}\tilde{\rmX}^{\top}\rvw. \end{equation} \]
If we now compute the QR decomposition of the augmented data \(\tilde{\rmX} = \rmQ\rmR\), then
\[(\tilde{\rmX}^{\top}\tilde{\rmX})^{-1} = (\rmR^{\top}\rmQ^{\top}\rmQ\rmR)^{-1} = (\rmR^{\top}\rmR)^{-1} = \rmR^{-1}\rmR^{-\top}, \]
which follows from the orthogonality of \(\rmQ\). Substituting the above into Equation \(\ref{eq:augmented_map}\) gives us
\[\begin{equation} \label{eq:map_qr} \rvw_{\text{MAP}} = \rmR^{-1}\rmR^{-\top}\rmR^{\top}\rmQ^{\top}\tilde{\rvy} = \rmR^{-1}\rmQ^{\top}\tilde{\rvy}. \end{equation} \]
Finally we can also use the reduced SVD approach with our MAP objective, which proceeds similarly to the ML case: using the same decomposition as in Equation \(\ref{eq:reduced_svd}\), we find \(\rmZ = \rmU\rmS\) which is an \(N \times m\) matrix. We can then calculate 3: We can interpret this as replacing the \(m\)-dimensional vectors with \(N\)-dimensional vectors, performing the fit as before, and then transforming the \(N\)-dimensional solution back to \(m\) dimensions by multiplying by \(\rmV\). 3[3] the MAP parameters as:
\[\rvw_{\text{MAP}} = \rmV(\rmZ^{\top}\rmZ + \lambda \rmI_{m})^{-1}\rmZ^{\top}\rvy. \]
Results
Repeating the problematic regression from above with \(N=10\) datapoints and order \(m=10\) polynomial basis expansion and \(\tau^{2} = 0.1\) gives far more reasonable results:
The Conjugate Case
After that slightly longer-than-intended introduction, we’re ready to tackle our toy regression problem using a Bayesian attack.
To do so, we will require the full posterior over the model’s parameters, found as
\[\begin{equation} \label{eq:weight_posterior} p(\vtheta \vert \rvy, \rmX) = \frac{p(\rvy \vert \rmX, \vtheta)p(\vtheta)}{p(\rvy \vert \rmX)} \end{equation} \]
Posterior Calculation 1: Fixed Likelihood Noise
For now, we will assume that we either know the noise variance or we have a way of finding it—we need only perform inference over the model weights.
Rather than the simple zero-mean form for the prior over the weights used in Equation \(\ref{eq:map_prior}\) previously, let us allow the prior to take on a slightly more general form 4: We use subscripts to indicate the value of these distribution parameters after having observed \(n\) datapoints; for the prior, we use a \(0\) subscript, while for the posterior after having observed \(\gD\) we use a \(N\) subscript. 4[4]
\[\begin{equation} p(\rvw) = \gN(\rvw; \rvm_{0}, \rmS_{0}), \end{equation} \]
where, for simplicity, we often resort to setting \(\rvm_{0} = \mathbf{0}\) and \(\rmS_{0} = \tau^{2}\rmI\).
The likelihood / observation noise model remains i.i.d.
\[p(\rvy \vert \rmX, \rvw) = \gN(\rvy; \phi(\rmX)\rvw, \sigma^{2}\rmI_{N}). \]
To find the posterior, we can use Bayes rule for linear Gaussian systems, a standard result which I derive in this section of my previous article on Gaussians. This gives us
\[\label{eq:posterior_lgs} \begin{align} p(\rvw \vert \rvy, \rmX) &= \gN(\rvw; \rvm_{N}, \rmS_{N}) \\ \rmS_{N} &= (\rmS_{0}^{-1} + \frac{1}{\sigma^{2}}\phi(\rmX)^{\top}\phi(\rmX))^{-1} \\ \rvm_{N} &= \rmS_{N}\big(\rmS_{0}^{-1}\rvm_{0} + \frac{1}{\sigma^{2}}\phi(\rmX)^{\top}\rvy\big) \\ \end{align} \]
To find the posterior predictive distribution for some test points \(\rmX_{\star}\), we marginalise out the model weights using the posterior. This can be found by applying the identity for the marginal when applying Bayes rule to a linear Gaussian system 5: or merely by noting that a) the linear transformation of a Gaussian \(\rvy = \rmA\rvx + \rvb\) for \(\rvx \sim \gN(\vmu, \mSigma)\) is also Gaussian \(\rvy \sim \gN(\rmA\vmu + \rvb, \rmA\mSigma\rmA{\top})\) and b) Gaussians are closed under convolution; \(\rvx_{1} + \rvx_{2} \sim \gN(\vmu_{1} + \vmu_{2}, \mSigma_{1} + \mSigma_{2})\) 5[5]
\[\begin{align} p(\rvy_{\star} \vert \rmX_{\star}, \rvy, \rmX) &= \int p(\rvy_{\star} \vert \rmX_{\star}, \rvw) p(\rvw \vert \rvy, \rmX)d\rvw \\ &= \int \gN(\rvy_{\star}; \phi(\rmX_{\star})\rvw, \sigma^{2}\rmI) \gN(\rvw; \vmu, \mSigma)d\rvw \\ &= \gN\big(\rvy_{\star}; \phi(\rmX_{\star})\rvm_{N}, \sigma^{2}\rmI + \phi(\rmX_{\star})\rmS_{N}\phi(\rmX_{\star})^{\top}\big) \label{eq:posterior_predictive_1} \end{align} \]
Results
Since the mean of the posterior predictive is exactly the same as the MAP prediction, the mean prediction line is no different. However we now get uncertainty estimates and since in this Gaussian case the predictive distribution is found in closed form, we can evaluate the likelihood of a test point.
Since we now have a posterior distribution over the weights, we can sample sets of weights and plot the resulting function lines:
Posterior Calculation 2: Unknown Noise Variance
When the variance \(\sigma^{2}\) of our Gaussian observation noise is unknown, we require a slightly more involved prior over both \(\rvw\) and \(\sigma^{2}\)—in this case the conjugate prior is the normal inverse gamma prior:
\[\begin{align} p(\vtheta) = p(\rvw, \sigma^{2}) &= \text{NIG}(\rvw, \sigma^{2}; \rvm_{0}, \rmS_{0}, a_{0}, b_{0}) \\ &\doteq \gN(\rvw; \rvm_{0}, \sigma^{2}\rmS_{0})\text{IG}(\sigma^{2}; a_{0}, b_{0}) \\ &= \frac{b_{0}^{a_{0}}}{(2\pi)^{m/2}\vert \rmS_{0}\vert^{1/2} \Gamma(a_{0})}(\sigma^{2})^{-(a_{0} + (m/2)+1)} \\ &\hspace{1em}\times \exp\left(-\frac{(\rvw - \rvm_{0})^{\top}\rmS_{0}^{-1}(\rvw - \rvm_{0}) + 2b_{0}}{2\sigma^{2}}\right). \end{align} \]
Recall that \(m\) is just the data dimension, resulting from our basis expansion. With this prior, and the likelihood as it was before: \(p(\rvy \vert \rmX, \vtheta) = \gN(\rvy; \phi(\rmX)\rvw, \sigma^{2}\rmI_{N})\), it can now be shown (Murphy, 2012) that the posterior takes the following form
\[\begin{align} p(\vtheta \vert \gD) = p(\rvw, \sigma^{2} \vert \rvy, \rmX) &= \text{NIG}(\rvw, \sigma^{2}; \rvm_{N}, \rmS_{N}, a_{N}, b_{N}) \\ \rmS_{N} &= \big(\rmS_{0}^{-1} + \frac{1}{\sigma^{2}}\phi(\rmX)^{\top}\phi(\rmX)\big)^{-1} \\ \rvm_{N} &= \rmS_{N}(\rmS_{0}^{-1}\rvm_{0} + \phi(\rmX)^{\top}\rvy) \\ a_{N} &= a_{0} + N/2 \\ b_{N} &= b_{0} + \frac{1}{2}(\rvm_{0}^{\top}\rmS_{0}^{-1}\rvm_{0} + \rvy^{\top}\rvy - \rvm_{N}^{\top}\rmS_{N}^{-1}\rvm_{N}). \end{align} \]
The posterior marginals come to an Inverse-Gamma distribution and a Student-T distribution:
\[\begin{align} p(\sigma^{2} \vert \gD) &= \text{IG}(a_{N}, b_{N}) \\ p(\rvw \vert \gD) &= \gT(\rvm_{N}, \frac{b_{N}}{a_{N}}\rmS_{N}, 2a_{N}). \end{align} \]
Finally the posterior predictive distribution for some test data \(\rmX_{\star}\) is a Student-T distribution:
\[\begin{align} p(\rvy_{\star} \vert \rmX_{\star}, \gD) = \gT(\rvy_{\star}; \phi(\rmX_{\star})\rvm_{N}, \frac{b_{N}}{a_{N}}(\rmI_{n} + \phi(\rmX_{\star})\rmS_{N}\phi(\rmX_{\star})^{\top}), 2a_{N}). \end{align} \]
Results
Calculating the predictive distribution, with a noise prior parametrised by \(a_{N} = 1\), \(b_{N} = 1\), we can see that it now has slightly heavier tails owing to the Student-T posterior, and we can see that this results in slightly higher variation in the \(\pm 2\sigma\) confidence intervals in comparison to the previous plot.
Running the regression with more datapoints (\(N=50\)), we can see that the predictive uncertainty quickly drops to a level commensurate to the observed noise level:
Non-Conjugate Priors
Sometimes our modelling problem will bring us to a likelihood that is no longer conjugate to our prior. This might be because a Gausssian observation noise assumption is too unrealistic, or we might be regressing the covariates to discrete observations 6: i.e. we’re solving a classification task 6[6] in which case we maximise a categorical likelihood.
In order to find the posterior in cases when the likelihood is not conjugate to the prior, we must often resort to approximation methods. We give a short overview of some of these here.
Laplace Approximation
Suppose that we write down the posterior over the model parameters \(\vtheta\), given in Equation \(\ref{eq:weight_posterior}\) as the following; making reference to an energy-based model:
\[\begin{equation} p(\vtheta \vert \gD) = \frac{1}{Z}p(\gD \vert \vtheta) p(\vtheta) = \frac{1}{Z}\exp(-\gL(\gD; \vtheta)), \end{equation} \]
where the partition function is the marginal likelihood \(Z \doteq \int p(\gD \vert \vtheta)p(\vtheta)d\vtheta\)—a usually intractable quantity in the non-conjugate setting.
The Laplace approximation, due in this context to (MacKay, 1992), proceeds by performing a second-order Taylor expansion of the log-joint \(\gL\) around the MAP parameters, \(\theta_{\text{MAP}}\):
\[\begin{align} \gL(\gD; \vtheta) \approx &\gL(\gD; \vtheta_{\text{MAP}}) + (\vtheta - \vtheta_{\text{MAP}})^{\top}\underbrace{\nabla_{\vtheta}\gL(\gD; \vtheta)\vert_{\vtheta_{\text{MAP}}}}_{\approx 0} + \nonumber \\ &\frac{1}{2}(\vtheta - \vtheta_{\text{MAP}})^{\top}\big(\nabla_{\vtheta}^{2}\gL(\gD; \vtheta)\vert_{\vtheta_{\text{MAP}}}\big)(\vtheta - \vtheta_{\text{MAP}}). \end{align} \]
In the above, the first order term disappears since the Jacobian should be zero at \(\vtheta_{\text{MAP}}\). What is left looks conveniently like the quadratic form of a Gaussian 7: up to some constant term that doesn’t depend on \(\vtheta\) 7[7] ;
\[\begin{equation} \gL(\gD; \vtheta) \approx \gL(\gD; \vtheta_{\text{MAP}}) + \frac{1}{2}(\vtheta - \vtheta_{\text{MAP}})^{\top}\big(\nabla_{\vtheta}^{2}\gL(\gD; \vtheta)\vert_{\vtheta_{\text{MAP}}}\big)(\vtheta - \vtheta_{\text{MAP}}). \end{equation} \]
and this allows us to make the Laplace approximation to the posterior, writing it as a Gaussian:
\[\begin{align} p(\vtheta \vert \gD) &\approx \gN(\vtheta; \vtheta_{\text{MAP}}, \mSigma) \label{eq:laplace_approx}, \\ \mSigma &\doteq \big(\nabla_{\vtheta}^{2}\gL(\gD; \vtheta)\vert_{\vtheta_{\text{MAP}}}\big)^{-1} \label{eq:laplace_approx_cov}, \end{align} \]
where \(\vtheta_{\text{MAP}} = \rvw_{\text{MAP}}\) as found in Equation \(\ref{eq:map_weights}\) previously, and the covariance of the Gaussian in Equation \(\ref{eq:laplace_approx}\) above is the inverse of the Hessian.
We can also recover an approximation to the marginal likelihood / model evidence, which is useful for model selection (such as the value of \(m\) we should use), as
\[\begin{equation} \label{eq:laplace_model_evidence} p(\gD) = Z \approx \exp\big(-\gL(\gD; \vtheta_{\text{MAP}})\big)(2\pi)^{m/2}(\det \mSigma)^{1/2}. \end{equation} \]
Despite the pleasing conceptual simplicity of this approximation, computing the inverse of the Hessian remains a complicated task. First, the Hessian is not guaranteed to be positive semidefinite, as required for its interpretation as a covariance matrix. This may result in ill-conditioned \(\mSigma\) matrices which are tricky to invert. Further, for models with even modest parameter counts, computing the Hessian may quickly become an expensive operation.
Many approaches have been proposed to circumvent this, ranging from diagonal approximations, to more advanced ones such as Kronecker-factored approximate curvature (Martens et al., 2015). See Daxberger et al. (2022), and the references therein, for a more comprehensive overview.
To obtain the predictive distribution using the Laplace approximation to the posteior, \(p(\vtheta \vert \rvy, \rmX)\), we must compute
\[p(\rvy_{\star} \vert \rmX_{\star}, \rvy, \rmX) = \int p(\rvy_{\star} \vert \rmX_{\star}, \vtheta) p(\vtheta \vert \rvy, \rmX)d\vtheta. \]
If the likelihood is Gaussian (or conjugate to the Gaussian), we may proceed as in Equation \(\ref{eq:posterior_predictive_1}\) to compute the integral analytically. However, the point of using the Laplace approximation in the first place is in these non-conjugate settings. For these cases, the simplest resolution is to draw \(S\) samples from our Gaussian posterior \(\{\vtheta^{(s)}\}_{s=1}^{S} \sim p(\vtheta \vert \gD)\) followed by an empirical average of the model predictions.
Variational Inference
In light of the intractable posterior distribution \(p(\vtheta \vert \gD)\), as given in Equation \(\ref{eq:weight_posterior}\), variational Bayes is a flexible, optimisation-based approximation scheme.
In particular, we introduce an approximate posterior, \(q(\vtheta)\) that is easier to work with 8: i.e. quick to sample from; having simple likelihood evaluations 8[8] . We then seek to minimise the difference (usually in KL, however other f-divergences are available) between the approximate and true posteior;
\[q^{\star} = \argmin_{q\in \gQ}\text{D}\big[q \Vert p\big]. \]
Often, rather than optimising over the space of (density) functions \(q\), we work with a parametric function with variational parameters \(\vphi\). As a trivial example, we could choose \(q_{\vphi}(\vtheta)\) to be a Gaussian \(\gN(\vtheta; \vmu, \mSigma)\), in which case the variational parameters would be the mean and covariance of this Gaussian, \(\vphi = \{\vmu, \mSigma\}\).
The objective we optimise to find \(\vphi\) is a lower bound to the model evidence. For a maximisation, we begin by writing the negative of the KL divergence 9: defined as \(\KL[q\Vert p] = \E_{q}\log(q/p)\) or, when negated, \(-\KL[q\Vert p] = \E_{q}\log(p/q)\) 9[9] between the approximate posterior and the true posterior:
\[\begin{align} - \KL\big[q_{\vphi}(\vtheta) \Vert p(\vtheta \vert \rvy, \rmX)\big] &= \int q_{\vphi}(\vtheta) \log \frac{p(\vtheta \vert \rvy, \rmX)}{q_{\vphi}(\vtheta)} d\vtheta \\ &= \int q_{\vphi}(\vtheta) \log \frac{p(\rvy \vert \vtheta, \rmX)p(\vtheta)}{p(\rvy \vert \rmX)q_{\vphi}(\vtheta)}d\vtheta \\ &= \int q_{\vphi}(\vtheta)\left(\log p(\rvy \vert \vtheta, \rmX) + \log \frac{p(\vtheta)}{q_{\vphi}(\vtheta)} - \log p(\rvy \vert \rmX)\right) d\vtheta \\ &= \int q_{\vphi}(\vtheta) \log p(\rvy \vert \vtheta, \rmX)d\vtheta - \KL[q_{\vphi}(\vtheta)\Vert p(\vtheta)] - \log p(\rvy \vert \rmX). \end{align} \]
Rearranging, and owing to the non-negative nature of the KL divergence, we can see how the above brings about a lower bound on the (log) model evidence:
\[\begin{equation} \underbrace{\log p(\rvy \vert \rmX)}_{\text{log evidence}} - \underbrace{\KL[q_{\vphi}(\vtheta) \Vert p(\vtheta \vert \rvy, \rmX)]}_{\ge 0} = \underbrace{\E_{q_{\vphi}(\vtheta)}[\log p(\rvy \vert \vtheta, \rmX)] - \KL[q_{\vphi}(\vtheta) \Vert p(\vtheta)]}_{\doteq \text{ELBO}}, \end{equation} \]
which we refer to as the evidence lower bound (ELBO) and denote \(\gL(\vphi)\).
\[\log p(\gD) = \log p(\rvy \vert \rmX) \ge \E_{q_{\vphi}(\vtheta)}\left[\log p(\rvy \vert \vtheta, \rmX) + \log p(\vtheta) - \log q_{\vphi}(\vtheta) \right] = \gL(\vphi) \]
To make predictions, we obtain the predictive distribution by marginalising over the parameters using the approximate posterior distribution with optimised variational parameters, \(q_{\vphi}(\vtheta)\). As above, if our choice of approximate posteior is conjugate to our likelihood, then we may proceed analytically; otherwise we fall back to a monte carlo approximation of this integral.
MCMC
While variational inference is generally a fast and computationally efficient inference method, our choice of the approximating posterior distribution \(q_{\vphi}(\vtheta)\) may be too simple and hence introdudce bias. Markov Chain Monte Carlo methods provide a flexible alternative method for making predictions with uncertainty, where we sample weights from the posterior.
The main idea behind a Monte Carlo simulation is to draw a set of i.i.d. samples \(\{\vtheta^{(s)}\}_{s=1}^{S} \sim p(\vtheta \vert \gD)\) from an intractable target density, which may then be used to approximate the target density with an empirical point-mass function
\[p_{S}(\vtheta \vert \gD) \approx \frac{1}{S} \sum_{s=1}^{S}\delta(\vtheta - \vtheta^{(s)}), \]
for \(\delta(\cdot)\) Dirac-delta. For Bayesian linear regression, we use these samples to calculate the empirical moments (mean, variance, etc) of the resulting function predictions \(\phi(\rmX)\rvw^{(s)}\); allowing us to make predictions without needing the intractable normalisation constant / marginal likelihood \(p(\gD) = \int p(\vtheta, \gD)d\vtheta\) that appears in Equation \(\ref{eq:weight_posterior}\).
The Markov Chain component allows an MCMC algorithm to better explore the distribution of interest; spending more time in the most important regions. Without it, samples are random and can be drawn from regions of both low and high probabilities.
Metropolis Hastings
The Metropolis-Hastings (MH) algorithm is a popular, if simple, MCMC algorithm; which many modern algorithms build on top of.
We begin by defining a proposal distribution \(q(\vtheta' \vert \vtheta)\), which proposes a new parameter value \(\vtheta'\) to visit given the current estimate \(\vtheta\). We select \(q\) to be easy to sample from and such that it covers the support of the target.
The MH algorithm then proceeds by:
- Starting \(\vtheta\) off in a random point in the parameter space
- Sampling new states from the proposal \(q(\vtheta' \vert \vtheta)\)
- Accept to move to a newly proposed point \(\vtheta'\) with probability equal
to the acceptance probability / unnormalised density ratio:
\[\gA(\vtheta, \vtheta') = \min\left(1, \frac{p(\vtheta' \vert \gD)}{p(\vtheta \vert \gD)}\right) \]
if the proposal is symmetric (i.e. \(q(\vtheta' \vert \vtheta) = q(\vtheta \vert \vtheta')\)), or apply a Hastings correction to the acceptance probability if the proposal is asymmetric:
\[\begin{align*} \gA(\vtheta, \vtheta') &= \min(1, \alpha) \\ \alpha &= \frac{p(\vtheta' \vert \gD)q(\vtheta \vert \vtheta')}{p(\vtheta \vert \gD)q(\vtheta' \vert \vtheta)}. \end{align*} \]
This works since the posteriors in the acceptance probability need only be known up to proportionality (i.e. omitting the normalising constant) since this cancels out in the ratio.
Here is a very basic yet functional Metropolis-Hastings sampler in PyTorch:
def RWMH( logpdf: Callable[[Tensor["B", "m"]], Tensor["B", "m"]], initial_pos: Tensor["B", "m"], sigma: float = 0.1 ) -> Generator[Tensor["B", "m"], None, None]: # initialise position and log-prob pos, log_prob = initial_pos, logpdf(initial_pos) while True: # Calculate the proposal (using a Gaussian q) prop = pos + t.randn_like(pos) * sigma prop_log_prob = logpdf(prop) # Calculate acceptance probability log_unif = t.rand_like(prop[..., 0]).log() accept = log_unif < prop_log_prob - log_prob # Update state pos = t.where(accept[..., None], prop, pos) log_prob = t.where(accept, prop_log_prob, log_prob) yield pos
Note that in the above, the elusively named B
(batch) dimension may be
multi-dimensional, allowing us to implement highly parallel sampling
procedures on modern GPUs for a batch of points by expanding along the
0th dimension. A simple
10: For real implementations we would include a burn in
as well as track metrics to monitor how well the chains are mixing
10[10]
piece of
code to use this is
def logpdf(pos: Tensor["B", "m"]) -> Tensor["B"]: return target_dist.log_prob(x).sum(-1) pos = t.randn((B, m)) samples = t.empty((num_samples, B, m)).to(device, dtype) sampler = RWMH(logpdf, pos, sigma=0.1) for pos, i in zip(sampler, range(num_samples)): samples[i] = pos[None, :]