Pairwise Distances for Kernel Methods

An esoteric note on computing the pairwise distance between the rows of two matrices (with PyTorch examples).

November 26, 2021

London, UK

When working with kernel methods, particularly with translation invariant kernels \(k(x, x') = k(x-x')\), it is common to compute a matrix of distances between pairs of inputs \(R_{ij} = \Vert \mathbf{x}_{i} - \mathbf{x}_{j}' \Vert_{\ell}\). 0: The notation \(\Vert \cdot \Vert_{\ell}\) implies that the distance metric is norm based (e.g. a Euclidean norm). While we only consider this type of distance metric in this article, this needn’t be the case—depending on the type of the input points \(\mathbf{X}\). 0[0]

While the idea of computing such a matrix is straightforward, the implementation (for instance, in your favourite Python linear algebra package) is not necessarily so obvious. In this Fragment, I want to step through the implementation of some common distance matrices \(\mathbf{R}\) to try to demystify it all.

Perhaps the most common such matrix of pairwise distances is the squared Euclidean distance between the rows:

\[R_{mn} = \Vert \mathbf{x}_{m} - \mathbf{x}_{n}'\Vert_{2}^{2}, \]

where \(\mathbf{X} \in \mathbb{R}^{M\times D}\) and \(\mathbf{X}' \in \mathbb{R}^{N\times D}\) can be thought of as design matrices, and the notation \(\mathbf{x}_{i} \in \mathbb{R}^{1\times D}\) denotes a single row of \(\mathbf{X}\).

This distance matrix is used in popular kernels such as the Squared Exponential (RBF) and rational quadratic. It should be clear that \(\mathbf{R}\) is an \(M\times N\) matrix 1: When \(M = N\), the square matrix \(\mathbf{R}\) is also symmetric positive semidefinite; since a distance metric is a symmetric, non-negative real-valued mapping \(d: \mathbf{X} \times \mathbf{X} \to \mathbb{R}\). 1[1] :

Each element of \(\mathbf{R}\) is a scalar:

\[\begin{align*} R_{mn} &= \Vert \mathbf{x}_{m} - \mathbf{x}_{n}'\Vert_2^2 \\ &= (\mathbf{x}_{m} - \mathbf{x}_{n}')(\mathbf{x}_{m} - \mathbf{x}_{n}')^\top \\ &= \mathbf{x}_{m}\mathbf{x}_{m}^\top - 2 \mathbf{x}_{m}{\mathbf{x}'}_{n}^\top + \mathbf{x}_{n}'{\mathbf{x}'}_{n}^{\top}, \end{align*} \]

where in the second line we have written the squared Euclidean norm as: \(\Vert \mathbf{a} \Vert_2^2 = \sum^D_{d=1} (\mathbf{a}_d)^2 = \mathbf{a}\mathbf{a}^\top \in \mathbb{R}\), for \(\mathbf{a} \in \mathbb{R}^{1\times D}\) a row vector.

Implementing this in PyTorch 2: you could just as easily use Numpy, but we’ll be using PyTorch for these examples. 2[2] for two row vectors is simple enough:

def r(xm: Tensor, xn: Tensor) -> Tensor:
    return (xm - xn)@(xm - xn).T

The returned value, although defined as a Tensor, is just a singleton list:

>>> r(x1, x2).shape
torch.Size([1, 1])

While the above is a good start, we’re interested in finding a function R(X1, X2) that will compute the entire \(\mathbf{R}\) matrix of pairwise squared Euclidean distances between each row of \(\mathbf{X}\) and \(\mathbf{X}'\).

We can use PyTorch’s broadcasting semantics to do this efficiently.

Consider the three parts of the solution above for an individual point:

\[R_{mn} = \mathbf{x}_{m}\mathbf{x}_{m}^\top - 2 \mathbf{x}_{m}{\mathbf{x}'}_{n}^\top + \mathbf{x}_{n}'{\mathbf{x}'}_{n}^{\top}. \]

  • The first term is the sum of the squared elements of the first argument \(\mathbf{x}_m\). For the full \(\mathbf{R}\) matrix, if the first argument \(\mathbf{X}\) is an \(M\times D\) matrix, then this \(\mathbf{x}\mathbf{x}^\top\) term for each element of the \(i\)th row of \(\mathbf{R}\) is \(\mathbf{x}_i\mathbf{x}_i^\top = \sum_{d=1}^D {x_{id}}^2\), which we can compute as Xa.pow(2).sum(dim=1)[i] 3: For those unfamiliar with PyTorch syntax, A.pow(2) raises each element of the matrix A to the power of 2. A.sum(1) sums along the 1st (column) dimension of a matrix A with at least 2 dimensions. 3[3] in Python.

  • The second term is \(-2\) multiplied by the inner product of the two arguments (remember, \(\mathbf{x}\) is a row vector). Computing this for each element of the full \(\mathbf{R}\) matrix is can be done with a single outer product, giving us an \(M\times N\) matrix: -2 * Xa @ Xb.T 4: In PyTorch, @ is the matrix multiplication operator. 4[4]

  • The last term can be found in the same way as the first. The \(\mathbf{x}'{\mathbf{x}'}^{\top}\) term for each element of the \(j\)th column of \(\mathbf{R}\) is computed as Xb.pow(2).sum(1)[j].

Computing the sum of squared elements for all rows of \(\mathbf{X}\) leaves us with a tensor of shape (M,)—the \(i\)th element of this tensor contains the first term for all solutions in the \(i\)th row of \(\mathbf{R}\). Similarly doing this for \(\mathbf{X}'\) results in a tensor of size (N,); the \(j\)th element of this row vector corresponds to the \(\mathbf{x}'{\mathbf{x}'}^\top\) term in all solutions for the \(j\)th column of \(\mathbf{R}\).

Making one of these (e.g. the first) into a column vector 5: For a tensor T of size (A,) (i.e. a row vector), the PyTorch syntax T[:, None] or equivalently T.reshape(-1, 1) will result in a tensor of size (A,1) (i.e. a column vector). 5[5] while the other remains a row vector will give us tensors of size (M,1) and (N,). Due to PyTorch’s broadcasting semantics, adding these two will result in a tensor of size (M,N) 6: The first tensor has 2 dimensions while the second has 1, so PyTorch will first make them equal length by prepending 1 to the dimension of the shorter tensor: giving (M,1) and (1,N). Now each dimension of the result is given by the max of the two arguments: (max(M,1), max(1,N)) = (M,N). 6[6] , which contains the first and last terms of all the pairwise squared Euclidean distances \(\mathbf{R}\): that is, \(\mathbf{x}_m\mathbf{x}_m^\top + \mathbf{x}_n'{\mathbf{x}_n'}^\top\) for all \(m\in[1,M], n\in[1,N]\).

We now have all the pieces we need to write the PyTorch method for computing the matrix of pairwise squared Euclidean distances:

def R(Xa: Tensor, Xb: Tensor) -> Tensor:
    assert Xa.size(1) == Xb.size(1)  # ensure D matches
    return Xa.pow(2).sum(1)[:,None] + Xb.pow(2).sum(1) - 2*Xa @ Xb.T

Recall that a Gram matrix is the following matrix of inner products between all pairs of points, which in our context are the rows of a design matrix \(\mathbf{X}\):

\[\mathbf{G} = \begin{bmatrix} \langle \mathbf{x}_{1}, \mathbf{x}_{1} \rangle & \langle \mathbf{x}_{1}, \mathbf{x}_{2} \rangle & \cdots & \langle \mathbf{x}_{1}, \mathbf{x}_{N} \rangle \\ \langle \mathbf{x}_{2}, \mathbf{x}_{1} \rangle & \langle \mathbf{x}_{2}, \mathbf{x}_{2} \rangle & \cdots & \langle \mathbf{x}_{2}, \mathbf{x}_{N} \rangle \\ \vdots & \vdots & \ddots & \vdots \\ \langle \mathbf{x}_{N}, \mathbf{x}_{1} \rangle & \langle \mathbf{x}_{N}, \mathbf{x}_{2} \rangle & \cdots & \langle \mathbf{x}_{N}, \mathbf{x}_{N} \rangle \end{bmatrix} \]

That is, given \(\mathbf{G}\), we can easily find \(\mathbf{x}_i\mathbf{x}_j^\top\) by simply reading off the corresponding element of the Gram matrix: \(G_{ij}\).

As we did above, consider finding the squared distance between just two vectors, for \(i, j \in [1, N]\):

\[\begin{align*} R_{ij} &= \Vert \mathbf{x}_{i} - \mathbf{x}_{j}\Vert_2^2 \\ &= \mathbf{x}_{i}\mathbf{x}_i^\top - 2\mathbf{x}_i\mathbf{x}_j^\top + \mathbf{x}_j\mathbf{x}_j^\top \\ &= \langle \mathbf{x}_i, \mathbf{x}_i \rangle - 2\langle \mathbf{x}_i, \mathbf{x}_j \rangle + \langle\mathbf{x}_j, \mathbf{x}_j\rangle \\ &= G_{ii} - 2G_{ij} + G_{jj}. \end{align*} \]

In PyTorch, we can implement this as:

def r(G: Tensor, i: int, j: int) -> Tensor:
    return G[i, i] - 2*G[i, j] + G[i, j]

Now considering the full \(\mathbf{R}\) matrix, observe that the squared Euclidean norm of each row of \(\mathbf{X}\) lies on the diagonal of the Gram matrix; \(\Vert \mathbf{x}_i\Vert^2_2 =\) G[i,i] = G.diag()[i].

We can make use of PyTorch’s broadcasting semantics once more to compute the \(N\times N\) matrix of all \(G_{ii} + G_{jj}\) terms as G.diag()[:, None] + G.diag().

The matrix of squared Euclidean distances is thus computed from the Gram matrix as:

def R(G: Tensor) -> Tensor:
    g = G.diag()
    return g[:, None] + g - 2 * G