When working with kernel methods, particularly with translation invariant kernels \(k(x, x') = k(x-x')\), it is common to compute a matrix of distances between pairs of inputs \(R_{ij} = \Vert \mathbf{x}_{i} - \mathbf{x}_{j}' \Vert_{\ell}\). 0: The notation \(\Vert \cdot \Vert_{\ell}\) implies that the distance metric is norm based (e.g. a Euclidean norm). While we only consider this type of distance metric in this article, this needn’t be the case—depending on the type of the input points \(\mathbf{X}\). 0[0]
While the idea of computing such a matrix is straightforward, the implementation (for instance, in your favourite Python linear algebra package) is not necessarily so obvious. In this Fragment, I want to step through the implementation of some common distance matrices \(\mathbf{R}\) to try to demystify it all.
Squared Euclidean Distance (from Features)
Perhaps the most common such matrix of pairwise distances is the squared Euclidean distance between the rows:
\[R_{mn} = \Vert \mathbf{x}_{m} - \mathbf{x}_{n}'\Vert_{2}^{2}, \]
where \(\mathbf{X} \in \mathbb{R}^{M\times D}\) and \(\mathbf{X}' \in \mathbb{R}^{N\times D}\) can be thought of as design matrices, and the notation \(\mathbf{x}_{i} \in \mathbb{R}^{1\times D}\) denotes a single row of \(\mathbf{X}\).
This distance matrix is used in popular kernels such as the Squared Exponential (RBF) and rational quadratic. It should be clear that \(\mathbf{R}\) is an \(M\times N\) matrix 1: When \(M = N\), the square matrix \(\mathbf{R}\) is also symmetric positive semidefinite; since a distance metric is a symmetric, non-negative real-valued mapping \(d: \mathbf{X} \times \mathbf{X} \to \mathbb{R}\). 1[1] :
Each element of \(\mathbf{R}\) is a scalar:
\[\begin{align*} R_{mn} &= \Vert \mathbf{x}_{m} - \mathbf{x}_{n}'\Vert_2^2 \\ &= (\mathbf{x}_{m} - \mathbf{x}_{n}')(\mathbf{x}_{m} - \mathbf{x}_{n}')^\top \\ &= \mathbf{x}_{m}\mathbf{x}_{m}^\top - 2 \mathbf{x}_{m}{\mathbf{x}'}_{n}^\top + \mathbf{x}_{n}'{\mathbf{x}'}_{n}^{\top}, \end{align*} \]
where in the second line we have written the squared Euclidean norm as: \(\Vert \mathbf{a} \Vert_2^2 = \sum^D_{d=1} (\mathbf{a}_d)^2 = \mathbf{a}\mathbf{a}^\top \in \mathbb{R}\), for \(\mathbf{a} \in \mathbb{R}^{1\times D}\) a row vector.
Implementing this in PyTorch 2: you could just as easily use Numpy, but we’ll be using PyTorch for these examples. 2[2] for two row vectors is simple enough:
def r(xm: Tensor, xn: Tensor) -> Tensor: return (xm - xn)@(xm - xn).T
The returned value, although defined as a Tensor
, is just a singleton list:
>>> r(x1, x2).shape torch.Size([1, 1])
Computing the Full Matrix
While the above is a good start, we’re interested in finding a function R(X1, X2)
that will compute the entire \(\mathbf{R}\) matrix of pairwise squared
Euclidean distances between each row of \(\mathbf{X}\) and \(\mathbf{X}'\).
We can use PyTorch’s broadcasting semantics to do this efficiently.
Consider the three parts of the solution above for an individual point:
\[R_{mn} = \mathbf{x}_{m}\mathbf{x}_{m}^\top - 2 \mathbf{x}_{m}{\mathbf{x}'}_{n}^\top + \mathbf{x}_{n}'{\mathbf{x}'}_{n}^{\top}. \]
-
The first term is the sum of the squared elements of the first argument \(\mathbf{x}_m\). For the full \(\mathbf{R}\) matrix, if the first argument \(\mathbf{X}\) is an \(M\times D\) matrix, then this \(\mathbf{x}\mathbf{x}^\top\) term for each element of the \(i\)th row of \(\mathbf{R}\) is \(\mathbf{x}_i\mathbf{x}_i^\top = \sum_{d=1}^D {x_{id}}^2\), which we can compute as
Xa.pow(2).sum(dim=1)[i]
3: For those unfamiliar with PyTorch syntax,A.pow(2)
raises each element of the matrixA
to the power of 2.A.sum(1)
sums along the 1st (column) dimension of a matrixA
with at least 2 dimensions. 3[3] in Python. -
The second term is \(-2\) multiplied by the inner product of the two arguments (remember, \(\mathbf{x}\) is a row vector). Computing this for each element of the full \(\mathbf{R}\) matrix is can be done with a single outer product, giving us an \(M\times N\) matrix:
-2 * Xa @ Xb.T
4: In PyTorch,@
is the matrix multiplication operator. 4[4] -
The last term can be found in the same way as the first. The \(\mathbf{x}'{\mathbf{x}'}^{\top}\) term for each element of the \(j\)th column of \(\mathbf{R}\) is computed as
Xb.pow(2).sum(1)[j]
.
Computing the sum of squared elements for all rows of \(\mathbf{X}\) leaves us
with a tensor of shape (M,)
—the \(i\)th element of this tensor contains the
first term for all solutions in the \(i\)th row of \(\mathbf{R}\). Similarly
doing this for \(\mathbf{X}'\) results in a tensor of size (N,)
; the \(j\)th
element of this row vector corresponds to the \(\mathbf{x}'{\mathbf{x}'}^\top\)
term in all solutions for the \(j\)th column of \(\mathbf{R}\).
Making one of these (e.g. the first) into a column vector
5: For a tensor T
of size (A,)
(i.e. a row vector), the PyTorch syntax T[:, None]
or
equivalently T.reshape(-1, 1)
will result in a tensor of size (A,1)
(i.e. a
column vector).
5[5]
while the other remains a row vector will give us tensors of
size (M,1)
and (N,)
. Due to PyTorch’s broadcasting semantics, adding these
two will result in a tensor of size (M,N)
6: The first tensor has 2
dimensions while the second has 1, so PyTorch will first make them equal length
by prepending 1 to the dimension of the shorter tensor: giving (M,1)
and
(1,N)
. Now each dimension of the result is given by the max of the two
arguments: (max(M,1), max(1,N)) = (M,N)
.
6[6]
, which contains the first and last
terms of all the pairwise squared Euclidean distances \(\mathbf{R}\): that is,
\(\mathbf{x}_m\mathbf{x}_m^\top + \mathbf{x}_n'{\mathbf{x}_n'}^\top\) for all
\(m\in[1,M], n\in[1,N]\).
We now have all the pieces we need to write the PyTorch method for computing the matrix of pairwise squared Euclidean distances:
def R(Xa: Tensor, Xb: Tensor) -> Tensor: assert Xa.size(1) == Xb.size(1) # ensure D matches return Xa.pow(2).sum(1)[:,None] + Xb.pow(2).sum(1) - 2*Xa @ Xb.T
Squared Euclidean Distance (from a Gram Matrix)
Recall that a Gram matrix is the following matrix of inner products between all pairs of points, which in our context are the rows of a design matrix \(\mathbf{X}\):
\[\mathbf{G} = \begin{bmatrix} \langle \mathbf{x}_{1}, \mathbf{x}_{1} \rangle & \langle \mathbf{x}_{1}, \mathbf{x}_{2} \rangle & \cdots & \langle \mathbf{x}_{1}, \mathbf{x}_{N} \rangle \\ \langle \mathbf{x}_{2}, \mathbf{x}_{1} \rangle & \langle \mathbf{x}_{2}, \mathbf{x}_{2} \rangle & \cdots & \langle \mathbf{x}_{2}, \mathbf{x}_{N} \rangle \\ \vdots & \vdots & \ddots & \vdots \\ \langle \mathbf{x}_{N}, \mathbf{x}_{1} \rangle & \langle \mathbf{x}_{N}, \mathbf{x}_{2} \rangle & \cdots & \langle \mathbf{x}_{N}, \mathbf{x}_{N} \rangle \end{bmatrix} \]
That is, given \(\mathbf{G}\), we can easily find \(\mathbf{x}_i\mathbf{x}_j^\top\) by simply reading off the corresponding element of the Gram matrix: \(G_{ij}\).
As we did above, consider finding the squared distance between just two vectors, for \(i, j \in [1, N]\):
\[\begin{align*} R_{ij} &= \Vert \mathbf{x}_{i} - \mathbf{x}_{j}\Vert_2^2 \\ &= \mathbf{x}_{i}\mathbf{x}_i^\top - 2\mathbf{x}_i\mathbf{x}_j^\top + \mathbf{x}_j\mathbf{x}_j^\top \\ &= \langle \mathbf{x}_i, \mathbf{x}_i \rangle - 2\langle \mathbf{x}_i, \mathbf{x}_j \rangle + \langle\mathbf{x}_j, \mathbf{x}_j\rangle \\ &= G_{ii} - 2G_{ij} + G_{jj}. \end{align*} \]
In PyTorch, we can implement this as:
def r(G: Tensor, i: int, j: int) -> Tensor: return G[i, i] - 2*G[i, j] + G[i, j]
Now considering the full \(\mathbf{R}\) matrix, observe that the squared Euclidean norm
of each row of \(\mathbf{X}\) lies on the diagonal of the Gram matrix; \(\Vert \mathbf{x}_i\Vert^2_2 =\) G[i,i] = G.diag()[i]
.
We can make use of PyTorch’s broadcasting semantics once more to compute the \(N\times N\)
matrix of all \(G_{ii} + G_{jj}\) terms as G.diag()[:, None] + G.diag()
.
The matrix of squared Euclidean distances is thus computed from the Gram matrix as:
def R(G: Tensor) -> Tensor: g = G.diag() return g[:, None] + g - 2 * G