Bayesian neural network (BNN) priors have become a contentious subject lately.

In BNNs, we first pick a prior \(\P(\rvw)\) to describe the network weights, having not yet seen any data, and proceed compute the weight posterior as the data comes in \(\P(\rvw \vert \gD) = \P(\gD \vert \rvw)\P(\rvw)/\P(\gD)\).

In practice, this choice is almost always an isotropic Gaussian \(\gN(\mathbf{0}, \sigma^{2}\rmI)\), centred at 0.

If you ask most Bayesians whether this zero-mean isotropic Gaussian prior accurately reflects their beliefs about the BNN weight distribution, or whether it yields the best generalisation, or the most accurate uncertainty estimates; they would probably say no. This is somewhat troubling, and keeps Bayesians up at night.

Signs that something might be going wrong are abound: from the often worse empirical
performance of BNNs compared to identical networks trained with
SGD
0: not to mention, a ballooning of the computational cost in return
for your worse performance.
^{0[0]}
(Heek *et al.*, 2019), to the
*cold posterior effect* (Wenzel *et al.*, 2020; Aitchison, 2020) which effectively discounts the influence
of the prior in return for improved performance
1: This in particular is very
troubling for the Bayesian method, since if the generative model is accurate,
then Bayesian inference is optimal, then and any ad-hoc changes to the posterior
should harm performance, not improve it
(Gelman *et al.*, 2003).
^{1[1]}
, and several others
(Papamarkou *et al.*, 2024).

To illustrate some of these issues slightly more concretely, consider
variational inference; a popular approach to approximate the intractable
posterior distribution of a finite-width BNN
(Blundell *et al.*, 2015). Variational approaches minimise an
objective of the form

\[\begin{equation} \gF_{\mathrm{VFE}} = \underbrace{-\E_{\Q(\rvw)}\left[\log \P(\rvy \vert \rmX, \rvw)\right]}_{\text{expected log-likelihood}} + \underbrace{\KL\left[\Q(\rvw) \Vert \P(\rvw)\right]}_{\text{complexity penalty}}. \end{equation} \]

However, Trippe *et al.* (2018) suggest that
minimising the KL complexity penalty above—the difficulty of which is
in large part determined by our choice of prior \(\P(\rvw)\)—can lead to a
pathological behaviour which they term *overpruning*. By setting the final
hidden-to-output weights of an output neuron to \(0\), we can effectively cut off
the influence of all the upstream weights in the network on the output
prediction, hence turning these upstream weights into parameters we are free to
move around without affecting the output. For wider
2: as determined by the
relative width of the hidden layers to the output
^{2[2]}
overparametrised networks,
there will be more such weights available. The network can then minimise the
KL term by ensuring that lots of the weights match the prior exactly,
while simultaneously ensuring that the resulting hidden activations aren’t used
in the output by pruning them at the last layer so as to not mess up the
predictions. The result is that the active, un-pruned network may be much
smaller than expected, while the pruned weights, whose only function is to keep
the KL small, end up costing us lots of memory and compute to maintain.
Further, Coker *et al.* (2022) show that, under certain
conditions in wide BNNs with isotropic priors, the variational posterior
predictive reverts to the prior predictive: effectively ignoring the data.

Many of the factors causing the problems listed above conspire to give a learning objective where the optimal solution is often to collapse the posterior to the prior. This is clearly bad; resulting in unnecessarily expensive and wasteful networks, with high bias which underfit the data.

## Centering the Prior on the Weight Initialisation

The idea I’d like to moot in this blog post is that the weight prior should be centred at the neural network weight initialisation locations, not \(\mathbf{0}\).

While this certainly wouldn’t be a panacea to solve all the issues listed above, it might be a first step in the right direction. This idea isn’t entirely new either, having already been used to show that such an approach improves PAC-Bayes generalisation bounds (Pitas, 2020).

Why might this be a good idea? We know that the weights aren’t initialised at zero, and for the most part they don’t seem to end up around zero after training, so centring the prior at \(0\) may be ill advised.

We also know that during training, particularly as networks grow wider, the
weights may not end up very far from their initialisation, and moreover the
weight progression follows simple linear trajectories
(Jacot *et al.*, 2018). For instance, the *video* below, taken
from this excellent blog post, shows the
evolution of a single \(m\times m\) weight matrix during MLP training, for
successively wider networks \(m \in \{10, 100, 1000\}\):

As you can see, the weight changes in the visualisation are almost imperceptible for all but the smallest \(10 \times 10\) example.

This new *weight-init conditioned* prior is conceptually simple: supposing we
have freshly sampled a set of neural network weights \(\rvw_{0}\) according
to the best practices in deep learning theory
(Glorot *et al.*, 2010; He *et al.*, 2015),
we could treat these as the mean of our prior and
use \(\gN(\rvw_{0}, \mSigma_{\mathrm{prior}})\), for some prior covariance which
we might simply set to \(\mSigma_{\text{prior}} \doteq \sigma_{\text{prior}}^{2}\rmI\) to keep things cheap and simple for now
3: I
defer the choice of more sensible / sophisticated covariance terms to other
work; there are certainly many costs and downsides to the mean-field /
independence assumption, but we won’t deal with them here.
^{3[3]}
.

This prior is valid in the sense that it doesn’t make use of any of the data, however it is still odd for a prior to be random. Perhaps we ought to marginalise over that source of randomness? In this case, we’d end up back at the standard zero-mean prior. We’ll investigate these issues below.

## Use in Variational Inference

Let’s drop this prior into a simple variational inference scheme. Denote the actual, vectorised neural network parameters as \(\rvw\), and the random initialisation as \(\rvw_{0}\).

Further, let \(\rmX\) be all the inputs, and \(\rvy\) all the outputs in our dataset. We will use the following model:

\[\begin{align} \P(\rvw) &= \gN(\rvw; \mathbf{0}, \sigma_{\mathrm{prior}}^2\rmI) \\ \P(\rvw_{0}) &= \gN(\rvw; \mathbf{0}, (1-\gamma^{2})\sigma^{2}_{\mathrm{prior}}\rmI) \\ \P(\rvw \vert \rvw_{0}) &= \gN(\rvw; \rvw_{0}, \gamma^{2}\sigma^{2}_{\mathrm{prior}}\rmI), \end{align} \]

where \(0 < \gamma^{2} < 1\). This gives us a prior which is conditioned on the weight initialisation, \(\P(\rvw \vert \rvw_{0})\), and we can use variational inference to learn the weight posterior \(\P(\rvw \vert \rvw_{0}, \rmX, \rvy)\).

Is this model valid, and can we simply proceed to use it in a standard VI scheme?

**Theorem 2.1.** *Marginalising over the weight initialisation distribution,
the variational objective formed using the weight-init conditioned prior
\(\P(\rvw \vert \rvw_{0})\) still forms a lower-bound on the marginal
likelihood.*

*Proof:* let us denote the new ELBO as \(\gL(\rvw_{0})\), to make the
initialisation-conditioning explicit—we now wish to show that this quantity
indeed still lower-bounds the log marginal likelihood:

\[\begin{equation} \label{eq:cond-elbo} \log \P(\rvy \vert \rmX, \rvw_{0}) \ge \gL(\rvw_{0}). \end{equation} \]

By marginalising out the random weight initialisations, we can write the standard marginal likelihood as:

\[\log \P(\rvy \vert \rmX) = \log \int \P(\rvy \vert \rmX, \rvw_{0})\P(\rvw_{0})d\rvw_{0}. \]

Now, substituting \(\P(\rvy \vert \rmX, \rvw_{0}) \ge \exp\left(\gL(\rvw_{0})\right)\) from the inequality in Equation \(\ref{eq:cond-elbo}\), we can see that

\[\begin{align} \log \P(\rvy \vert \rmX) &\ge \log \int \exp(\gL(\rvw_{0}))\P(\rvw_{0}) d\rvw_{0} \\ &= \log \E_{\rvw_{0}}\left[\exp(\gL(\rvw_{0}))\right] \\ &\ge \E_{\rvw_{0}}\left[\gL(\rvw_{0})\right], \end{align} \]

where we have applied Jensen’s inequality on the final line.

\(\square\)

In other words, yes, by marginalising out the weight initialisation, the new prior still yields a lower bound on the model evidence.

We can further optimise the hyperparameters using this ELBO. The two main ones are the \(\gamma^{2}\) parameters, and the prior variance. To do this optimisation, we can calculate the initial weights using the reparametrisation trick:

\[\begin{equation} \rvw_{0} = \sqrt{(1 - \gamma^{2})\sigma^{2}}\boldsymbol{\xi} \end{equation} \]

where we sample \(\boldsymbol{\xi}\) once during network initialisation. We can then learn the hyperparameters of the weight initialisation distribution, \(\gamma^{2}\) and \(\sigma^{2}\) using SGD on the normal ELBO objective, by merely including them in a parameter group of our optimiser.

### Some Empirical Results

Loosely inspired by the setup of
(Gal *et al.*, 2016; Trippe *et al.*, 2018),
the following plots are from a 3-layer tanh fully-connected network trained
with AdamW on the Boston UCI dataset.

As a sanity check, the following RMSE results were achieved on the UCI datasets, roughly in-line with what we would expect:

method |
boston | concrete | energy | kin8nm | naval | power | protein | wine | yacht |
---|---|---|---|---|---|---|---|---|---|

Weight-Init Prior |
2.84 | 4.72 | 0.47 | 0.06 | 0.00 | 3.95 | 3.85 | 0.68 | 0.84 |

Dropout (Gal et al., 2016) |
2.90 | 4.82 | 0.54 | 0.08 | 0.00 | 4.01 | 4.27 | 0.62 | 0.67 |

DGP (Salimbeni et al., 2017) |
2.92 | 5.65 | 0.47 | 0.06 | 0.00 | 3.68 | 3.72 | 0.63 | — |

Now taking just one split of the Boston dataset, we observe the following empirical results.

#### Difference in KL

One of the motivations for using this prior, both from the PAC-Bayes
theory (Pitas, 2020) and the overpruning
results of Trippe *et al.* (2018) was that it would
lead to a lower KL divergence.

Plotting the KL divergence term in the variational objective, for both a standard zero-mean isotropic Gaussian prior, and our weight-init conditioned prior shows that the KL divergence is indeed lower:

#### Weights Deviate Less Far From the Prior

Another property that we might expect is that, on average, the mean of the
weight posterior after training is closer to the weight-init prior than the
zero-mean prior. The following is a *historgram*
4: I’ve plotted the histogram
with lines rather than bars.
^{4[4]}
showing that, indeed, the final weights tend to
deviate less far from the prior when using the weight-init prior as opposed to
the zero-mean prior.

#### Training and Validation MSE Improves

This plot is straightforward; the MSE loss when using the weight-init
conditioned prior is consistently lower than when using the zero-mean prior.
The difference is larger between the training MSE lines
5: the wiggles
in the training lines are due to the minibatches; the validation points are
averaged across all mini-batches
^{5[5]}
than it is for the validation MSE
lines, yet it is maintained nonetheless.

#### Effect on the Cold Posterior Effect

Finally, we can have a go at tempering the posterior (Aitchison, 2020),

\[\begin{equation} \log \P_{\text{tempered}}(\rvw \vert \rmX, \rvy) = \frac{1}{\lambda} \log \P(\rmX, \rvy \vert \rvw) + \log \P(\rvw) + \text{const}, \end{equation} \]

which, when used in the VI scheme reduces to discounting the effect of the KL

\[\begin{equation} \gL = \E_{\Q(\rvw)}\left[\log\P(\rmX, \rvy \vert \rvw)\right] - \lambda \KL\left[\Q(\rvw) \Vert \P(\rvw)\right]. \end{equation} \]

Setting different values of \(\lambda\) (denoted `t`

in the legend below) yields:

Perhaps the above isn’t surprising; if the KL term is smaller to begin with (relative to the variational log-likelihood term), then scaling it up and down should have a less pronounced effect.

### Discussion

At risk of painting an overly optimistic picture with the above, it turns out that this prior is not in fact very effective.

To see why, consider that the effectiveness of this scheme depends on selecting just the right prior variance. If it is too broad, then any benefits from the non-zero prior disappear as the effect of the prior is dominated by its scale rather than its location. On the other hand if we make it very narrow, then the KL term can be driven to be very small indeed, however this comes at the cost of using a needlessly restrictive prior, which applies extremely strong regularisation towards the network initialisation, and hurts performance.

The plot below shows the effect of varying the prior variance: lighter lines correspond to a narrower prior, while darker lines correspond to a broader / less informative prior. The red lines are for the weight-init conditioned prior, while the blue lines are for the normal zero-mean prior.

As you can see, while the absolute performance, measured in MSE, of the weight-init prior is generally better than the zero mean prior, we can fairly directly control the size of the advantage by modulating the prior standard deviation. A smaller prior variance corresponds to an advantage at the cost of worse absolute MSE performance, and severely underestimated predictive uncertainty (not plotted).

With a very narrow prior centred at a given network initialisation, one might
merely view this scheme as a single element of an ensemble—it would be
vastly simpler, and probably less costly, to just use a deep ensemble
(Lakshminarayanan *et al.*, 2017), or Monte-Carlo dropout
(Gal *et al.*, 2016) which purports to implicitly create
an ensemble.

On the other hand, perhaps our heavy-handed treatment of the prior covariance—not only the independence assumption, but also the manual initialisation without much thought—is at fault. Future work may fix this by moving beyond the independence assumption, or identifying dimensions which ought to be broader or narrower as a function of the initialisation or training dynamics.

Regardless, the somewhat degenerate behaviour exhibited above leaves doubts as
to whether this prior is useful at all. I would advise against using it in its
current form, although with some work on the covariance it might be salvaged.
Moreover, the investigation above only considered VI—*other Bayesian
approximations are available* and there is a slim chance that using MCMC,
SGLD or a Laplace approximation the results might work out differently.