When working with kernel methods, particularly with translation invariant
kernels
While the idea of computing such a matrix is straightforward, the
implementation (for instance, in your favourite Python linear algebra package)
is not necessarily so obvious. In this Fragment, I want to step through the
implementation of some common distance
matrices
Squared Euclidean Distance (from Features)
Perhaps the most common such matrix of pairwise distances is the squared Euclidean distance between the rows:
where
This distance matrix is used in popular kernels such as the Squared Exponential
(RBF) and rational quadratic. It should be clear that
Each element of
where in the second line we have written the squared Euclidean norm as:
Implementing this in PyTorch 2: you could just as easily use Numpy, but we’ll be using PyTorch for these examples. 2[2] for two row vectors is simple enough:
def r(xm: Tensor, xn: Tensor) -> Tensor: return (xm - xn)@(xm - xn).T
The returned value, although defined as a Tensor
, is just a singleton list:
>>> r(x1, x2).shape torch.Size([1, 1])
Computing the Full Matrix
While the above is a good start, we’re interested in finding a function R(X1, X2)
that will compute the entire
We can use PyTorch’s broadcasting semantics to do this efficiently.
Consider the three parts of the solution above for an individual point:
-
The first term is the sum of the squared elements of the first argument
. For the full matrix, if the first argument is an matrix, then this term for each element of the th row of is , which we can compute asXa.pow(2).sum(dim=1)[i]
3: For those unfamiliar with PyTorch syntax,A.pow(2)
raises each element of the matrixA
to the power of 2.A.sum(1)
sums along the 1st (column) dimension of a matrixA
with at least 2 dimensions. 3[3] in Python. -
The second term is
multiplied by the inner product of the two arguments (remember, is a row vector). Computing this for each element of the full matrix is can be done with a single outer product, giving us an matrix:-2 * Xa @ Xb.T
4: In PyTorch,@
is the matrix multiplication operator. 4[4] -
The last term can be found in the same way as the first. The
term for each element of the th column of is computed asXb.pow(2).sum(1)[j]
.
Computing the sum of squared elements for all rows of (M,)
—the (N,)
; the
Making one of these (e.g. the first) into a column vector
5: For a tensor T
of size (A,)
(i.e. a row vector), the PyTorch syntax T[:, None]
or
equivalently T.reshape(-1, 1)
will result in a tensor of size (A,1)
(i.e. a
column vector).
5[5]
while the other remains a row vector will give us tensors of
size (M,1)
and (N,)
. Due to PyTorch’s broadcasting semantics, adding these
two will result in a tensor of size (M,N)
6: The first tensor has 2
dimensions while the second has 1, so PyTorch will first make them equal length
by prepending 1 to the dimension of the shorter tensor: giving (M,1)
and
(1,N)
. Now each dimension of the result is given by the max of the two
arguments: (max(M,1), max(1,N)) = (M,N)
.
6[6]
, which contains the first and last
terms of all the pairwise squared Euclidean distances
We now have all the pieces we need to write the PyTorch method for computing the matrix of pairwise squared Euclidean distances:
def R(Xa: Tensor, Xb: Tensor) -> Tensor: assert Xa.size(1) == Xb.size(1) # ensure D matches return Xa.pow(2).sum(1)[:,None] + Xb.pow(2).sum(1) - 2*Xa @ Xb.T
Squared Euclidean Distance (from a Gram Matrix)
Recall that a Gram matrix is the following matrix of inner products between all
pairs of points, which in our context are the rows of a design matrix
That is, given
As we did above, consider finding the squared distance between just two
vectors, for
In PyTorch, we can implement this as:
def r(G: Tensor, i: int, j: int) -> Tensor: return G[i, i] - 2*G[i, j] + G[i, j]
Now considering the full G[i,i] = G.diag()[i]
.
We can make use of PyTorch’s broadcasting semantics once more to compute the G.diag()[:, None] + G.diag()
.
The matrix of squared Euclidean distances is thus computed from the Gram matrix as:
def R(G: Tensor) -> Tensor: g = G.diag() return g[:, None] + g - 2 * G