Article


Of VJPs and JVPs

Some intuitions and visualisations of vector-Jacobian products and Jacobian-vector products, to help you avoid confusing the two again.

November 5, 2023

London, UK


Suppose that we have a function \(f: \R^{n} \to \R^{m}\), which maps an \(n\)-dimensional input \(\rvx \in \R^{n}\) to an \(m\)-dimensional output \(\rvy \in \R^{m}\). 0: In machine learning, we often work with parametric functions: \(f(\rvx; \vtheta)\), in which case we might also be more interested in differentiating wrt. these parameters \(\vtheta\) than the data \(\rvx\). For clarity, we will stick to the first case for now, however in a later section we will see how we can recover a function in \(\vtheta\) only by essentially averaging out the data term. For now, you could think of ‘partially applying’ an input data point—resulting in a new function \(f_{\rvx}: \R^{n} \to \R^{m}\) where we have assumed \(\vtheta \in \R^{n}\). 0[0]

One way to view this function is as a column vector of \(m\) scalar-valued functions stacked one on top of each other:

The Jacobian matrix \(\rmJ\) of \(f\) is an \(m \times n\) matrix, where each row contains the gradient of the \(i\)‘th ‘scalar function’ wrt the inputs \(\rvx\), \(\nabla_{\rvx}f_{i}(\rvx)\):

\[\rmJ = \begin{bmatrix} \hrule \nabla_{\rvx}f_{1}(\rvx) \hrule \\ \hrule \nabla_{\rvx}f_{2}(\rvx) \hrule \\ \vdots \\ \hrule \nabla_{\rvx}f_{m}(\rvx) \hrule \\ \end{bmatrix} = \begin{bmatrix} \frac{\partial \ervy_{1}}{\partial \ervx_{1}} & \frac{\partial \ervy_{1}}{\partial \ervx_{2}} & \cdots & \frac{\partial \ervy_{1}}{\partial \ervx_{n}} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial \ervy_{m}}{\partial \ervx_{1}} & \frac{\partial \ervy_{m}}{\partial \ervx_{2}} & \cdots & \frac{\partial \ervy_{m}}{\partial \ervx_{n}} \end{bmatrix}. \]

More generally, the \(i, j\)-th entry of \(\rmJ\) contains the partial derivative of the \(i\)th output with respect to the \(j\)th input.

At risk of stating the obvious, the vector-Jacobian product (VJP) is the left-multiplication of the Jacobian by some vector, while the Jacobian-vector product (JVP) is the right-multiplication of the Jacobian by an appropriately shaped vector. Here are the definitions with a bit more detail:

Vector-Jacobian product (VJP)

Given a vector \(\rvu \in \R^{m}\), the VJP is the following row vector:

\[\begin{equation}\label{eq:vjp} \rvu^{\top}\rmJ \in \R^{1 \times n}. \end{equation} \]

Being \(n\)-dimensional, we have one VJP element for each of the function inputs, \(\rvx\) (it is an ‘input space’ concept).

It tells us “in what direction should each of the inputs change, in order to get (as close as possible to 1: This caveat, which I thank Arthur Conmy for pointing out, is because there may be directions in the output space that cannot be influenced by any input changes. For instance our function \(f\) might be singular, and hence \(\rmJ\) not full-rank, as is the case with rank-deficient input features in regression, or the decoder part of an autoencoder’s bottleneck. Alternatively \(f\) may not be surjective onto \(\R^{n}\), as can occur with saturated activations or pooling layers in CNNs. In either case, achieving an exact change of \(\rvu\) in the outputs may be impossible, and we might have to settle for a change that is merely close to \(\rvu\) (e.g. in a least squares sense). 1[1] ) a change of \(\rvu\) in the outputs?

We might think of it as a ‘sensitivity map’ over the inputs: if I want to increase the first output element \(\rvy_{1}\) by \(0.5\) (by setting \(\rvu_{1} = 0.5\), and \(\rvu_{i\ne 1} = 0\)), then the resulting \(n\)-dimensional VJP will tell me how I ought to perturb \(\rvx\).

This is what we do during reverse-mode automatic differentiation.

Similarly, for the JVP

Jacobian-vector product (JVP)

Given a vector \(\rvv \in \R^{n}\), the JVP is the following column vector:

\[\begin{equation}\label{eq:jvp} \rmJ\rvv \in \R^{m}. \end{equation} \]

Being \(m\)-dimensional, we have one JVP element for each of the function outputs, \(\rvy\) (it is an ‘output space’ concept).

It tells us “in what direction do the outputs of \(f\) change if I make a perturbation \(\rvv\) to the inputs?

We might think of it as a directional derivative of \(f\) in the direction of \(\rvv\): if I perturb the first input element \(\rvx_{1}\) by \(0.5\) (by setting \(\rvv_{1} = 0.5\), and \(\rvv_{i\ne 1} = 0\)), then the resulting \(m\)-dimensional JVP will tell me how much the output \(\rvy\) will change.

This is what we do during forward-mode automatic differentiation.

Note that the JVP really just corresponds to a first-order Taylor approximation to the function \(f\). If \(f\) is differentiable at some point \(\rvx_{0}\), then we can approximate it as:

\[f(\rvx) = f(\rvx_{0}) + \rmJ(\rvx - \rvx_{0}) + o(\Vert \rvx - \rvx_{0}\Vert), \]

where we use the little-o notation to denote that the remainder goes to \(0\) faster than \(\Vert \rvx - \rvx_{0}\Vert\) as \(\rvx \to \rvx_{0}\). We often refer to this as the linearization of \(f\).

Mnemonic

One simple way to remember which is which is that for a simple affine transformation \(f: \R^{n} \to \R^{m}\) defined as

\[f(\rvx) = \rmW\rvx, \]

the Jacobian is simply \(\rmJ = \rmW \in \R^{m\times n}\). Hence,

  • VJP: left-multiplying \(\rmJ\) or \(\rmW\) by a vector must correspond to an output-space perturbation, and hence returns the “sensitivty map” over the inputs.
  • JVP: right-multiplying \(\rmJ\) or \(\rmW\) by a vector must correspond to an input-space perturbation, and gives us the directional derivative of \(f\).

Visual Mnemonic

Here is a rote memonic you can apply in times of confusion: just look at whether the ‘V’ in the acronym comes before (input) or after (output) the Jacobian’s ‘J’.

  • VJP: This is the \(n\)-dimensional, input-space object: the sensitivity map telling us how the input affects the output.
  • JVP: This is the \(m\)-dimensional, output-space object: the directional derivative telling us how the function output changes when the inputs are perturbed.

Let’s consider a very simple function with \(n=3\) inputs and \(m=2\) outputs to illustrate what’s going on:

\[f(\rvx) \doteq \begin{bmatrix}\ervx_{1}^{2} + \ervx_{2} + \ervx_{3} \\ \ervx_{1} - \ervx_{2} + \ervx_{3} \end{bmatrix}. \]

The Jacobian \(\nabla_{\rvx}f(\rvx)\vert_{\rvx_1}\), evaluated at some example input point \(\rvx_{1}^{\top} = \begin{bmatrix}2 & 3 & 1\end{bmatrix}\), is the following \(2\times 3\) matrix:

\[\rmJ = \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix}. \]

For a VJP \(\rvu^{\top}\rmJ\), we will require an \(m\)-dimensional vector \(\rvu\). Consider, to start, what happens if \(\rvu\) is a one-hot vector; we get

\[\rvu^{\top}\rmJ = \begin{bmatrix}1 & 0\end{bmatrix}\begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix} = \begin{bmatrix}4 & 1 & 1 \end{bmatrix}. \]

In other words, with some element \(i \in \{1, 2, \ldots, m\}\) of \(\rvu\) set to \(1\), and all others set to \(0\), we have selected the \(i\)th row of the Jacobian, which is exactly the gradient of the ‘scalar function’ \(f_{i}(\rvx)\). Recall, this is an \(n\)-dimensional vector, which tells us how the \(i\)th output component of \(f\) depends on each of the \(n\) input components: it describes the sensitivity of this \(i\)th output component to infintessimal changes in each input component.

Now, if each of the \(m\) elements of \(\rvu\) were set to \(1\), we would pay equal attention to the sensitivity of all the output components to changes in the input—doing an element-wise sum over the rows of the Jacobian:

\[\rvu^{\top}\rmJ = \begin{bmatrix}1 & 1\end{bmatrix}\begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix} = \begin{bmatrix}5 & 0 & 2 \end{bmatrix}. \]

Depending on the importance of each output element, we could also re-weight each row (i.e. gradient vector) differently, giving us a weighted ‘sensitivity map’ over the inputs 2: If the elements of \(\rvu\) don’t sum to \(1\), then this simply scales the VJP magnitude up or down, and it is trivial to normalise this again. 2[2] .

\[\rvu^{\top}\rmJ = \begin{bmatrix}0.5 & 1\end{bmatrix}\begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix} = \begin{bmatrix}3 & -0.5 & 2.5 \end{bmatrix}. \]

For the JVP \(\rmJ\rvv\), we need an \(n\)-dimensional perturbation vector 3: applied to either the inputs \(\rvx\) or weights / parameters \(\vtheta\) in the alternative perspective set out in the sidenote in the introduction. 3[3] which gives the direction along which we wish to calculate the derivative of \(f\).

First, let’s once again consider one-hot vector. Setting the first dimension to \(1\), we get

\[\rmJ\rvv = \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix}\begin{bmatrix}1 \\ 0 \\ 0\end{bmatrix} = \begin{bmatrix}4 \\ 1\end{bmatrix}. \]

Generalising: by setting element \(j \in \{1, 2, \ldots, n\}\) of \(\rvv\) to \(1\), we select the \(j\)th column of the Jacobian. This \(m\)-dimensional vector tells us the derivative of \(f\) along the direction of the first basis vector \(\rve_{1}\). Put otherwise, if we had perturbed the first input value 4: analagously, the first weight element \(\vtheta_{1}\) 4[4] \(\rvx_{1}\) by \(1\), the JVP tells us how the function output would change in response.

If we instead set \(\rvv\) to be a vector of all \(1\)s, we make a unitary perturbation to all the inputs and obtain the derivative of \(f\) in that direction, which corresponds to summing up all the columns of the Jacobian 5: hence, the output is also scaled up by \(n\) 5[5] :

\[\rmJ\rvv = \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix}\begin{bmatrix}1 \\ 1 \\ 1\end{bmatrix} = \begin{bmatrix}6 \\ 1\end{bmatrix}. \]

As before, we can also choose an arbitrary direction \(\rvv\) in which we’d like to find the derivative, which reduces to calculating a weighted sum of the columns of \(\rmJ\):

\[\rmJ\rvv = \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix}\begin{bmatrix}0.5 \\ 1 \\ 2\end{bmatrix} = \begin{bmatrix}5 \\ 1.5\end{bmatrix}. \]

Let’s now consider a supervised learning setting, where we have inputs \(\rvx\) and targets \(\rvt\) which are jointly drawn from some data-generating distribution, \(p_{\text{data}}\). Further suppose that we’ve sampled a training dataset \(\gD = \{(\rvx_{i}, \rvt_{i})\}_{i=1}^{N}\) from this data distribution. Our loss function \(\ell(\rvy, \rvt)\) quantifies how unhappy we are when we make the prediction \(\rvy\) where the true label is \(\rvt\). We make predictions using a parametric function approximator \(f\) with \(n\) parameters \(\vtheta \in \R^{n}\), which outputs \(\rvy = f(\rvx; \vtheta)\).

We wish to find \(\vtheta\) that minimises the generalisation loss or risk, which gives us a function in \(\vtheta\) that we can optimise:

\[\gR(\vtheta) = \E_{(\rvx, \rvt) \sim p_{\text{data}}}\left[\ell(f(\rvx; \vtheta), \rvt)\right]. \]

Since we don’t have access to \(p_{\text{data}}\), and calculating this integral would be expensive, we instead minimise the empirical risk using (Monte-Carlo) samples from the trainig dataset:

\[\gJ(\vtheta) = \frac{1}{N}\sum_{i=1}^{N}\ell(f(\rvx_{i}; \vtheta), \rvt_{i}). \]

Consider a simple linear regression problem, where we have an \(n=2\) dimensional input parameter vector (consisting of the gradient and the intercept), and a scalar-valued loss function (e.g. an \(\ell_2\) loss) yielding \(m=1\). The regression function itself might be defined as

\[y = f(x; a, b) = ax + b. \]

To get a toy dataset, suppose that the latent data-generating function is \(2x + 0.5\), our observations of which are corrupted by Gaussian noise: \(\epsilon \sim \gN(0, 0.5)\). To keep things interesting, let’s also sample the \(x\) values from a LogNormal distribution:

In PyTorch, we can replicate the SGD process using VJPs as follows (note, the code is slightly bloated for this simple example, but it will work with any model, not just this simple linear example). See the torch.func documentation for more details.

import torch as t
import torch.nn as nn
from functools import partial
from torch.func import functional_call, vjp

model = nn.Linear(1, 1)
model_params = dict(model.named_parameters())

def f(x, target, params):
    y = functional_call(model, params, x)
    return (y - target).pow(2.).mean()

lr = 1e-3

for b in train_loader:
    x, target = b
    fx = partial(f, x, target)
    loss, jvp_fn = jvp(fx, model_params)
    param_grad = jvp_fn(t.ones(()))[0]
    for k, v in param_grad.items():
        model_params[k] -= lr * v

Initialising 5 different parameter values, and running each for the same duration yields the following parameter trajectories plotted on top of the empirical risk function. Clearly this is a very naive optimisation algorithm, however the parameters approach the optimum at \((2, 0.5)\) nontheless.

The direct use of JVPs in optimisation is less common. You can think of it as proposing directions in which to step (stabbing in the dark) and getting back the rate of change of \(f\) if you were to step in that direction.

Generally if \(n\), the number of inputs to our function, is large (for instance, all the parameters of a neural network) while \(m\), the number of outputs, is small (such as a scalar loss function), then using reverse-mode automatic differentiation with JVPs will be most efficient: telling us directly how much we ought to update each parameter relative to one another.

Conversely, if \(n\) was small and \(m\) was large, which is less common in machine learning, then it would be more efficient to use forward-mode automatic differentiation with VJPs.

Note that while we could compute the update procedure \(\theta_{t+1} = \theta_{t} - \eta \rvu^{\top}\rmJ\) by fully realising the Jacobian matrix, for large neural network layers this would consume unnecessarily large amounts of memory. Indeed, it is rare to ever require the Jacobian on its own, and the VJP (and JVP) allows us to avoid ever needing to store this full matrix on its own.

Returning to the linearisation view (1st-order Taylor series expansion), we can see that the Jacobian merely gives us a tangent line to our objective function. While this might be useful to determine the direction in which we should update the parameters, it does not tell us by how much and quickly becomes a poor approximation to the function, for instance ignoring the curvature.

In subsequent posts, I will go over some common first-order and second-order optimisation algorithms which seek to make more judicious parameter updates at every step.