In machine learning, methods that use the second derivative, or Hessian, of some function approximator are not generally as commonly used as those that merely use first derivatives, or the Jacobian.
This curvature information 0: and related quantities such as the Generalised Gauss-Newton or Fisher Information Matrix 0[0] has many useful applications; from faster or more ‘stable’ optimisation, to uncertainty quantification, and several others which we will touch on below.
Despite these appealing properties, the additional cost of calculating the curvature 1: or, indeed, the complexity incurred while implementing less costly schemes 1[1] has largely held-back these second order methods from broader adoption compared to first order methods.
Nonetheless, there are several good approximations to the Hessian that can drastically reduce the computational and memory cost of working with these methods. In this article, we will focus on Kronecker-Factored Approximate Curvature (K-FAC) in particular (Martens et al., 2015), and I hope to convince you that it isn’t a particularly complicated method to implement 2: despite an apparent reputation 2[2] .
We will begin by motivating the use of the Hessian from the point of view of a simple gradient-based optimiser in Section 1, as well as proximal optimisation in Section 2. We then look at K-FAC in Section 3.
As a final introductory note, I must mention that much of what follows draws on material in Roger Grosse’s excellent set of notes on neural net training dynamics, which you should certainly refer to if you wish to go into more depth on any of the topics below.
Second-Order Methods and Optimisation
Consider the following, fairly standard machine learning setup: we have some data distribution of interest \(\P_{\text{data}}\), and a training dataset that we have sampled from it \(\gD = \{(\rvx_{i}, \rvt_{i})\}_{i=1}^{N}\), consisting of inputs \(\rvx\) and targets \(\rvt\) (where perhaps both the inputs and targets have been corrupted by observation noise).
We also pick a scalar loss function \(\ell(\rvy, \rvt)\) which tells us how unhappy we are 3: in other words, the difference in the prediction and the target, for some appropriate distance metric / likelihood in the probabilistic view 3[3] with the prediction \(\rvy = f(\rvx; \vtheta)\) made by our parametric function approximator \(f\), which has \(n\) parameters \(\vtheta \in \R^{n}\).
Our true objective is to minimise the generalisation loss or risk, on data sampled from our distribution of interest 4: which may be non-stationary or drift over time 4[4] , with respect to the network parameters:
\[\begin{equation} \label{eq:risk} \gR(\vtheta) = \E_{(\rvx, \rvt) \sim \P_{\text{data}}}\left[\ell\big(f(\rvx; \vtheta), \rvt\big)\right]. \end{equation} \]
However, given that we only have our finite data sample, we instead content ourselves with minimising the empirical risk, \(\gJ\) using our finite sampled training dataset:
\[\begin{equation} \label{eq:empirical_risk} \gJ(\vtheta; \gD) = \frac{1}{N}\sum^{N}_{i=1} \ell\big(f(\rvx_{i}; \vtheta), \rvt_{i}\big). \end{equation} \]
In practice, the vast majority of procedures used in machine learning to minimise objectives of the above form are based on first-order optimisation methods, such as stochastic gradient descent (SGD). These make use of the Jacobian 5: Or, more realistically, a sequence of vector-Jacobian products such that we don’t realise the full Jacobian matrix. See my previous article on the topic for more on this. 5[5] (i.e. first derivative) of the objective, \(\nabla\gJ(\vtheta)\) to perform updates to the weights \(\vtheta\).
Most of the common (better-performing) variants of SGD such as Adam (Kingma et al., 2015), AdamW (Loshchilov et al., 2018) and numerous others, which have been responsible for most of the breakthrough advances in deep learning over the past decade (Krizhevsky et al., 2012; Silver et al., 2016), all however remain first-order methods 6: Note that these may be viewed within a framework of approximate second-order optimisation; as elucidated by the Bayesian learning rule of (Khan et al., 2023). 6[6] .
The popularity of these first-order optimisation algorithms in ML is no doubt due to their simple implementation, and low computational and memory cost. However, they are not without their shortcomings, which a move to second-order methods goes some way towards rectifying. To see this, we will begin by examining Taylor approximations of the objective function.
Taylor Approximations
One way of distinguishing between first- and second-order optimisation algorithms in machine learning might be to say that they are respectively making first- and second-order Taylor series approximations to our objective function.
Suppose that the current (initial) model parameters are \(\vtheta_{0}\), and we make a first-order Taylor series expansion around this point 7: On notation; we have dropped the \(\gD\) parameter from \(\gJ\) since the dataset should be clear from the context. The notation \(\nabla_{\vtheta}\gJ(\vtheta_{0})\) is a shorthand for \(\frac{\delta \gJ(\theta)}{\delta \vtheta} \vert_{\theta = \vtheta_{0}}\). 7[7] :
\[\begin{equation} \label{eq:first_order} \gJ^{(1)}(\vtheta) \approx \gJ(\vtheta_{0}) + \nabla_{\vtheta}\gJ(\vtheta_{0})^{\top}(\vtheta - \vtheta_{0}). \end{equation} \]
By stopping at the first order term, this approximation assumes that the objective is linear around the point at which we have expanded, \(\vtheta_{0}\). Clearly, for most non-trivial functions 8: which certainly includes neural network objective functions 8[8] , we don’t need to move very far away from \(\vtheta_{0}\) for this approximation to become inaccurate.
If we treat this approximation as locally correct however, within some small radius \(\eta\), we can optimise the first-order approximation in Equation \(\ref{eq:first_order}\) by moving \(\vtheta\) in the direction where the loss decreases most rapidly. Taking derivatives of Equation \(\ref{eq:first_order}\) wrt. \(\vtheta\), this is simply the direction of the negative gradient, yielding the familiar first-order update used in SGD:
\[\begin{equation} \label{eq:sgd_update} \vtheta_{t+1} = \vtheta_{t} - \eta \nabla_{\vtheta}\gJ(\vtheta_{t}). \end{equation} \]
The biggest problem with this is that from Equation \(\ref{eq:first_order}\) alone, we have no good way of telling how big the step size \(\eta\) should be—this depends on the curvature of \(\gJ(\vtheta)\) around the point at which we have expanded it. If the objective function is ‘flat’ (low curvature) in a given direction, then we ought to take a large step in that direction, while if the objective is ‘sharp’ (high curvature) in a particular direction then we ought to take a small step.
In one dimension, the issue of how big of a step size along the gradient direction to update \(\vtheta_{0}\) to its new value \(\vtheta_{1}\) might be illustrated as follows:
This is particularly problematic when the optimal value of \(\eta\) changes depending on which direction we move in (i.e. if one dimension is very steep and another is very shallow) 9: You might also wonder what would happen if the gradient pointed the ‘wrong way’—that is, if we got trapped in a local minimum. For first-order methods, ‘momentum’ is a simple modification to vanilla SGD which helps to solve this issue. This involves initialising a velocity vector \(v_{0} = 0\) and some momentum coefficient \(\mu \in [0, 1]\) in addition to the learning rate \(\eta\). Now, at each iteration, we first update the velocity \(v_{t+1} = \mu v_{t} + \eta \nabla_{\vtheta}\gJ(\vtheta_{t})\) before updating the parameters \(\vtheta_{t+1} = \vtheta_{t} - v_{t+1}\). 9[9] .
To make progress, suppose that we now extend our Taylor approximation to include the second-order term, resulting in a quadratic approximation to the objective function:
\[\begin{equation} \label{eq:second_order} \gJ^{(2)}(\vtheta) = \gJ(\vtheta_{0}) + \nabla_{\vtheta}\gJ(\vtheta_{0})^{\top}(\vtheta - \vtheta_{0}) + \frac{1}{2}(\vtheta - \vtheta_{0})^{\top}\rmH(\vtheta - \vtheta_{0}), \end{equation} \]
where \(\rmH \doteq \nabla_{\vtheta}^{2}\gJ(\vtheta_{0})\) is the Hessian matrix containing second derivatives of our scalar objective function. Equation \(\ref{eq:second_order}\) now gives us a convex proxy objective / quadratic approximation to our objective function. By solving this proxy objective, we now know how big of a step to take to reach its minimum.
Note that optimising this quadratic proxy objective won’t necessarily get us to the minimum of our true objective. It does however give us a useful heuristic to decide how big of a step size to take—namely, we just go straight to the minimum of our quadratic proxy by solving it analytically.
Assuming for now that this proxy objective is good enough 10: and that it is indeed convex i.e. \(\rmH \succeq \mathbf{0}\). 10[10] , we proceed to optimise it by differentiating it. Using the notation \(\rmJ_{0} \doteq \nabla_{\vtheta}\gJ(\vtheta_{0})\) for brevity, we get
\[\begin{align} \nabla_{\vtheta}\gJ^{(2)}(\vtheta) &= \nabla_{\vtheta}\gJ(\vtheta_{0}) + \nabla_{\vtheta}\left[\rmJ_{0}^{\top}(\vtheta - \vtheta_{0})\right] + \nabla_{\vtheta}\left[\frac{1}{2}(\vtheta - \vtheta_{0})\rmH(\vtheta - \vtheta_{0})\right] \\ &= \nabla_{\vtheta}\rmJ_{0}^{\top}\vtheta + \nabla_{\vtheta}\frac{1}{2}\vtheta^{\top}\rmH\vtheta - \nabla_{\vtheta}\frac{1}{2}\vtheta^{\top}\rmH\vtheta_{0} - \nabla_{\vtheta}\frac{1}{2}\vtheta_{0}^{\top}\rmH\vtheta \\ &= \rmJ_{0} + \rmH\vtheta - \rmH\vtheta_{0} \\ &= \nabla_{\vtheta}\gJ(\vtheta_{0}) + \rmH(\vtheta - \vtheta_{0}), \label{eq:proxy_derivative} \end{align} \]
where in the second line we have expanded the brackets and dropped all terms that do not depend on \(\vtheta\).
Now, setting Equation \(\ref{eq:proxy_derivative}\) above to zero, and solving for \(\vtheta\) gives \(\vtheta = \vtheta_{0} - \rmH^{-1}\nabla_{\vtheta}\gJ(\vtheta_{0})\). Re-writing the variables to give an iterative update rule, and including a step size parameter \(\eta\) for flexibility, we obtain the Newton-Raphson update rule.
Newton's Method
Newton’s iterative method of updating the weights proceeds as:
\[\begin{equation} \label{eq:newton_raphson} \vtheta_{t+1} = \vtheta_{t} - \eta\rmH^{-1}\nabla_{\vtheta}\gJ(\vtheta_{t}) \end{equation} \]
The above is similar to the first-order update rule, however the inverse Hessian now acts as a pre-conditioner \(\rmH^{-1}\) which re-scales the gradient in different directions, providing an appropriate scaling of the step size in each direction.
You might be wondering what happens if the assumption that \(\rmH \succeq 0\) we made above does not hold: namely that the Hessian or curvature of our objective is negative in one or more directions at the point around which we expand it 11: alternatively, the Hessian has some negative eigenvalues 11[11] . Pictured visually (again, only in 1D), the quadratic proxy points in the wrong direction and becomes rather less useful:
Hence, the Newton update rule, as presented above, is not apt to decrease the cost function at each iteration, nor is it guaranteed to converge efficiently, if at all. When the function is concave we might update the weights in the wrong direction. Further, when we Taylor-expand around a point with low curvature, we multiply the gradient by a very big term which may result in a very large step size, and cause training instabilities 12: In the example below, as a happy coincidence we end up close to the optimum, however this isn’t always the case: the Hessian estimate may further be inaccurate causing us to take a big step in the wrong direction, and it may cause us to completely overshoot the global optimum. 12[12] :
One simple way around these constraints is to assume that, while at an unfortunate iteration the objective might be non-convex or even have low positive-curvature, if we average the update over a sufficiently large number of past updates, it will tend to point in the right direction and not be too big. Hence by damping the updates, we can prevent them from moving too far from the current position, and hence ‘minimise the damage’ from a bad update at a given iteration.
This can be done by introducing a quadratic penalty from the previous update to our usual second-order Taylor series approximation of the objective 13: or, alternatively, placing a Gaussian prior on the previous weights with precision \(\eta\) 13[13] :
\[\begin{align} \vtheta_{t+1} &= \argmin_{\vtheta} \gJ^{(2)}(\vtheta) + \frac{\eta}{2}\Vert \vtheta - \vtheta_{t}\Vert^{2} \\ &=\argmin_{\vtheta}\nabla_{\vtheta}\gJ(\vtheta_{t})^{\top}\vtheta + \frac{1}{2}(\vtheta - \vtheta_{t})^{\top}(\rmH + \eta \rmI)(\vtheta - \vtheta_{t}) \\ &= \vtheta_{t} - (\rmH + \eta\rmI)^{-1}\nabla_{\vtheta}\gJ(\vtheta_{t}). \end{align} \]
While the above makes some progress towards fixing our woes, it still isn’t quite good enough. In areas where the approximate objective is non-convex (i.e. the Hessian has negative eigenvalues), it still makes little sense to proceed 14: often our second order method can perform worse than a simple first-order method with momentum in getting out of local optima 14[14] . Further, calculating the Hessian requires our network function to be twice-differentiable, which is not the case for many common architectures (e.g. ReLU networks). This motivates an alternative to the Hessian for calculating the curvature of our objective.
The Gauss-Newton Hessian
Denoting an individual network output as \(\rvy = f(\rvx; \vtheta)\) and the corresponding loss as \(\ell(\rvy, \rvt)\), we will use, without proof, the fact that the Hessian decomposes as:
\[\begin{equation} \label{eq:hessian_decomp} \nabla^{2}_{\vtheta}\gJ(\vtheta) = \rmJ^{\top}_{\rvy\vtheta}\rmH_{\rvy}\rmJ_{\rvy\vtheta} + \sum_{i}\frac{\partial \ell}{\partial \rvy_{i}}\nabla^{2}_{\vtheta}[f(\rvx; \vtheta)]_{i}, \end{equation} \]
where \(\rmH_{\rvy} = \nabla^{2}_{\rvy}\ell(\rvy, \rvt)\) is the output Hessian (i.e. the second derivative of just the loss function wrt. the network outputs \(\rvy\)) and \(\rmJ_{\rvy\vtheta}\) is the Jacobian \(\delta \rvy/\delta \vtheta\) or \(\nabla_{\vtheta}f(\rvx; \vtheta)\) using our previous notation. The first term in Equation \(\ref{eq:hessian_decomp}\) consists of a quadratic approximation to \(\ell\) and a linear approximation to \(f\), whereas the second term contains a linear approximation to \(\ell\), and a quadratic approximation to \(f\).
By simply dropping the second term, we get the Gauss-Newton Hessian, which we denote as \(\rmG\):
\[\begin{align} \rmG &= \rmJ_{\rvy\vtheta}^{\top}\rmH_{\rvy}\rmJ_{\rvy\vtheta} \label{eq:gnh} \\ &= \nabla_{\vtheta}f(\rvx; \vtheta)^{\top}\rmH_{\rvy}\nabla_{\vtheta}f(\rvx; \vtheta). \end{align} \]
One way to interpret the above is that we have first linearised the network around the current weights 15: i.e. using a 1st-order Taylor approximation 15[15]
\[f_{\text{lin}}(\rvx; \vtheta) = f(\rvx; \vtheta_{0}) + \nabla_{\vtheta}f(\rvx; \vtheta_{0})\big(\vtheta - \vtheta_{0}\big), \]
which leaves the following objective function \(\gJ_{\text{lin}}(\vtheta) = \sum_{i=1}^{N}\ell(f_{\text{lin}}(\rvx_{i}; \vtheta), \rvt_{i})\). The Gauss-Newton Hessian can then be recovered as the Hessian of this new objective:
\[\begin{equation} \nabla_{\vtheta}^{2}\gJ_{\mathrm{lin}}(\vtheta) = \rmJ_{\rvy \vtheta}^{\top}\rmH_{\rvy}\rmJ_{\rvy \vtheta} = \rmG. \end{equation} \]
Swapping out \(\rmH\) for \(\rmG\) in our (damped) second-order weight update gives a modified update rule:
\[\begin{equation} \label{eq:damped_ggn_update} \vtheta_{t+1} = \vtheta_{t} - (\rmG + \eta\rmI)^{-1}\nabla_{\vtheta}\gJ(\vtheta_{t}). \end{equation} \]
This has now rectified the two problems that motivated the use of the Gauss-Newton Hessian:
- The positive semi-definite properties of \(\rmG\) are now only dependent on the convexity of our loss function \(\ell\) and no longer the network function. Hence, as long as our loss / negative likelihood is convex, \(\rmG \succeq \mathbf{0}\) and hence our proxy objective is convex, we can ensure that \(\rmG\) remains positive definite.
- Unlike the calculation of the normal Hessian, \(\rmG\) requires only first-order derivatives of the network function and only the second-order derivatives of the loss function (usually the MSE or cross-entropy / likelihood equivalents; which are generally twice differentiable). Therefore we can apply it to things like ReLU networks without issues.
Note that there are limitations to the Gauss-Newton Hessian; in particular it disregards the nonlinear modelling error matrix (NME), which may have implications for effective feature learning in neural networks. See the preliminary work of (Dauphin et al., 2024) for more on this point.
Fisher Information and Proximal Optimisation
The introduction of the Gauss-Newton Hessian above with the linearised network may have seemed a little ad-hoc—relying heavily on an identity conjured without proof. Let’s instead consider an alternative derivation of a closely related quantity, the Fisher information matrix, that is hopefully less reliant on definitions and more intuitive.
It turns out that for likelihoods belonging to the exponential family of distributions (which include common likelihoods such as the Gaussian and Categorical; corresponding to MSE and cross-entropy losses, respectively), the Gauss-Newton Hessian is equivalent to the Fisher information matrix. The Fisher arises in proximal optimisation, with the KL divergence as a regulariser.
Definition (Proximal Optimisation). Proximal optimisation refers to a general class of optimisation algorithms which minimise a cost function plus some proximity term \(\rho\), that penalises the distance from the current iterate. That is, for some cost function \(\gJ(\vtheta)\), the proximal update rule is
\[\begin{equation} \label{eq:general_proximal_update} \vtheta^{(t+1)} = \mathrm{prox}_{\gJ, \lambda}(\vtheta^{(t)}) = \argmin_{\hat{\vtheta}}\left[\gJ(\hat{\vtheta}) + \lambda \rho(\hat{\vtheta}, \vtheta^{(t)})\right]. \end{equation} \]
\(\blacksquare\)
Different choices of dissimilarity function \(\rho\) lead to different update rules. For instance, using the squared Euclidean distance \(\rho(\hat{\vtheta}, \vtheta^{(t)}) = \frac{1}{2}\Vert \hat{\vtheta} - \vtheta^{(t)}\Vert^{2}\), results in—after differentiating and solving for zero—the following optimal update rule:
\[\begin{equation} \label{eq:optimal_proximal_update} \mathrm{prox}_{\gJ, \lambda}(\vtheta^{(t)}) = \vtheta_{\star} = \vtheta^{(t)} - \lambda^{-1}\nabla_{\vtheta}\gJ(\vtheta_{\star}). \end{equation} \]
This looks a lot like the usual gradient descent update, however the gradient is computed at the new iterate, \(\vtheta_{\star}\). Clearly this update rule can’t be used directly, since \(\vtheta_{\star}\) appears on both sides.
Approximating the Proximal Update
To proceed, we will need some approximation. We will consider two below:
Linearisation
Suppose that we let \(\lambda \to \infty\), which has the effect of weighting the proximity term very heavily. This means that \(\vtheta_{\star}\) will remain very close to the current \(\vtheta^{(t)}\), and that \(\gJ\) will be well approximated by a first order Taylor approximation, leading to the following linearised objective
\[\begin{align} \mathrm{prox}_{\gJ, \lambda}(\vtheta^{(t)}) &= \argmin_{\hat\vtheta}\left[\gJ(\vtheta^{(t)}) + \nabla_{\vtheta}\gJ(\vtheta^{(t)})^{\top}(\hat{\vtheta} - \vtheta^{(t)}) + \lambda \rho(\hat{\vtheta}, \vtheta^{(t)}))\right] \\ &= \argmin_{\hat{\vtheta}}\left[\nabla_{\vtheta}\gJ(\vtheta^{(t)})^{\top}\hat{\vtheta} + \lambda\rho(\hat{\vtheta}, \vtheta^{(t)})\right]. \end{align} \]
If we also approximate \(\rho\), this time using a second-order Taylor approximation (since when \(\rvu = \rvv\) in \(\rho(\rvu, \rvv)\), we have \(\nabla_{\rvu}\rho(\rvu, \rvv)\vert_{\rvu = \rvv} = \mathbf{0}\) hence losing the first-order term), then
\[\begin{equation} \label{eq:1} \rho(\hat{\vtheta}, \vtheta^{(t)}) = \frac{1}{2}(\hat{\vtheta} - \vtheta^{(t)})^{\top}\rmG(\hat{\vtheta} - \vtheta^{(t)}) + \gO(\Vert \hat{\vtheta} - \vtheta^{(t)}\Vert^{3}), \end{equation} \]
where \(\rmG \doteq \nabla^{2}_{\hat{\vtheta}}\rho(\hat{\vtheta}, \vtheta^{(t)})\vert_{\hat{\vtheta} = \vtheta^{(t)}}\), the Hessian of the distance metric, is referred to as the metric matrix.
Interestingly, this looks rather similar to a Mahalanobis distance, which also bears similarity to the quadratic form in a log Gaussian density:
\[\begin{align} \rho(\hat{\vtheta}, \vtheta^{(t)}) &= \frac{1}{2}\Vert \hat{\vtheta} - \vtheta^{(t)}\Vert^{2}_{\rmG} + \gO(\Vert \hat{\vtheta} - \vtheta^{(t)}\Vert^{3}), \end{align} \]
where \(\Vert \rvv \Vert_{\rmG} \doteq \sqrt{\rvv^{\top}\rmG\rvv}\).
Putting the linearised cost and second-order approximation of the distance metric into the general proximal update rule in Equation \(\ref{eq:general_proximal_update}\), we get
\[\begin{equation} \label{eq:3} \mathrm{prox}_{\gJ, \lambda}(\vtheta^{t}) = \argmin_{\hat{\vtheta}}\left[\nabla_{\vtheta}\gJ(\vtheta^{(t)})^{\top}\hat{\vtheta} + \frac{\lambda}{2}(\hat{\vtheta} - \vtheta^{(t)})^{\top}\rmG(\hat{\vtheta} - \vtheta^{(t)})\right], \end{equation} \]
and optimising suggests the following point-wise update rule:
\[\begin{equation} \label{eq:proximal_approx_1} \vtheta_{\star} = \vtheta^{(t)} - \lambda^{-1}\rmG^{-1}\nabla_{\vtheta}\gJ(\vtheta^{(t)}). \end{equation} \]
This update rule resembles the second-order Newton update \(\vtheta^{(t)} - \eta\rmH^{-1}\nabla_{\vtheta}\gJ(\vtheta^{(t)})\), except that instead of using the Hessian of \(\gJ\), we use the Hessian of \(\rho\). When the distance metric is the Euclidean distance (i.e. a Gaussian prior in a log-joint objective), then \(\rmG = \rmI\), so this reduces to the ordinary (first-order) gradient descent update.
Second-Order Approximation
Had we chosen to not just linearise the cost but to take a second-order Taylor approximation, then the update rule, derived similarly to the above, would be
\[\begin{equation} \label{eq:proximal_approx_2} \vtheta^{(t+1)} = \vtheta^{(t)} - (\rmH + \lambda \rmG)^{-1}\nabla_{\vtheta}\gJ(\vtheta^{(t)}). \end{equation} \]
Once again, if \(\rho\) is the standard Euclidean distance, then we recover a damped Newton’s update with \(\rmG = \rmI\).
KL dissimilarity metric
Looking further afield, we can get more interesting results by replacing \(\rho\) with another distance metric. A common one is the KL divergence:
\[\begin{equation} \label{eq:6} \KL(\Q \Vert \P) = \E_{\rvx \sim \Q}\left[\log \Q(\rvx) - \log \P(\rvx)\right]. \end{equation} \]
In our two approximations (Equations \(\ref{eq:proximal_approx_1}\) and \(\ref{eq:proximal_approx_2}\)) of the idealised update in Equation \(\ref{eq:general_proximal_update}\), we made use of the Hessian of the distance metric. The Hessian of the KL divergence is the Fisher information matrix:
\[\begin{align} \nabla_{\rvu}^{2}\KL[\P_{\rvu} \Vert \P_{\vtheta}]\vert_{\rvu = \vtheta} &= \rmF_{\vtheta} \\ &= \Cov_{\rvx \sim \P_{\vtheta}}\big(\nabla_{\vtheta}\log \P_{\vtheta}(\rvx)\big) \\ &= \E_{\rvx \sim \P_{\vtheta}}\left[\big(\nabla_{\vtheta}\log\P_{\vtheta}(\rvx)\big)\big(\nabla_{\vtheta}\log\P_{\vtheta}(\rvx)\big)^{\top}\right]. \end{align} \]
Note that the last line holds since \(\Cov(\rvv) = \E[\rvv\rvv^{\top}] - \E[\rvv]\E[\rvv]^{\top}\) and \(\E_{\rvx \sim \P_{\vtheta}}\left[\nabla_{\vtheta}\log \P_{\vtheta}(\rvx)\right] = \mathbf{0}\): intuitively, for data drawn from \(\P_{\vtheta}\), the log likelihood of this data under \(\P_{\vtheta'}\) will be maximised when \(\vtheta = \vtheta'\); hence the log-likelihood gradient should be \(\mathbf{0}\).
Using the KL in the proximal update gives:
\[\begin{equation} \label{eq:2} \mathrm{prox}_{\gJ, \lambda}(\vtheta) = \argmin_{\hat{\vtheta}}\left[\gJ(\hat{\vtheta}) + \lambda \KL[\P_{\hat{\vtheta}} \Vert \P_{\vtheta}]\right]. \end{equation} \]
In the limit of \(\lambda \to \infty\), this gives the following point-wise parameter update:
\[\begin{equation} \label{eq:kl_proximal_update} \vtheta^{(t+1)} = \vtheta^{(t)} - \eta \rmF^{-1}_{\vtheta^{(t)}}\nabla_{\vtheta}\gJ(\vtheta^{(t)}), \end{equation} \]
where \(\eta = \lambda^{-1}\), and pre-conditioning the gradient by the inverse Fisher gives a natural gradient update.
The key beneficial property of using the KL as our dissimilarity metric, rather than the squared Euclidean distance, is that the KL, which operates over distributions, not weights, doesn’t care about how the distributions are parametrised.
Kronecker-Factored Approximate Curvature
Having equations for the Gauss-Newton or Fisher information matrix is all well and good, however we still have to contend with the computational and memory costs mentioned in the introduction that beset second order methods. The full Hessian is an \(n \times n\) matrix where, recall, \(n\) was the total number of parameters \(\vtheta\) in our neural network. For even modestly sized networks, this quickly becomes prohibitive. We now turn to K-FAC to resolve this issue.
Before proceeding, let’s give some more detail about our hypothetical function approximator \(f_{\vtheta}\) and define some notation.
Suppose that \(f\) is a neural network with \(L\) layers, \(\ell = 1, \ldots, L\). The \(\ell\)th layer has input activations \(\rva_{\ell-1} \in \R^{N}\), weights \(\rmW_{\ell} \in \R^{M \times N}\), bias \(\rvb_{\ell} \in \R^{M}\) and (pre-activation) outputs \(\rvs_{\ell} \in \R^{M}\) and activation function \(\phi_{\ell}\). That is, the computation of layer \(\ell\) proceeds as follows:
\[\begin{align} \rvs_{\ell} &= \overline{\rmW}_{\ell}\bar{\rva}_{\ell - 1} \\ \rva_{\ell} &= \phi_{\ell}(\rvs_{\ell}), \end{align} \]
The notation used above is homogeneous vector notation, where we concatenate the weights and biases into the same matrix by padding the activations with a constant unit:
\[\begin{align} \overline{\rmW}_{\ell} &= \begin{bmatrix}\rmW_{\ell} & \rvb_{\ell}\end{bmatrix} \\ \bar{\rva}_{\ell - 1} &= \begin{bmatrix}\rva_{\ell - 1}^{\top} & 1\end{bmatrix}^{\top}. \end{align} \]
We will also use the following pseudo-gradient notation,
\[\begin{equation} \label{eq:pseudo_grad_notation} \gD \rvv = \nabla_{\rvv}\log \P(\rvy \vert \rvx; \vtheta). \end{equation} \]
In other words, \(\gD \rvv\) is a shorthand for the derivative of the log likelihood with respect to some vector \(\rvv\). For example, the pseudo-gradient for the homogeneous weight matrix \(\overline{\rmW}_{\ell}\) is
\[\gD \overline{\rmW}_{\ell} = \gD\rvs_{\ell}\bar{\rva}^{\top}_{\ell-1}. \]
Finally, while we introduced the Gauss-Newton Hessian from the point of view of optimisation above, it will be simpler to work with the Fisher information matrix below. Fortunately, for losses (e.g. MSE, cross entropy, MAE 16: mean absolute error; returning an estimate of the median. See this section of an earlier article for more on this. 16[16] ) corresponding to likelihoods from the exponential family (e.g. Gaussian, Categorical, Laplace; respectively), these two are equivalent. Starting from our definition of \(\rmG\) from the previous section:
\[\begin{align} \rmG &= \E_{\rvx \sim \gD}\left[\rmJ^{\top}_{\rvy\vtheta}\rmH_{\rvy}\rmJ_{\rvy\vtheta}\right] \\ &= \E_{\rvx \sim \gD}\left[\rmJ^{\top}_{\rvy\vtheta}\rmF_{\rvy}\rmJ_{\rvy\vtheta}\right] \\ &= \E_{\rvx \sim \gD}\left[\rmJ^{\top}_{\rvy\vtheta}\E_{\rvy \sim f_{\vtheta}(\rvx)}\left[\gD\rvy\gD\rvy^{\top}\right]\rmJ_{\rvy\vtheta}\right] \\ &= \E_{\substack{\rvx \sim \gD\\ \rvy \sim f_{\vtheta}(\rvx)}}\left[\rmJ^{\top}_{\rvy\vtheta}\gD\rvy\gD\rvy^{\top}\rmJ_{\rvy\vtheta}\right] \\ &= \E_{\substack{\rvx \sim \gD\\ \rvy \sim f_{\vtheta}(\rvx)}}\left[\gD\vtheta\gD\vtheta^{\top}\right] \label{eq:fisher} \\ &\doteq \rmF. \end{align} \]
The second line holds due to Fisher’s identity, \(\E[\rmH_{\rvy}] = \rmF_{\rvy}\), where \(\rmF_{\rvy}\) is the output Fisher or uncentred covariance of the log-likelihood gradients. With slight abuse of notation, we use \(\rvy \sim f_{\vtheta}(\rvx)\) to denote sampling from the model’s output distribution (e.g. for classification tasks, sampling from categorical distribution parametrised by the logits output by the network at \(\rvx\), \(f_{\vtheta}(\rvx)\)). Recall, \(\gD\vtheta = \nabla_{\vtheta}\log \P(\rvy \vert \rvx; \vtheta)\).
K-FAC Assumptions
At a high level, K-FAC efficiently approximates \(\rmG\) by making two key approximations:
- We assume independence between the weights in different layers of the network, giving the approximated \(\rmG\) a block-diagonal structure:
- We further approximate each block \(\rmG_{\ell \ell}\) for \(\ell \in \{1, \ldots, L\}\) as a Kronecker product between two factors, \(\rmG_{\ell \ell} \approx \rmA_{\ell-1} \otimes \rmS_{\ell}\).
This factorisation is where we introduce the second approximation, which is an independence assumption between the activations \(\bar{\rva}_{\ell-1}\) and output gradients \(\gD\rvs_{\ell}\) for each \(\ell\).
The first block diagonal approximation is easy enough to understand. Before looking at the second however and how we find these factors \(\rmA_{\ell}\), \(\rmS_{\ell}\), let’s look at some properties of the Kronecker product that will be useful.
Kronecker Product Identities
A Kronecker product between two matrices \(\rmA \in \R^{m \times n}\) and \(\rmB \in \R^{o \times p}\) is the following \((mo \times np)\) sized matrix
\[\rmA \otimes \rmB = \begin{bmatrix}\erva_{11}\rmB & \cdots & \erva_{1n}\rmB \\ \vdots & \ddots & \vdots \\ \erva_{m1}\rmB & \cdots & \erva_{mn}\rmB\end{bmatrix}. \]
The first of two identities we will need relates to the similarity between the Kronecker product and the vec operator. Stated generally, this is
\[\begin{equation} \label{eq:kron_vec_ident} \text{vec}\big(\rmA\rmX\rmB\big) = \big(\rmB^{\top}\otimes \rmA\big)\text{vec}(\rmX). \end{equation} \]
The special case we will be using is when \(\rmX = \rmI\), for two vectors \(\rvu, \rvv\) when it becomes
\[\begin{equation} \label{eq:kron_vec_ident_specialcase} \text{vec}(\rvu \rvv^{\top}) = \rvv \otimes \rvu. \end{equation} \]
The second, mixed product identity shows how a product of Kronecker products factorises:
\[\begin{equation} \label{eq:kron_prod_fact} \rmA\rmC \otimes \rmB\rmD = (\rmA \otimes \rmB)(\rmC \otimes \rmD). \end{equation} \]
Also note that \((\rmA \otimes \rmB)^{\top} = \rmA^{\top}\otimes \rmB^{\top}\), which is easy to check.
Now, returning to our block diagonal structure, we will focus on a single block of \(\rmG\) corresponding to layer \(\ell\); \(\rmG_{\ell\ell}\). By Equation \(\ref{eq:fisher}\) 17: the definition of the GGN for exponential family likelihoods / the Fisher Information Matrix 17[17] , we start from the uncentred covariance of the vectorised Jacobian:
\[\begin{align} \rmG_{\ell\ell} &= \E\left[\text{vec}\big(\gD\overline{\rmW}_{\ell}\big)\text{vec}\big(\gD\overline{\rmW}_{\ell}\big)^{\top}\right] \\ &= \E\left[\text{vec}\big(\gD\rvs_{\ell}\bar{\rva}^{\top}_{\ell-1}\big)\text{vec}\big(\gD\rvs_{\ell}\bar{\rva}^{\top}_{\ell-1}\big)^{\top}\right] \\ &= \E\left[\big(\bar{\rva}_{\ell-1}\otimes \gD\rvs_{\ell}\big)(\bar{\rva}_{\ell-1}\otimes \gD\rvs_{\ell}\big)^{\top}\right] \\ &= \E\left[\bar{\rva}_{\ell - 1}\bar{\rva}_{\ell-1}^{\top} \otimes \gD\rvs_{\ell}\gD\rvs_{\ell}^{\top}\right], \end{align} \]
where we have used our shorthand of \(\gD\overline{\rmW}_{\ell} = \nabla_{\overline{\rmW}_{\ell}}\log \P(\rvy\vert\rvx; \vtheta)\), \(\rvs_{\ell}\) are the pre-activations (we refer to \(\gD\rvs_{\ell}\) as the output gradients) and on the penultimate and final lines, we have applied the two Kronecker identities in Equation \(\ref{eq:kron_vec_ident_specialcase}\) and Equation \(\ref{eq:kron_prod_fact}\), respectively.
So far, we have not yet applied our second approximation; to bring the Kronecker product outside the expectation, we will now assume that the activations are independent from the output gradients. We now denote this approximate GGN as \(\hat{\rmG}\):
\[\begin{align} \hat{\rmG}_{\ell \ell} &= \E\left[\bar{\rva}_{\ell-1}\bar{\rva}^{\top}_{\ell-1}\right] \otimes \E\left[\gD\rvs_{\ell}\gD\rvs_{\ell}^{\top}\right] \\ &\doteq \rmA_{\ell-1}\otimes \rmS_{\ell}, \end{align} \]
where the two Kronecker factors are the uncentred covariance matrices of the activations and pre-activation pseudo-gradients:
\[\begin{align} \rmA_{\ell} &= \E\left[\bar{\rva}_{\ell}\bar{\rva}_{\ell}^{\top}\right] \\ &= \begin{bmatrix}\E[\rva_{\ell}\rva_{\ell}^{\top} & \E[\rva_{\ell}]] \\ \E[\rva_{\ell}] & 1\end{bmatrix}, \\ \rmS_{\ell} &= \E\left[\gD\rvs_{\ell}\gD\rvs_{\ell}^{\top}\right], \end{align} \]
with the expectation taken over empirical input samples from our dataset \(\rvx \sim \gD\) and outputs sampled from the model’s output distribution \(\rvy \sim f_{\vtheta}(\rvx)\).
If layer \(\ell\) has \(N\)-dimensional input activations, and returns \(M\)-dimensional outputs, then the size of \(\hat{\rmG}_{\ell\ell}\) is \((N+1)M \times (N+1)M\), with \(\rmA_{\ell} \in \R^{(N+1) \times (N+1)}\) and \(\rmS_{\ell} \in \R^{M\times M}\).
Eigenvalue-Corrected K-FAC
George et al. (2018) introduce a modification to K-FAC that is provably a better approximation of the GGN while remaining cheap to compute.
The modification consists of exploiting the fact that eigendecompositions distribute over Kronecker products. That is, if we consider two factors \(\rmA\) and \(\rmS\) 18: We drop the \(\ell\) subscripts to reduce clutter 18[18] , which have eigendecompositions \(\rmQ_{\rmA}\mLambda_{\rmA}\rmQ_{\rmA}^{\top}\) and \(\rmQ_{\rmS}\mLambda_{\rmS}\rmQ_{\rmS}^{\top}\) 19: Recall, our factors are uncentred covariance matrices and hence are real, symmetric matrices giving rise to this form of the eigendecomposition. 19[19] , then applying the mixed product identity twice, the eigendecomposition of their Kronecker product distributes as follows:
\[\begin{align} \label{eq:kron_eigen} \rmA \otimes \rmS &= \rmQ_{\rmA}\mLambda_{\rmA}\rmQ_{\rmA}^{\top} \otimes \rmQ_{\rmS}\mLambda_{\rmS}\rmQ_{\rmS}^{\top} \\ &= \big(\rmQ_{\rmA}\otimes \rmQ_{\rmS}\big)\big(\mLambda_{\rmA}\otimes\mLambda_{\rmS}\big)\big(\rmQ_{\rmA}\otimes\rmQ_{\rmS}\big)^{\top}. \end{align} \]
Each of the three terms in the above are \((N+1)M \times (N+1)M\) matrices, however the middle term is diagonal, and hence only has \((N+1)M\) nonzero elements. We can thus afford to store this matrix without using a Kronecker approximation. This is the observation behind the Eigenvalue-Corrected K-FAC of George et al. (2018) who approximate the GGN (for a layer \(\ell\)) by replacing the diagonal matrix with \(\mLambda\),
\[\begin{equation} \label{eq:ekfac} \hat{\rmG} \approx \big(\rmQ_{\rmA}\otimes \rmQ_{\rmS})\mLambda\big(\rmQ_{\rmA}\otimes\rmQ_{\rmS}\big)^{\top}, \end{equation} \]
where \(\mLambda\) is chosen to minimise the approximation error between \(\rmG\) and our K-FAC approximation \(\hat{\rmG}\). It’s entries are defined as:
\[\begin{equation} \label{eq:ekfac_diag} \mLambda_{ii} \doteq \E\left[\big((\rmQ_{\rmA}\otimes\rmQ_{\rmS})^{\top}\gD\vtheta\big)^{2}_{i}\right] \end{equation} \]
This captures the variances of the pseudo-gradient \(\gD\vtheta\) projected onto each eigenvector of the K-FAC approximation, and improves the quality of the approximation.
Derivation
The diagonal entries in Equation \(\ref{eq:ekfac_diag}\) above are derived by minimising the Frobenius norm between the GNN \(\rmG\) and the our approximation. Let the eigendecomposition of our approximation be \(\hat{\rmG} = \rmQ\mLambda\rmQ^{\top}\), with \(\rmQ\) orthogonal. Then, the error is \(e = \Vert \rmG - \rmQ\mLambda\rmQ^{T}\Vert_{F}\).
Working from the squared error, and recalling that \(\rmQ^{\top}\rmQ = \rmI\)
\[\begin{align} e^{2} &= \Vert \rmG - \rmQ\mLambda\rmQ^{\top}\Vert^{2}_{F} \\ &= \Vert \rmQ^{\top}\big(\rmG - \rmQ\mLambda\rmQ^{\top}\big)\rmQ\Vert_{F}^{2} \\ &= \Vert \rmQ^{\top}\rmG\rmQ - \mLambda\Vert_{F}^{2} \\ &= \underbrace{\sum_{i}\big(\rmQ^{\top}\rmG\rmQ - \mLambda\big)^{2}_{ii}}_{\text{diagonal}} + \underbrace{\sum_{i}\sum_{j\ne i}\big(\rmQ^{\top}\rmG\rmQ\big)^{2}_{ii}}_{\text{off-diagonal}}, \end{align} \]
where in the last line we omit \(\mLambda\) from the second off-diagonal term since it contributes nothing to the sum.
Now, since we’re interested in finding the value of the diagonal \(\mLambda\) which minimises \(e\), we set \(\mLambda_{ii} = (\rmQ^{\top}\rmG\rmQ)_{ii}\). Expanding definitions:
\[\begin{align} \mLambda_{ii} &= (\rmQ^{T}\rmG\rmQ)_{ii} \\ &= \big(\rmQ^{\top}\E\left[\gD\vtheta\gD\vtheta^{\top}\right]\rmQ\big)_{ii} \\ &= \big(\E\left[\rmQ^{\top}\gD\vtheta(\rmQ^{\top}\gD\vtheta)^{\top}\right]\big)_{ii} \\ &= \E\left[\big(\rmQ^{\top}\gD\vtheta\big)^{2}_{i}\right] \end{align} \]
In the Kronecker-factored setting, we merely substitute \(\rmQ\) for \((\rmQ_{\rmA} \otimes \rmQ_{\rmS})\).
\(\square\)