In machine learning, methods that use the second derivative, or *Hessian*, of
some function approximator are not generally as commonly used as those that
merely use first derivatives, or the *Jacobian*.

This curvature information
0: and related quantities such as the Generalised
Gauss-Newton or Fisher Information Matrix
^{0[0]}
has many useful applications; from
faster or more ‘stable’ optimisation, to uncertainty quantification, and several others which we will touch on below.

Despite these appealing properties, the additional cost of calculating the
curvature
1: or, indeed, the complexity incurred while implementing less
costly schemes
^{1[1]}
has largely held-back these second order methods from broader
adoption compared to first order methods.

Nonetheless, there are several good approximations to the Hessian that can
drastically reduce the computational and memory cost of working with these
methods. In this article, we will focus on Kronecker-Factored Approximate
Curvature (K-FAC) in particular
(Martens *et al.*, 2015), and I hope to convince you that it
isn’t a particularly complicated method to implement
2: despite an apparent
reputation
^{2[2]}
.

We will begin by motivating the use of the Hessian from the point of view of a simple gradient-based optimiser in Section 1, as well as proximal optimisation in Section 2. We then look at K-FAC in Section 3.

As a final introductory note, I must mention that much of what follows draws on material in Roger Grosse’s excellent set of notes on neural net training dynamics, which you should certainly refer to if you wish to go into more depth on any of the topics below.

## Second-Order Methods and Optimisation

Consider the following, fairly standard machine learning setup: we have some data distribution of interest \(\P_{\text{data}}\), and a training dataset that we have sampled from it \(\gD = \{(\rvx_{i}, \rvt_{i})\}_{i=1}^{N}\), consisting of inputs \(\rvx\) and targets \(\rvt\) (where perhaps both the inputs and targets have been corrupted by observation noise).

We also pick a scalar loss function \(\ell(\rvy, \rvt)\) which tells us how
unhappy we are
3: in other words, the difference in the prediction and the
target, for some appropriate distance metric / likelihood in the probabilistic
view
^{3[3]}
with the prediction \(\rvy = f(\rvx; \vtheta)\) made by our parametric
function approximator \(f\), which has \(n\) parameters \(\vtheta \in \R^{n}\).

Our true objective is to minimise the generalisation loss or *risk*, on data
sampled from our distribution of interest
4: which may be non-stationary or
drift over time
^{4[4]}
, with respect to the network parameters:

\[\begin{equation} \label{eq:risk} \gR(\vtheta) = \E_{(\rvx, \rvt) \sim \P_{\text{data}}}\left[\ell\big(f(\rvx; \vtheta), \rvt\big)\right]. \end{equation} \]

However, given that we only have our finite data sample, we instead content
ourselves with minimising the *empirical* risk, \(\gJ\) using our finite sampled
training dataset:

\[\begin{equation} \label{eq:empirical_risk} \gJ(\vtheta; \gD) = \frac{1}{N}\sum^{N}_{i=1} \ell\big(f(\rvx_{i}; \vtheta), \rvt_{i}\big). \end{equation} \]

In practice, the vast majority of procedures used in machine learning to
minimise objectives of the above form are based on *first-order* optimisation
methods, such as stochastic gradient descent (SGD). These make use of the
Jacobian
5: Or, more realistically, a sequence of vector-Jacobian products
such that we don’t realise the full Jacobian matrix. See my previous
article on the topic for more on this.
^{5[5]}
(i.e. first
derivative) of the objective, \(\nabla\gJ(\vtheta)\) to perform updates to the
weights \(\vtheta\).

Most of the common (better-performing) variants of SGD such as Adam
(Kingma *et al.*, 2015), AdamW
(Loshchilov *et al.*, 2018) and numerous others, which have been
responsible for most of the breakthrough advances in deep learning over the
past decade (Krizhevsky *et al.*, 2012; Silver *et al.*, 2016), all however remain first-order methods
6: Note
that these may be viewed within a framework of approximate second-order
optimisation; as elucidated by the *Bayesian learning rule* of
(Khan *et al.*, 2023).
^{6[6]}
.

The popularity of these first-order optimisation algorithms in ML is no doubt due to their simple implementation, and low computational and memory cost. However, they are not without their shortcomings, which a move to second-order methods goes some way towards rectifying. To see this, we will begin by examining Taylor approximations of the objective function.

### Taylor Approximations

One way of distinguishing between first- and second-order optimisation algorithms in machine learning might be to say that they are respectively making first- and second-order Taylor series approximations to our objective function.

Suppose that the current (initial) model parameters are \(\vtheta_{0}\), and we
make a first-order Taylor series expansion around this point
7: On notation;
we have dropped the \(\gD\) parameter from \(\gJ\) since the dataset should be
clear from the context. The notation \(\nabla_{\vtheta}\gJ(\vtheta_{0})\) is a
shorthand for \(\frac{\delta \gJ(\theta)}{\delta \vtheta} \vert_{\theta = \vtheta_{0}}\).
^{7[7]}
:

\[\begin{equation} \label{eq:first_order} \gJ^{(1)}(\vtheta) \approx \gJ(\vtheta_{0}) + \nabla_{\vtheta}\gJ(\vtheta_{0})^{\top}(\vtheta - \vtheta_{0}). \end{equation} \]

By stopping at the first order term, this approximation assumes that the
objective is linear around the point at which we have expanded, \(\vtheta_{0}\).
Clearly, for most non-trivial functions
8: which certainly includes neural
network objective functions
^{8[8]}
, we don’t need to move very far away from
\(\vtheta_{0}\) for this approximation to become inaccurate.

If we treat this approximation as locally correct however, within some small radius \(\eta\), we can optimise the first-order approximation in Equation \(\ref{eq:first_order}\) by moving \(\vtheta\) in the direction where the loss decreases most rapidly. Taking derivatives of Equation \(\ref{eq:first_order}\) wrt. \(\vtheta\), this is simply the direction of the negative gradient, yielding the familiar first-order update used in SGD:

\[\begin{equation} \label{eq:sgd_update} \vtheta_{t+1} = \vtheta_{t} - \eta \nabla_{\vtheta}\gJ(\vtheta_{t}). \end{equation} \]

The biggest problem with this is that from Equation \(\ref{eq:first_order}\) alone, we have no good way of telling how big the step size \(\eta\) should be—this depends on the curvature of \(\gJ(\vtheta)\) around the point at which we have expanded it. If the objective function is ‘flat’ (low curvature) in a given direction, then we ought to take a large step in that direction, while if the objective is ‘sharp’ (high curvature) in a particular direction then we ought to take a small step.

In one dimension, the issue of how big of a step size along the gradient direction to update \(\vtheta_{0}\) to its new value \(\vtheta_{1}\) might be illustrated as follows:

This is particularly problematic when the optimal value of \(\eta\) changes
depending on which direction we move in (i.e. if one dimension is very steep
and another is very shallow)
9: You might also wonder what would happen if the
gradient pointed the ‘wrong way’—that is, if we got trapped in a local
minimum. For first-order methods, *‘momentum’* is a simple modification to
vanilla SGD which helps to solve this issue. This involves initialising a
*velocity* vector \(v_{0} = 0\) and some momentum coefficient \(\mu \in [0, 1]\) in
addition to the learning rate \(\eta\). Now, at each iteration, we first update
the velocity \(v_{t+1} = \mu v_{t} + \eta \nabla_{\vtheta}\gJ(\vtheta_{t})\)
before updating the parameters \(\vtheta_{t+1} = \vtheta_{t} - v_{t+1}\).
^{9[9]}
.

To make progress, suppose that we now extend our Taylor approximation to include the second-order term, resulting in a quadratic approximation to the objective function:

\[\begin{equation} \label{eq:second_order} \gJ^{(2)}(\vtheta) = \gJ(\vtheta_{0}) + \nabla_{\vtheta}\gJ(\vtheta_{0})^{\top}(\vtheta - \vtheta_{0}) + \frac{1}{2}(\vtheta - \vtheta_{0})^{\top}\rmH(\vtheta - \vtheta_{0}), \end{equation} \]

where \(\rmH \doteq \nabla_{\vtheta}^{2}\gJ(\vtheta_{0})\) is the
*Hessian* matrix containing second derivatives of our scalar objective
function. Equation \(\ref{eq:second_order}\) now gives us a convex proxy
objective / quadratic approximation to our objective function. By solving this
proxy objective, we now know how big of a step to take to reach its minimum.

Note that optimising this quadratic proxy objective won’t necessarily get us to the minimum of our true objective, unless it is convex too. It does however resolve the question of how big a step size to take—namely, we just go straight to the minimum of our quadratic proxy by solving it analytically.

Assuming for now that this proxy objective is good enough
10: and that it
is indeed convex i.e. \(\rmH \succeq \mathbf{0}\).
^{10[10]}
, we proceed to optimise it by
differentiating it. Letting \(\rmJ_{0} \doteq \nabla_{\vtheta}\gJ(\vtheta_{0})\)
for brevity, we get

\[\begin{align} \nabla_{\vtheta}\gJ^{(2)}(\vtheta) &= \nabla_{\vtheta}\gJ(\vtheta_{0}) + \nabla_{\vtheta}\left[\rmJ_{0}^{\top}(\vtheta - \vtheta_{0})\right] + \nabla_{\vtheta}\left[\frac{1}{2}(\vtheta - \vtheta_{0})\rmH(\vtheta - \vtheta_{0})\right] \\ &= \nabla_{\vtheta}\rmJ_{0}^{\top}\vtheta + \nabla_{\vtheta}\frac{1}{2}\vtheta^{\top}\rmH\vtheta - \nabla_{\vtheta}\frac{1}{2}\vtheta^{\top}\rmH\vtheta_{0} - \nabla_{\vtheta}\frac{1}{2}\vtheta_{0}^{\top}\rmH\vtheta \\ &= \rmJ_{0} + \rmH\vtheta - \rmH\vtheta_{0} \\ &= \nabla_{\vtheta}\gJ(\vtheta_{0}) + \rmH(\vtheta - \vtheta_{0}), \label{eq:proxy_derivative} \end{align} \]

where in the second line we have expanded the brackets and dropped all terms that do not depend on \(\vtheta\).

Now, setting Equation \(\ref{eq:proxy_derivative}\) above to zero, and solving
for \(\vtheta\) gives \(\vtheta = \vtheta_{0} - \rmH^{-1}\nabla_{\vtheta}\gJ(\vtheta_{0})\). Re-writing the variables to give an
iterative update rule, and including a step size parameter \(\eta\) for
flexibility, we obtain the *Newton-Raphson* update rule.

### Newton's Method

Newton’s iterative method of updating the weights proceeds as:

\[\begin{equation} \label{eq:newton_raphson} \vtheta_{t+1} = \vtheta_{t} - \eta\rmH^{-1}\nabla_{\vtheta}\gJ(\vtheta_{t}) \end{equation} \]

The above is similar to the first-order update rule, however the inverse
Hessian now acts as a *pre-conditioner* \(\rmH^{-1}\) which re-scales the
gradient in different directions, providing an appropriate scaling of the step
size in each direction.

You might be wondering what happens if the assumption that \(\rmH \succeq 0\) we
made above does not hold: namely that the Hessian or *curvature* of our
objective is negative in one or more directions at the point around which we
expand it
11: alternatively, the Hessian has some negative
eigenvalues
^{11[11]}
. Pictured visually (again, only in 1D), the quadratic proxy
points in the wrong direction and becomes rather less useful:

Hence, the Newton update rule, as presented above, is not apt to decrease the cost
function at each iteration, nor is it guaranteed to converge efficiently, if at
all. When the function is concave we might update the weights in the wrong
direction. Further, when we Taylor-expand around a point with low curvature, we
multiply the gradient by a very big term which may result in a very large step
size, and cause training instabilities
12: In the example below, as a happy
coincidence we end up close to the optimum, however this isn’t always the case:
the Hessian estimate may further be inaccurate causing us to take a big step in
the wrong direction, and it may cause us to completely overshoot the global
optimum.
^{12[12]}
:

One simple way around these constraints is to assume that, while at an
unfortunate iteration the objective might be non-convex or even have low
positive-curvature, if we average the update over a sufficiently
large number of past updates, it will tend to point in the right direction and
not be too big. Hence by *damping* the updates, we can prevent them from moving
too far from the current position, and hence ‘minimise the damage’ from a bad
update at a given iteration.

This can be done by introducing a quadratic penalty from the previous update to
our usual second-order Taylor series approximation of the objective
13: or,
alternatively, placing a Gaussian prior on the previous weights with precision
\(\eta\)
^{13[13]}
:

\[\begin{align} \vtheta_{t+1} &= \argmin_{\vtheta} \gJ^{(2)}(\vtheta) + \frac{\eta}{2}\Vert \vtheta - \vtheta_{t}\Vert^{2} \\ &=\argmin_{\vtheta}\nabla_{\vtheta}\gJ(\vtheta_{t})^{\top}\vtheta + \frac{1}{2}(\vtheta - \vtheta_{t})^{\top}(\rmH + \eta \rmI)(\vtheta - \vtheta_{t}) \\ &= \vtheta_{t} - (\rmH + \eta\rmI)^{-1}\nabla_{\vtheta}\gJ(\vtheta_{t}). \end{align} \]

While the above makes some progress towards fixing our woes, it still isn’t quite good enough.
In areas where the approximate objective is non-convex (i.e. the Hessian has
negative eigenvalues), it still makes little sense to proceed
14: often our
second order method can perform *worse* than a simple first-order method with
momentum in getting out of local optima
^{14[14]}
. Further, calculating the Hessian
requires our network function to be twice-differentiable, which is not the case
for many common architectures (e.g. ReLU networks). This motivates an
alternative to the Hessian for calculating the curvature of our objective.

### The Gauss-Newton Hessian

Denoting an individual network output as \(\rvy = f(\rvx; \vtheta)\) and the corresponding loss as \(\ell(\rvy, \rvt)\), we will use, without proof, the fact that the Hessian decomposes as:

\[\begin{equation} \label{eq:hessian_decomp} \nabla^{2}_{\vtheta}\gJ(\vtheta) = \rmJ^{\top}_{\rvy\vtheta}\rmH_{\rvy}\rmJ_{\rvy\vtheta} + \sum_{i}\frac{\partial \ell}{\partial \rvy_{i}}\nabla^{2}_{\vtheta}[f(\rvx; \vtheta)]_{i}, \end{equation} \]

where \(\rmH_{\rvy} = \nabla^{2}_{\rvy}\ell(\rvy, \rvt)\) is the *output Hessian*
(i.e. the second derivative of just the loss function wrt. the network outputs \(\rvy\))
and \(\rmJ_{\rvy\vtheta}\) is the Jacobian \(\delta \rvy/\delta \vtheta\) or
\(\nabla_{\vtheta}f(\rvx; \vtheta)\) using our previous notation.
The first term in Equation \(\ref{eq:hessian_decomp}\) consists of a quadratic approximation to
\(\ell\) and a linear approximation to \(f\), whereas the second term contains a
linear approximation to \(\ell\), and a quadratic approximation to \(f\).

By simply dropping the second term, we get the *Gauss-Newton Hessian*, which we denote as \(\rmG\):

\[\begin{align} \rmG &= \rmJ_{\rvy\vtheta}^{\top}\rmH_{\rvy}\rmJ_{\rvy\vtheta} \label{eq:gnh} \\ &= \nabla_{\vtheta}f(\rvx; \vtheta)^{\top}\rmH_{\rvy}\nabla_{\vtheta}f(\rvx; \vtheta). \end{align} \]

One way to interpret the above is that we have first linearised the network
around the current weights
15: i.e. using a 1st-order Taylor approximation
^{15[15]}

\[f_{\text{lin}}(\rvx; \vtheta) = f(\rvx; \vtheta_{0}) + \nabla_{\vtheta}f(\rvx; \vtheta_{0})\big(\vtheta - \vtheta_{0}\big), \]

which leaves the following objective function \(\gJ_{\text{lin}}(\vtheta) = \sum_{i=1}^{N}\ell(f_{\text{lin}}(\rvx_{i}; \vtheta), \rvt_{i})\). The Gauss-Newton Hessian can then be recovered as the Hessian of this new objective:

\[\begin{equation} \nabla_{\vtheta}^{2}\gJ_{\mathrm{lin}}(\vtheta) = \rmJ_{\rvy \vtheta}^{\top}\rmH_{\rvy}\rmJ_{\rvy \vtheta} = \rmG. \end{equation} \]

Swapping out \(\rmH\) for \(\rmG\) in our (damped) second-order weight update gives a modified update rule:

\[\begin{equation} \label{eq:damped_ggn_update} \vtheta_{t+1} = \vtheta_{t} - (\rmG + \eta\rmI)^{-1}\nabla_{\vtheta}\gJ(\vtheta_{t}). \end{equation} \]

This has now rectified the two problems that motivated the use of the Gauss-Newton Hessian:

- The positive semi-definite properties of \(\rmG\) are now only dependent on the convexity of our loss function \(\ell\) and no longer the network function. Hence, as long as our loss / negative likelihood is convex, \(\rmG \succeq \mathbf{0}\) and hence our proxy objective is convex, we can ensure that \(\rmG\) remains positive definite.
- Unlike the calculation of the normal Hessian, \(\rmG\) requires only first-order derivatives of the network function and only the second-order derivatives of the loss function (usually the MSE or cross-entropy / likelihood equivalents; which are generally twice differentiable). Therefore we can apply it to things like ReLU networks without issues.

Note that there are limitations to the Gauss-Newton Hessian; in particular it
disregards the *nonlinear modelling error* matrix (NME), which may have
implications for effective feature learning in neural networks. See the
preliminary work of (Dauphin *et al.*, 2024) for more on
this point.

## Fisher Information and Proximal Optimisation

The introduction of the Gauss-Newton Hessian above with the linearised network may have seemed a little ad-hoc—relying heavily on an identity conjured without proof. Let’s consider an alternative derivation of a closely related quantity, the Fisher information matrix, that is hopefully less reliant on definitions and more intuitive.

It turns out that for likelihoods belonging to the exponential family of distributions (which include common likelihoods such as the Gaussian and Categorical; corresponding to MSE and cross-entropy losses, respectively), the Gauss-Newton Hessian is equivalent to the Fisher information matrix. The Fisher arises in proximal optimisation, with the KL divergence as a regulariser.

**Definition** (Proximal Optimisation)**.** *Proximal optimisation* refers to a
general class of optimisation algorithms which minimise a cost function plus
some *proximity term* \(\rho\), that penalises the distance from the current
iterate. That is, for some cost function \(\gJ(\vtheta)\), the proximal update rule is

\[\begin{equation} \label{eq:general_proximal_update} \vtheta^{(t+1)} = \mathrm{prox}_{\gJ, \lambda}(\vtheta^{(t)}) = \argmin_{\hat{\vtheta}}\left[\gJ(\hat{\vtheta}) + \lambda \rho(\hat{\vtheta}, \vtheta^{(t)})\right]. \end{equation} \]

\(\blacksquare\)

Different choices of dissimilarity function \(\rho\) lead to different update rules. For instance, using the squared Euclidean distance \(\rho(\hat{\vtheta}, \vtheta^{(t)}) = \frac{1}{2}\Vert \hat{\vtheta} - \vtheta^{(t)}\Vert^{2}\), results in—after differentiating and solving for zero—the following optimal update rule:

\[\begin{equation} \label{eq:optimal_proximal_update} \mathrm{prox}_{\gJ, \lambda}(\vtheta^{(t)}) = \vtheta_{\star} = \vtheta^{(t)} - \lambda^{-1}\nabla_{\vtheta}\gJ(\vtheta_{\star}). \end{equation} \]

This looks a lot like the usual gradient descent update, however the gradient is
computed at the *new iterate*, \(\vtheta_{\star}\). Clearly this update rule
can’t be used directly, since \(\vtheta_{\star}\) appears on both sides.

### Approximating the Proximal Update

To proceed, we will need some approximation. We will consider two below:

#### Linearisation

Suppose that we let \(\lambda \to \infty\), with the effect of weighting the proximity term very heavily. This means that \(\vtheta_{\star}\) will remain very close to the current \(\vtheta^{(t)}\), and that \(\gJ\) will be well approximated by a first order Taylor approximation, leading to the following linearised objective

\[\begin{align} \mathrm{prox}_{\gJ, \lambda}(\vtheta^{(t)}) &= \argmin_{\hat\vtheta}\left[\gJ(\vtheta^{(t)}) + \nabla_{\vtheta}\gJ(\vtheta^{(t)})^{\top}(\hat{\vtheta} - \vtheta^{(t)}) + \lambda \rho(\hat{\vtheta}, \vtheta^{(t)}))\right] \\ &= \argmin_{\hat{\vtheta}}\left[\nabla_{\vtheta}\gJ(\vtheta^{(t)})^{\top}\hat{\vtheta} + \lambda\rho(\hat{\vtheta}, \vtheta^{(t)})\right]. \end{align} \]

If we also approximate \(\rho\), this time using a second-order Taylor approximation (since when \(\rvu = \rvv\) in \(\rho(\rvu, \rvv)\), we have \(\nabla_{\rvu}\rho(\rvu, \rvv)\vert_{\rvu = \rvv} = \mathbf{0}\) hence losing the first-order term), then

\[\begin{equation} \label{eq:1} \rho(\hat{\vtheta}, \vtheta^{(t)}) = \frac{1}{2}(\hat{\vtheta} - \vtheta^{(t)})^{\top}\rmG(\hat{\vtheta} - \vtheta^{(t)}) + \gO(\Vert \hat{\vtheta} - \vtheta^{(t)}\Vert^{3}), \end{equation} \]

where
\(\rmG \doteq \nabla^{2}_{\hat{\vtheta}}\rho(\hat{\vtheta}, \vtheta^{(t)})\vert_{\hat{\vtheta} = \vtheta^{(t)}}\),
the Hessian of the distance metric, is referred to as the *metric matrix*.

Interestingly, this looks rather similar to a Mahalanobis distance, which also bears similarity to the quadratic form in a log Gaussian density:

\[\begin{align} \rho(\hat{\vtheta}, \vtheta^{(t)}) &= \frac{1}{2}\Vert \hat{\vtheta} - \vtheta^{(t)}\Vert^{2}_{\rmG} + \gO(\Vert \hat{\vtheta} - \vtheta^{(t)}\Vert^{3}), \end{align} \]

where \(\Vert \rvv \Vert_{\rmG} \doteq \sqrt{\rvv^{\top}\rmG\rvv}\).

Putting the linearised cost and second-order approximation of the distance metric into the general proximal update rule in Equation \(\ref{eq:general_proximal_update}\), we get

\[\begin{equation} \label{eq:3} \mathrm{prox}_{\gJ, \lambda}(\vtheta^{t}) = \argmin_{\hat{\vtheta}}\left[\nabla_{\vtheta}\gJ(\vtheta^{(t)})^{\top}\hat{\vtheta} + \frac{\lambda}{2}(\hat{\vtheta} - \vtheta^{(t)})^{\top}\rmG(\hat{\vtheta} - \vtheta^{(t)})\right], \end{equation} \]

and optimising suggests the following point-wise update rule:

\[\begin{equation} \label{eq:proximal_approx_1} \vtheta_{\star} = \vtheta^{(t)} - \lambda^{-1}\rmG^{-1}\nabla_{\vtheta}\gJ(\vtheta^{(t)}). \end{equation} \]

This update rule resembles the second-order Newton update \(\vtheta^{(t)} - \eta\rmH^{-1}\nabla_{\vtheta}\gJ(\vtheta^{(t)})\), except that instead of using the Hessian of \(\gJ\), we use the Hessian of \(\rho\). When the distance metric is the Euclidean distance (i.e. a Gaussian prior in a log-joint objective), then \(\rmG = \rmI\), so this reduces to the ordinary (first-order) gradient descent update.

#### Second-Order Approximation

Had we chosen to not just linearise the cost but to take a second-order Taylor approximation, then the update rule, derived similarly to the above, would be

\[\begin{equation} \label{eq:proximal_approx_2} \vtheta^{(t+1)} = \vtheta^{(t)} - (\rmH + \lambda \rmG)^{-1}\nabla_{\vtheta}\gJ(\vtheta^{(t)}). \end{equation} \]

Once again, if \(\rho\) is the standard Euclidean distance, then we recover a damped Newton’s update with \(\rmG = \rmI\).

### KL dissimilarity metric

Looking further afield, we can get more interesting results by replacing \(\rho\) with another distance metric. A common one is the KL divergence:

\[\begin{equation} \label{eq:6} \KL(\Q \Vert \P) = \E_{\rvx \sim \Q}\left[\log \Q(\rvx) - \log \P(\rvx)\right]. \end{equation} \]

In our two approximations (Equations \(\ref{eq:proximal_approx_1}\) and \(\ref{eq:proximal_approx_2}\)) of the idealised update in Equation \(\ref{eq:general_proximal_update}\), we made use of the Hessian of the distance metric. The Hessian of the KL divergence is the Fisher information matrix:

\[\begin{align} \nabla_{\rvu}^{2}\KL[\P_{\rvu} \Vert \P_{\vtheta}]\vert_{\rvu = \vtheta} &= \rmF_{\vtheta} \\ &= \Cov_{\rvx \sim \P_{\vtheta}}\big(\nabla_{\vtheta}\log \P_{\vtheta}(\rvx)\big) \\ &= \E_{\rvx \sim \P_{\vtheta}}\left[\big(\nabla_{\vtheta}\log\P_{\vtheta}(\rvx)\big)\big(\nabla_{\vtheta}\log\P_{\vtheta}(\rvx)\big)^{\top}\right]. \end{align} \]

Note that the last line holds since \(\Cov(\rvv) = \E[\rvv\rvv^{\top}] - \E[\rvv]\E[\rvv]^{\top}\) and \(\E_{\rvx \sim \P_{\vtheta}}\left[\nabla_{\vtheta}\log \P_{\vtheta}(\rvx)\right] = \mathbf{0}\): intuitively, for data drawn from \(\P_{\vtheta}\), the log likelihood of this data under \(\P_{\vtheta'}\) will be maximised when \(\vtheta = \vtheta'\); hence the log-likelihood gradient should be \(\mathbf{0}\).

Using the KL in the proximal update gives:

\[\begin{equation} \label{eq:2} \mathrm{prox}_{\gJ, \lambda}(\vtheta) = \argmin_{\hat{\vtheta}}\left[\gJ(\hat{\vtheta}) + \lambda \KL[\P_{\hat{\vtheta}} \Vert \P_{\vtheta}]\right]. \end{equation} \]

In the limit of \(\lambda \to \infty\), this gives the following point-wise parameter update:

\[\begin{equation} \label{eq:kl_proximal_update} \vtheta^{(t+1)} = \vtheta^{(t)} - \eta \rmF^{-1}_{\vtheta^{(t)}}\nabla_{\vtheta}\gJ(\vtheta^{(t)}), \end{equation} \]

where \(\eta = \lambda^{-1}\), and pre-conditioning the gradient by the inverse Fisher gives a natural gradient update.

The key beneficial property of using the KL as our dissimilarity metric, rather than the squared Euclidean distance, is that the KL, which operates over distributions, not weights, doesn’t care about how the distributions are parametrised.

## Kronecker-Factored Approximate Curvature

Having equations for the Gauss-Newton or Fisher information matrix is all well and good, however we still have to contend with the computational and memory costs mentioned in the introduction that beset second order methods. The full Hessian is an \(n \times n\) matrix where, recall, \(n\) was the total number of parameters \(\vtheta\) in our neural network. For even modestly sized networks, this quickly becomes prohibitive. We now turn to K-FAC to resolve this issue.

Before proceeding, let’s give some more detail about our hypothetical function approximator \(f_{\vtheta}\) and define some notation.

Suppose that \(f\) is a neural network with \(L\) layers, \(\ell = 1, \ldots, L\). The \(\ell\)th layer has input activations \(\rva_{\ell-1} \in \R^{N}\), weights \(\rmW_{\ell} \in \R^{M \times N}\), bias \(\rvb_{\ell} \in \R^{M}\) and (pre-activation) outputs \(\rvs_{\ell} \in \R^{M}\) and activation function \(\phi_{\ell}\). That is, the computation of layer \(\ell\) proceeds as follows:

\[\begin{align} \rvs_{\ell} &= \overline{\rmW}_{\ell}\bar{\rva}_{\ell - 1} \\ \rva_{\ell} &= \phi_{\ell}(\rvs_{\ell}), \end{align} \]

The notation used above is *homogeneous vector notation*, where we concatenate
the weights and biases into the same matrix by padding the activations with a
constant unit:

\[\begin{align} \overline{\rmW}_{\ell} &= \begin{bmatrix}\rmW_{\ell} & \rvb_{\ell}\end{bmatrix} \\ \bar{\rva}_{\ell - 1} &= \begin{bmatrix}\rva_{\ell - 1}^{\top} & 1\end{bmatrix}^{\top}. \end{align} \]

We will also use the following *pseudo-gradient* notation,

\[\begin{equation} \label{eq:pseudo_grad_notation} \gD \rvv = \nabla_{\rvv}\log \P(\rvy \vert \rvx; \vtheta). \end{equation} \]

In other words, \(\gD \rvv\) is a shorthand for the derivative of the log likelihood with respect to some vector \(\rvv\). For example, the pseudo-gradient for the homogeneous weight matrix \(\overline{\rmW}_{\ell}\) is

\[\gD \overline{\rmW}_{\ell} = \gD\rvs_{\ell}\bar{\rva}^{\top}_{\ell-1}. \]

Finally, while we introduced the Gauss-Newton Hessian from the point of view of
optimisation above, it will be simpler to work with the Fisher information
matrix below. Fortunately, for losses (e.g. MSE, cross entropy,
MAE
16: *mean absolute error*; returning an estimate of the median. See
this section of an earlier article for more on this.
^{16[16]}
)
corresponding to likelihoods from the exponential family (e.g. Gaussian,
Categorical, Laplace; respectively), these two are equivalent. Starting from
our definition of \(\rmG\) from the previous section:

\[\begin{align} \rmG &= \E_{\rvx \sim \gD}\left[\rmJ^{\top}_{\rvy\vtheta}\rmH_{\rvy}\rmJ_{\rvy\vtheta}\right] \\ &= \E_{\rvx \sim \gD}\left[\rmJ^{\top}_{\rvy\vtheta}\rmF_{\rvy}\rmJ_{\rvy\vtheta}\right] \\ &= \E_{\rvx \sim \gD}\left[\rmJ^{\top}_{\rvy\vtheta}\E_{\rvy \sim f_{\vtheta}(\rvx)}\left[\gD\rvy\gD\rvy^{\top}\right]\rmJ_{\rvy\vtheta}\right] \\ &= \E_{\substack{\rvx \sim \gD\\ \rvy \sim f_{\vtheta}(\rvx)}}\left[\rmJ^{\top}_{\rvy\vtheta}\gD\rvy\gD\rvy^{\top}\rmJ_{\rvy\vtheta}\right] \\ &= \E_{\substack{\rvx \sim \gD\\ \rvy \sim f_{\vtheta}(\rvx)}}\left[\gD\vtheta\gD\vtheta^{\top}\right] \label{eq:fisher} \\ &\doteq \rmF. \end{align} \]

The second line holds due to Fisher’s identity, \(\E[\rmH_{\rvy}] = \rmF_{\rvy}\), where \(\rmF_{\rvy}\) is the *output Fisher* or uncentred
covariance of the log-likelihood gradients.
With slight abuse of notation, we use \(\rvy \sim f_{\vtheta}(\rvx)\) to denote
sampling from the model’s output distribution (e.g. for classification tasks,
sampling from categorical distribution parametrised by the logits output by the
network at \(\rvx\), \(f_{\vtheta}(\rvx)\)). Recall, \(\gD\vtheta = \nabla_{\vtheta}\log \P(\rvy \vert \rvx; \vtheta)\).

### K-FAC Assumptions

At a high level, K-FAC efficiently approximates \(\rmG\) by making two key approximations:

- We assume independence between the weights in different layers of the network, giving the approximated \(\rmG\) a block-structure structure:

- We further approximate each block \(\rmG_{\ell \ell}\) for \(\ell \in \{1, \ldots, L\}\) as a Kronecker product between two factors, \(\rmG_{\ell \ell} \approx \rmA_{\ell-1} \otimes \rmS_{\ell}\).

This factorisation is where we introduce the second approximation, which is an independence assumption between the activations \(\bar{\rva}_{\ell-1}\) and output gradients \(\gD\rvs_{\ell}\) for each \(\ell\).

The first block diagonal approximation is easy enough to understand. Before looking at the second however and how we find these factors \(\rmA_{\ell}\), \(\rmS_{\ell}\), let’s look at some properties of the Kronecker product that will be useful.

### Kronecker Product Identities

A Kronecker product between two matrices \(\rmA \in \R^{m \times n}\) and \(\rmB \in \R^{o \times p}\) is the following \((mo \times np)\) sized matrix

\[\rmA \otimes \rmB = \begin{bmatrix}\erva_{11}\rmB & \cdots & \erva_{1n}\rmB \\ \vdots & \ddots & \vdots \\ \erva_{m1}\rmB & \cdots & \erva_{mn}\rmB\end{bmatrix}. \]

The first of two identities we will need relates to the similarity between the Kronecker product and the vec operator. Stated generally, this is

\[\begin{equation} \label{eq:kron_vec_ident} \text{vec}\big(\rmA\rmX\rmB\big) = \big(\rmB^{\top}\otimes \rmA\big)\text{vec}(\rmX). \end{equation} \]

The special case we will be using is when \(\rmX = \rmI\), for two vectors \(\rvu, \rvv\) when it becomes

\[\begin{equation} \label{eq:kron_vec_ident_specialcase} \text{vec}(\rvu \rvv^{\top}) = \rvv \otimes \rvu. \end{equation} \]

The second, *mixed product* identity shows how a product of Kronecker products
factorises:

\[\begin{equation} \label{eq:kron_prod_fact} \rmA\rmC \otimes \rmB\rmD = (\rmA \otimes \rmB)(\rmC \otimes \rmD). \end{equation} \]

Also note that \((\rmA \otimes \rmB)^{\top} = \rmA^{\top}\otimes \rmB^{\top}\), which is easy to check.

Now, returning to our block diagonal structure, we will focus on a single block of
\(\rmG\) corresponding to layer \(\ell\); \(\rmG_{\ell\ell}\). By Equation
\(\ref{eq:fisher}\)
17: the definition of the GGN for exponential family
likelihoods / the Fisher Information Matrix
^{17[17]}
, we start from the uncentred covariance of
the vectorised Jacobian:

\[\begin{align} \rmG_{\ell\ell} &= \E\left[\text{vec}\big(\gD\overline{\rmW}_{\ell}\big)\text{vec}\big(\gD\overline{\rmW}_{\ell}\big)^{\top}\right] \\ &= \E\left[\text{vec}\big(\gD\rvs_{\ell}\bar{\rva}^{\top}_{\ell-1}\big)\text{vec}\big(\gD\rvs_{\ell}\bar{\rva}^{\top}_{\ell-1}\big)^{\top}\right] \\ &= \E\left[\big(\bar{\rva}_{\ell-1}\otimes \gD\rvs_{\ell}\big)(\bar{\rva}_{\ell-1}\otimes \gD\rvs_{\ell}\big)^{\top}\right] \\ &= \E\left[\bar{\rva}_{\ell - 1}\bar{\rva}_{\ell-1}^{\top} \otimes \gD\rvs_{\ell}\gD\rvs_{\ell}^{\top}\right], \end{align} \]

where we have used our shorthand of \(\gD\overline{\rmW}_{\ell} = \nabla_{\overline{\rmW}_{\ell}}\log \P(\rvy\vert\rvx; \vtheta)\), \(\rvs_{\ell}\)
are the pre-activations (we refer to \(\gD\rvs_{\ell}\) as the *output
gradients*) and on the penultimate and final lines, we have applied the two
Kronecker identities in Equation \(\ref{eq:kron_vec_ident_specialcase}\) and
Equation \(\ref{eq:kron_prod_fact}\), respectively.

So far, we have not yet applied our second approximation; to bring the Kronecker product outside the expectation, we will now assume that the activations are independent from the output gradients. We now denote this approximate GGN as \(\hat{\rmG}\):

\[\begin{align} \hat{\rmG}_{\ell \ell} &= \E\left[\bar{\rva}_{\ell-1}\bar{\rva}^{\top}_{\ell-1}\right] \otimes \E\left[\gD\rvs_{\ell}\gD\rvs_{\ell}^{\top}\right] \\ &\doteq \rmA_{\ell-1}\otimes \rmS_{\ell}, \end{align} \]

where the two *Kronecker factors* are the uncentred covariance matrices of the
activations and pre-activation pseudo-gradients:

\[\begin{align} \rmA_{\ell} &= \E\left[\bar{\rva}_{\ell}\bar{\rva}_{\ell}^{\top}\right] \\ &= \begin{bmatrix}\E[\rva_{\ell}\rva_{\ell}^{\top} & \E[\rva_{\ell}]] \\ \E[\rva_{\ell}] & 1\end{bmatrix}, \\ \rmS_{\ell} &= \E\left[\gD\rvs_{\ell}\gD\rvs_{\ell}^{\top}\right], \end{align} \]

with the expectation taken over empirical input samples from our dataset \(\rvx \sim \gD\) and outputs sampled from the model’s output distribution \(\rvy \sim f_{\vtheta}(\rvx)\).

If layer \(\ell\) has \(N\)-dimensional input activations, and returns \(M\)-dimensional outputs, then the size of \(\hat{\rmG}_{\ell\ell}\) is \((N+1)M \times (N+1)M\), with \(\rmA_{\ell} \in \R^{(N+1) \times (N+1)}\) and \(\rmS_{\ell} \in \R^{M\times M}\).

### Eigenvalue-Corrected K-FAC

George *et al.* (2018) introduce a modification to
K-FAC that is provably a better approximation of the GGN while
remaining cheap to compute.

The modification consists of exploiting the fact that eigendecompositions
distribute over Kronecker products. That is, if we consider two factors \(\rmA\)
and \(\rmS\)
18: We drop the \(\ell\) subscripts to reduce clutter
^{18[18]}
, which have
eigendecompositions \(\rmQ_{\rmA}\mLambda_{\rmA}\rmQ_{\rmA}^{\top}\) and
\(\rmQ_{\rmS}\mLambda_{\rmS}\rmQ_{\rmS}^{\top}\)
19: Recall, our factors are
uncentred covariance matrices and hence are real, symmetric matrices giving
rise to this form of the eigendecomposition.
^{19[19]}
, then applying the mixed product
identity twice, the eigendecomposition of their Kronecker product distributes
as follows:

\[\begin{align} \label{eq:kron_eigen} \rmA \otimes \rmS &= \rmQ_{\rmA}\mLambda_{\rmA}\rmQ_{\rmA}^{\top} \otimes \rmQ_{\rmS}\mLambda_{\rmS}\rmQ_{\rmS}^{\top} \\ &= \big(\rmQ_{\rmA}\otimes \rmQ_{\rmS}\big)\big(\mLambda_{\rmA}\otimes\mLambda_{\rmS}\big)\big(\rmQ_{\rmA}\otimes\rmQ_{\rmS}\big)^{\top}. \end{align} \]

Each of the three terms in the above are \((N+1)M \times (N+1)M\) matrices,
however the middle term is diagonal, and hence only has \((N+1)M\) nonzero
elements. We can thus afford to store this matrix without using a Kronecker
approximation. This is the observation behind the *Eigenvalue-Corrected
K-FAC* of George *et al.* (2018) who approximate the
GGN (for a layer \(\ell\)) by replacing the diagonal matrix with
\(\mLambda\),

\[\begin{equation} \label{eq:ekfac} \hat{\rmG} \approx \big(\rmQ_{\rmA}\otimes \rmQ_{\rmS})\mLambda\big(\rmQ_{\rmA}\otimes\rmQ_{\rmS}\big)^{\top}, \end{equation} \]

where \(\mLambda\) is chosen to minimise the approximation error between \(\rmG\) and our K-FAC approximation \(\hat{\rmG}\). It’s entries are defined as:

\[\begin{equation} \label{eq:ekfac_diag} \mLambda_{ii} \doteq \E\left[\big((\rmQ_{\rmA}\otimes\rmQ_{\rmS})^{\top}\gD\vtheta\big)^{2}_{i}\right] \end{equation} \]

This captures the variances of the pseudo-gradient \(\gD\vtheta\) projected onto each eigenvector of the K-FAC approximation, and improves the quality of the approximation.

#### Derivation

The diagonal entries in Equation \(\ref{eq:ekfac_diag}\) above are derived by minimising the Frobenius norm between the GNN \(\rmG\) and the our approximation. Let the eigendecomposition of our approximation be \(\hat{\rmG} = \rmQ\mLambda\rmQ^{\top}\), with \(\rmQ\) orthogonal. Then, the error is \(e = \Vert \rmG - \rmQ\mLambda\rmQ^{T}\Vert_{F}\).

Working from the squared error, and recalling that \(\rmQ^{\top}\rmQ = \rmI\)

\[\begin{align} e^{2} &= \Vert \rmG - \rmQ\mLambda\rmQ^{\top}\Vert^{2}_{F} \\ &= \Vert \rmQ^{\top}\big(\rmG - \rmQ\mLambda\rmQ^{\top}\big)\rmQ\Vert_{F}^{2} \\ &= \Vert \rmQ^{\top}\rmG\rmQ - \mLambda\Vert_{F}^{2} \\ &= \underbrace{\sum_{i}\big(\rmQ^{\top}\rmG\rmQ - \mLambda\big)^{2}_{ii}}_{\text{diagonal}} + \underbrace{\sum_{i}\sum_{j\ne i}\big(\rmQ^{\top}\rmG\rmQ\big)^{2}_{ii}}_{\text{off-diagonal}}, \end{align} \]

where in the last line we omit \(\mLambda\) from the second off-diagonal term since it contributes nothing to the sum.

Now, since we’re interested in finding the value of the diagonal \(\mLambda\) which minimises \(e\), we set \(\mLambda_{ii} = (\rmQ^{\top}\rmG\rmQ)_{ii}\). Expanding definitions:

\[\begin{align} \mLambda_{ii} &= (\rmQ^{T}\rmG\rmQ)_{ii} \\ &= \big(\rmQ^{\top}\E\left[\gD\vtheta\gD\vtheta^{\top}\right]\rmQ\big)_{ii} \\ &= \big(\E\left[\rmQ^{\top}\gD\vtheta(\rmQ^{\top}\gD\vtheta)^{\top}\right]\big)_{ii} \\ &= \E\left[\big(\rmQ^{\top}\gD\vtheta\big)^{2}_{i}\right] \end{align} \]

In the Kronecker-factored setting, we merely substitute \(\rmQ\) for \((\rmQ_{\rmA} \otimes \rmQ_{\rmS})\).

\(\square\)