Consider a generic generative modelling task, where we would like to estimate the density of some unknown data distribution \(p_{\text{data}}(\rvx)\) given a dataset of points \(\gD = \{\rvx_{i}\}_{i=1}^{N}\) each drawn from this distribution \(\rvx_{i} \sim p_{\text{data}}(\rvx)\).
Different choices of generative models represent the underlying probability distribution in different ways, and these will have implications for any downstream modelling tasks.
For instance, likelihood based methods are a broad class of methods which directly model the data likelihood \(p_{\text{data}}(\rvx)\) using a parametric function approximator \(f_{\vtheta}(\rvx) \in \R\). 0: i.e. a neural network of some sort. 0[0] One very general way to construct a neural likelihood, inspired by energy-based models, might be to write the likelihood as:
\[\begin{equation} \label{eq:ebm} p_{\vtheta}(\rvx) = \frac{1}{Z_{\vtheta}}\exp\big({-f_{\vtheta}(\rvx)}\big), \end{equation} \]
where the partition function \(Z_{\vtheta} > 0\) ensures that the density normalises; that is, \(\int p_{\vtheta}(\rvx)d\rvx = 1\). For most neural networks \(f_{\vtheta}\), computing this partition function is intractable. Further, for maximum likelihood training of \(\vtheta\) under such a model, where we maximise an objective of the form 1: i.e. maximising the (log) likelihood of the data, which is assumed i.i.d. hence the simple sum over the individual datapoint log-likelihoods in Equation \(\ref{eq:ml}\) 1[1]
\[\begin{equation} \label{eq:ml} \gL_{\text{ML}}(\vtheta) = \sum_{x \in \gD} \log p_{\vtheta}(\rvx), \end{equation} \]
we must re-compute the partition function for each gradient step, exacerbating our computational woes under this framework.
Likelihood based model families get around this need to compute an expensive (if, intractable) normalising constant in different ways. VAEs (Kingma et al., 2014; Rezende et al., 2014) optimise a tractable lower bound on the log-likelihood; autoregressive approaches (Germain et al., 2015; Salimans et al., 2016) exploit the structure of the joint; and flow-based models (Rezende et al., 2015; Kingma et al., 2017) impose a restrictive model architecture on \(f_{\vtheta}\) to ensure the invertability of the learned function and an easily computable Jacobian determinant. Each of these solutions introduce side effects—such as slower sampling speed or less expressive models—that one must trade off when choosing a method.
Diffusion models on the other hand have emerged as a successful alternative class of generative models, which owe much of their popularity to the conspicuous lack of constraints on the model architecture, non-adversarial training procedure and their impressive empirical performance in numerous domains 2: including, most famously, image generation, but also audio synthesis, time series modelling and many others. 2[2] . These generative models not only allow us to sample new instances from the approximated distribution, but also condition on part of a data point to impute the rest (perhaps to fill in missing datapaoints) or evaluate the likelihood of a new test point.
Diffusion models, as originally set out (Sohl-Dickstein et al., 2015; Ho et al., 2020), are however only suitable for continuous data. Yet there are many instances of structured, discrete data for which we might want a generative model: graphs, text, genomic sequences and many others. In what follows, we will first consider the SDE view of diffusion models, which will give us a flexible framework in which to discuss several modifications that allow diffusion models to work on the probability simples, and thus over discrete data.
Preliminaries: Score-Based Generative Modelling with SDEs
By way of setting out notation and keeping this post relatively self-contained, we will take a brief overview of the SDE interpretation of diffusion models (Song et al., 2020).
In a basic score-based generative model, we train a neural network to estimate the Stein score 3: as opposed to the Fisher score \(\nabla_{\vtheta}\log p_{\vtheta}(\rvx)\) where the gradient is taken wrt. the parameters \(\vtheta\) of the distribution, and which we come across in RL, VI and other areas in ML, the Stein score \(\nabla_{\rvx}\log p_{\vtheta}(\rvx)\) takes the gradient wrt. the data, \(\rvx\). 3[3] of the data distribution, \(\nabla_{\rvx}\log p_{\text{data}}(\rvx)\), using a score network \(s_{\vtheta}: \R^{D} \to \R^{D}\) optimised such that \(s_{\vtheta}(\rvx) \approx \nabla_{\rvx}\log p_{\text{data}}(\rvx)\). Taking the gradient of the likelihood in Equation \(\ref{eq:ebm}\) wrt. \(\rvx\), we can see how this gets around the normalisation issues faced by likelihood-based models, since the gradient of the log partition function goes to \(0\)
\[s_{\vtheta}(\rvx) = \nabla_{\rvx}\log p_{\vtheta}(\rvx) = - \underbrace{\nabla_{\rvx}\log Z_{\vtheta}}_{=0} -\nabla_{\rvx} f_{\vtheta}(\rvx) = -\nabla_{\rvx}f_{\vtheta}(\rvx). \]
One objective we may minimise 4: apologies for the inconsistency in notation between the ML objective we maximised in Equation \(\ref{eq:ml}\) and this one which we minimise. 4[4] to train such a score network \(s_{\vtheta}\) is the squared \(\ell_{2}\) distance between the ground-truth data score and the score network output, a quantity known as the Fisher divergence:
\[\begin{equation} \label{eq:fisher_divergence} \gL_{\text{fisher}}(\vtheta) = \E_{p_{\text{data}}(\rvx)}\left[\Vert \nabla_{\rvx}\log p_{\text{data}}(\rvx) - s_{\vtheta}(\rvx)\Vert_{2}^{2}\right]. \end{equation} \]
While the ground-truth data score is unknown 5: the \(p_{\text{data}}(\rvx)\) distribution is the very thing we’re trying to learn 5[5] , it can be shown (Hyvärinen, 2005; Vincent, 2011) (also see Appendix A ) that an alternative objective which does not depend on the ground truth scores \(\nabla_{\rvx}\log p_{\text{data}}(\rvx)\) can remarkably be found as:
\[\begin{equation} \label{eq:score_matching} \gL(\vtheta) = \E_{p_{\text{data}}(\rvx)}\left[\frac{1}{2}\Vert s_{\vtheta}(\rvx)\Vert_{2}^{2} + \Tr(\nabla_{x}s_{\vtheta}(\rvx))\right] + \text{const}. \end{equation} \]
Despite resolving this, an issue remains with the above in that this objective will produce score networks \(s_{\vtheta}(\rvx)\) which are very inaccurate in low density regions where we observe few data points:
\[\begin{equation} \label{eq:score_matching_int} \gL(\vtheta) = \int {\color{#D32F2F}p_{\text{data}}(\rvx)}\left(\frac{1}{2}\Vert s_{\vtheta}(\rvx)\Vert_{2}^{2} + \Tr\big(\nabla_{x}s_{\vtheta}(\rvx)\big)\right) d\rvx + \text{const}. \end{equation} \]
Instead, diffusion models resolve this issue by introducing a known, time-dependent perturbation to the data \(p_{t}(\rvx)\), and instead training the score network to match the score of this perturbation \(\nabla_{\rvx}\log p_{t}(\rvx)\) at a given time \(t \in [0, T]\) rather than the score of the data distribution \(\nabla_{\rvx} \log p_{\text{data}}(\rvx)\). Intuitively, \(\nabla_{\rvx} \log p_{\text{data}}(\rvx)\) may not make for a very good learning signal, and we would rather design our own fate by being the ones in control of \(\nabla_{\rvx}\log p_{t}(\rvx)\).
This perturbation is selected such that at \(t = 0\), we recover the original (unperturbed) data, for increasing time \(t > 0\) we progressively add more noise until at some final time \(T\) we reach a limiting distribution \(p_{t=1}(\rvx)\) which is ideally chosen to be easily sampled from and independent from the data distribution 6: for this reason, it is sometimes referred to as a prior distribution; by analogy to the prior in a VAE or normalising flow. 6[6] . By ensuring that \(p_{t}(\rvx)\) is sufficiently noisy for large values of \(t\), we can make sure that there is a sufficiently strong learning signal away from the observed data.
The SDE View
More generally, these perturbations can be flexibly specified through a stochastic differential equation (SDE), whose solution is known as the forward (or noise corrupting) process, and whose marginal at any time \(t \in [0, T]\) can be obtained in closed form, \(p_{t\vert 0}(\rvx_{t} \vert \rvx_{0})\) given the boundary condition of \(\rvx_{0}\) at \(t=0\) being a point drawn from the dataset.
That is, we have an SDE of the following form
\[\begin{equation} \label{eq:forward_sde} d\rvx_{t} = f(\rvx, t)dt + \rmG(\rvx_{t}, t)d\rmW_{t}, \end{equation} \]
where \(\rmW\) is the standard Wiener process, \(f(\cdot, t): \R^{D} \to \R^{D}\) is the drift term, and \(\rmG(\cdot, t): \R^{D} \to \R^{D\times D}\) the diffusion coefficient.
For example, one extremely simple SDE is \(d\rvx_{t} = e^{t}\rmI d\rmW_{t}\), corresponding to zero-mean Gaussian noise perturbations with exponentially growing variance in \(t\). Other potential SDEs include the Variance Exploding SDE, the Variance Preserving SDE and the sub-VP SDE (Song et al., 2020).
Our new score-matching objective is to match the score of the forward process \(\nabla_{\rvx} \log p_{t}(\rvx)\) with a time-dependent score network 7: Introducing the time dependence usually involves a fairly straightforward modification to most architectures, where we embed the time value using a learned or fixed embedding, and add it to the input activations at various layers. 7[7] \(s_{\vtheta}(\rvx, t)\), such that \(s_{\vtheta}(\rvx, t) \approx \nabla_{\rvx}\log p_{t}(\rvx)\). We can use the following weighted combination of Fisher divergences as our objective to minimise
\[\begin{equation} \label{eq:weighted_fisher_div} \gL_{\text{TDS}}(\vtheta) = \E_{t\sim U[0, T]}\E_{p_{t}(\rvx)}\Big[\lambda(t) \big\Vert \nabla_{\rvx}\log p_{t}(\rvx) - s_{\vtheta}(\rvx, t)\big\Vert_{2}^{2}\Big], \end{equation} \]
where we sample time values uniformly at random along the time interval, and \(\lambda: \R \to \R_{>0}\), defined as \(\lambda(t) \propto 1/\E\big[\Vert \nabla_{\rvx_{t}}\log p_{t\vert 0}\big(\rvx_{t} \vert \rvx_{0}\big)\Vert_{2}^{2}\big]\), serves to normalise the magnitude of the different \(\ell_{2}\) losses across time.
With this score function (or the approximation thereof; \(s_{\vtheta}(\rvx, t)\)) in hand, we can draw samples from the learned distribution by solving the time-reversed SDE 8: which we state for the multivariate form for consistency with later sections. The univariate form is somewhat simpler: \(d\rvx_{t} = [\rvf(\rvx_{t}, t) - g^{2}(t)\rvs_{\vtheta}(\rvx_{t}, t)]dt + g(t)d\rvw\) 8[8]
\[\begin{align} d\rvx_{t} &= \rvf(\rvx_{t}, t) - \frac{1}{2}\nabla_{\rvx} [\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^\top]dt \nonumber \\ &-\frac{1}{2}\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^{\top}\nabla_{\rvx} \log p_{t}(\rvx_{t})dt + \rmG(\rvx_{t}, t)d\overline{\rmW} \label{eq:time_reversed_sde} \end{align} \]
where time flows backwards from \(t=T\) to \(0\) and \(\overline{\rmW}\) is the time-reversed Wiener process 9: Whose properties are identical to the usual Wiener process, yet we use the bar notation to avoid ambiguity about whether we’re referring to the forward or reverse process. 9[9] .
We can solve this reverse SDE and generate a sample using any numerical method we’d like, with the Euler-Maruyama method being the most straightforward. Selecting some small negative time increment \(\Delta t = -\epsilon\) for some \(\epsilon > 0\), initialising \(t \gets T\), and iterating until \(t \approx 0^{+}\), we repeat
\[\begin{align*} \Delta \rvx &\gets \rvf(\rvx_{t}, t) dt - \frac{1}{2}\nabla_{\rvx}[\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^{\top}]dt \\ &- \frac{1}{2}\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^{\top}{\color{#3B5E8C}\nabla_{\rvx}\log p_{t}(\rvx_{t})}dt + \rmG(\rvx_{t}, t)\sqrt{\vert \Delta t\vert}\rvz_{t} \\ \rvx &\gets \rvx + \Delta \rvx \\ t &\gets t + \Delta t, \end{align*} \]
where \(\rvz_{t} \sim \gN(0, \rmI)\). In practice, we substitute \({\color{#3B5E8C}\nabla_{\rvx}\log p_{t}(\rvx_{t})}\) with the output of our time-conditioned score network \(s_{\vtheta}(\rvx_{t}, t)\).
The key challenge of applying the above on discrete data is that the score function \(\nabla_{\rvx}\log p(\rvx)\) is undefined: there is no notion of the direction in which the probability mass function changes smoothly, or any smoothness at all for that matter.
There have been several proposals to adapt this model class to work with discrete data. In what follows, we will consider two approaches to perform diffusion on the probability simplex (Richemond et al., 2023; Floto et al., 2023) as well as the Diffusion Bridges approach of Liu et al. (2022). Other promising approaches include Concrete Score Matching (Meng et al., 2022), VQ-Diffusion (Gu et al., 2022) which we do not treat here. In a related line of work, Bayesian flow networks (Graves et al., 2023) also provide a way to use a ‘diffusion-like’ model 10: in the sense that the BFN is a generative model class which is not autoregressive and places few constraints on the network 10[10] with discrete data.
Categorical SDEs with Simplex Diffusion
In Categorical SDEs with Simplex Diffusion, Richemond et al. (2023) propose a discrete diffusion construction by deriving a tractable multivariate stochastic process that operates on the probability simplex itself. That is, the limiting distribution for each dimension of this multivariate process approach a certain Gamma distribution, which may be used to draw samples from a Dirichlet distribution on the simplex.
Dirichlet Distributions
The \(k\)-dimensional probability simplex is defined as
\[\begin{equation} \label{eq:k_simplex} \gS^{k} \doteq \Big\{\rvx \in \R^{k}: 0 \le \rvx_{i} \le 1, \sum_{i=1}^{k} \rvx_{i} = 1\Big\}. \end{equation} \]
Each point on the simplex can be seen to define an \(k+1\) dimensional categorical distribution.
Definition (Dirichlet distribution). The Dirichlet distribution, denoted \(\gD(\valpha)\) is a multivariate, continuous distribution parametrised by a vector of positive scalars \(\valpha \doteq (\alpha_{1}, \ldots, \alpha_{D})\) whose probability density function is defined 11: Using the standard Lebesgue measure on \(\R^{D}\) 11[11] as
\[\begin{equation} \label{eq:dirichlet_density} f_{\gD}(\rvx, \valpha) = \frac{1}{Z_{\valpha}}\prod_{i=1}^{D}\ervx_{i}^{\alpha_{i}-1}, \end{equation} \]
for \(\rvx \in \R^{D}\), \(\alpha_{i} \in \R^{+}\ \ \forall i\) and \(Z_{\valpha}\) the partition function or normalising constant.
\(\blacksquare\)
Note that setting \(\alpha_{i} = 1\) for all \(i\) gives us a uniform distribution over the simplex. Further, sampling from a Dirichlet can be done with the following two-step procedure:
- Sample \(D\) independent Gamma random variables \(\ervy_{1} \sim \gG(\alpha_{1}, \beta), \ldots, \ervy_{D} \sim \gG(\alpha_{D}, \beta)\) where \(\alpha_{i}\) are the individual shape parameters, and \(\beta\) is a shared rate parameter.
- Normalise these random variables to sum to \(1\), resulting in a Dirichlet-distributed random vector:
\[\left(\frac{\ervy_{1}}{\sum_{i=1}^{D}\ervy_{i}}, \ldots, \frac{\ervy_{D}}{\sum_{i=1}^{D}\ervy_{i}}\right) \sim \gD(\valpha). \]
The above result holds for any positive \(\beta\). Richemond et al. (2023) choose to use a shared rate parameter of \(\beta = 1\).
Cox-Ingersoll-Ross Process
In order to define a suitable forward process, we use a Cox-Ingersoll-Ross (CIR) process, whose limiting distribution can be made into the Gamma distribution \(\gG(\alpha_{i}, 1)\). The key insight is that we can treat this as one of the \(D\) Gamma distributions required to sample from the Dirichlet, as set out in the sampling procedure above.
In more detail, the CIR process is a real-valued diffusion process defined by the following SDE in \(\theta_{t}\)
\[\begin{equation} \label{eq:cir_process_sde} d\theta_{t} = b(a - \theta_{t})dt + \sigma \sqrt{\theta_{t}}dW_{t}, \end{equation} \]
which holds for any \(\theta_{0} \ge 0\), \(a, b, \sigma > 0\) and where \((W_{t})_{t\ge 0}\) is a Wiener process. The invariant limiting distribution is \(\gG(2ab/\sigma^{2}, 2b/\sigma^{2})\), and if \(2ab \ge \sigma^{2}\) with positive \(\theta_{0} > 0\), then the process becomes strictly positive. Setting \(2b = \sigma^{2}\), the SDE in Equation \(\ref{eq:cir_process_sde}\) becomes
\[\begin{equation} d\theta_{t} = b(a - \theta_{t})dt + \sqrt{2b\theta_{t}}dW_{t}, \end{equation} \]
which admits the Gamma distribution \(\gG(a, 1)\) as its limiting distribution, as required for constructing one dimension of the Dirichlet random vector.
Using the CIR Process for Diffusion on the Simplex
A forward process whose marginal distribution in the large time limit provides samples from a Dirichlet distribution \(\gD(\valpha)\) would allow us to perform diffusion on the simplex. Using \(D\) independent CIR processes (as described in the previous section) and simulating them in parallel, we obtain a process \(\rvy_{t}\) with positive marginals, each following
\[\begin{equation} \label{eq:cir_fwd_sde} d\ervy_{t}^{i} = b(\alpha_{i} - \ervy_{t}^{i})dt + \sqrt{2b\ervy_{t}^{i}}dW_{t}^{i}, \end{equation} \]
where, as previously, each \(\ervy^{i}\) has limiting distribution \(\gG(\alpha_{i}, 1)\). The normalised, unit-sum vector is then Dirichlet distributed as \(t\to \infty\)
\[\rvz_{t} = \left(\frac{\ervy_{t}^{1}}{\sum_{i=1}^{D}\ervy_{t}^{i}}, \ldots, \frac{\ervy_{t}^{D}}{\sum_{i=1}^{D}\ervy_{t}^{i}}\right) \sim \gD(\valpha). \]
Reversing the Simplex Diffusion
In the previous section we obtained a forward process which allows us to perform a diffusion towards a vertex of the simplex (that is, a one-hot vector corresponding to a categorical sample).
Note that the forward process SDE in Equation \(\ref{eq:cir_fwd_sde}\) is in the general vector form given in Equation \(\ref{eq:forward_sde}\), where the drift coefficient is \(\rvf(\rvy_{t}, t) = b(\valpha - \rvy_{t}) \in \R^{D}\) and the diffusion coefficient is \(\rmG(\rvy_{t}, t) = \sqrt{2b} \cdot \text{diag}\big(\sqrt{\ervy_{t}^{1}}, \ldots, \sqrt{\ervy_{t}^{D}}\big)\in \R^{D\times D}\). Further, let \(p_{t}(\rvy_{t})\) denote the law of the probability density function of \(\rvy_{t}\) and \(\mSigma(\rvy_{t}, t) = \rmG(\rvy_{t}, t)\rmG(\rvy_{t}, t)^{\top} = 2b \rvy_{t} \cdot \text{diag}(\rvy_{t})\).
Now, by Equation \(\ref{eq:time_reversed_sde}\), the time-reversal of the multidimensional CIR given above is \((\rvx_{t})_{t\in [0, T]}\) such that \(\rvx_{t} = \rvy_{T-t}\) satisfies
\[\begin{align} d\rvx_{t} &= [-b(\valpha - \rvx_{t}) + 2b\rvx_{t}\cdot \text{diag}(\rvx_{t}) \nabla_{\rvx}\log p_{T-t}(\rvx_{t}) + 2b\rmI]dt \nonumber \\ &+ \rmG(\rvx_{t}, T-t) d\rmW_{t} \label{eq:time_reversed_cir} \end{align} \]
where \(\rvx_{0} \sim p_{T}\) is a one-hot vector. In practice, we use the score network
\[\begin{align} d\rvx_{t} &= [-b(\valpha - \rvx_{t}) + 2b \cdot \text{diag}(\rvx_{t}) s_{\vtheta}(\rvx_{t}, T-t) + 2b\rmI]dt \\ &+\rmG(\rvx_{t}, T-t) d\rmW_{t}, \end{align} \]
where at \(t = 0\) we recover the joint over the \(D\) independent Gamma distributions \(\rvx_{0} \sim p_{\text{ref}}(x^{1}, \ldots, x^{D}) = \prod_{i=1}^{D} \gG(z^{i}, \alpha_{i}, 1)\) and \(s_{\vtheta}(\rvx, t)\) is the score network approximating \(\nabla_{\rvx}\log p_{t}(\rvx)\).
Training
In order to use the usual denoising score matching training objectives, we need a closed form transition density for the CIR process. Omitting derivations, this can be obtained as
\[\begin{align} \label{eq:cir_density} p_{t\vert 0}(\rvy_{t} \vert \rvy_{0}) &= c \exp\left(-c(\rvy_{0}\exp(-bt) + \rvy_{t})\right)\left(\frac{\rvy_{t}\exp(bt)}{\rvy_{0}}\right)^{\frac{\alpha_{i}-1}{2}} \nonumber \\ &\hspace{1em} \cdot I_{\alpha_{i}-1}\left(2c\sqrt{\rvy_{0}\rvy_{t}\exp(-bt)}\right), \end{align} \]
where \(c = (1-\exp(-bt))^{-1}\) and \(I_{\alpha_{i}-1}\) is the modified Bessel function of the first kind of order \(\alpha_{i} - 1\).
Hence, we can evaluate the following objective to train the score network:
\[\begin{equation} \gL(\vtheta) = \mathbb{E}_{\substack{ \rvy_{0} \sim \gD \\ \rvy_{y} \sim p_{t\vert 0}(\rvy_{t}\vert \rvy_{0}) \\ }} \E_{t\sim U[0, T]} \Big[\big(\nabla_{\rvy}\log p_{t \vert 0}(\rvy_t \vert \rvy_{0}) - s_{\vtheta}(\rvy_{t}, t)\big)\text{diag}(\rvy_t)\big(\nabla_{\rvy}\log p_{t\vert 0}(\rvy_{t} \vert \rvy_{0}) - s_{\vtheta}(\rvy_{t}, t)\big)\Big]. \end{equation} \]
Diffusion on the Probability Simplex
In Diffusion on the Probability Simplex, Floto et al. (2023) also propose an alternative forward process that allows diffusion to proceed on the probability simplex in the discrete data setting. Whereas Richemond et al. (2023) use a forward process with a Dirichlet limiting distribution, in this section we will see a forward process with a Logistic-Gaussian limiting distribution.
Logistic-Gaussian Distributions
Recall the definition of the \(k\)-dimensional probability simplex \(\gS^{k}\) of Equation \(\ref{eq:k_simplex}\).
Definition (Logistic-Gaussian distribution). The Logistic-Gaussian distribution, has a bounded support over \([0, 1]\) and is the distribution of the softmax function applied to a Gaussian. That is, if \(\ry\) is a Gaussian-distributed random variable, and \(\sigma\) is the standard logistic function, then \(\rx = \sigma(\ry)\) is logistic-Gaussian distributed. The PDF is given by
\[\begin{align} f_{\gL\gG}(\rvx; \mu, \mSigma) &= \big\vert(2\pi)^{D-1}\mSigma\big\vert^{-1/2} \Big[\prod_{i=1}^{D}\rvx_{i}\Big]^{-1} \nonumber \\ &\hspace{1.2em}\exp\left(-\frac{1}{2}\left[\log \left(\frac{\bar{\rvx}}{\rvx}\right) - \mu\right]^{\top}\mSigma^{-1}\left[\log \left(\frac{\bar{\rvx}}{\rvx}\right) - \mu\right]\right), \label{eq:logistic_gaussian} \end{align} \]
where \(\rvx \in \gS^{D}\) and we use \(\bar{\rvx}\) to denote the first \(D-1\)
components of \(\rvx\); \(\bar{\rvx} = [\ervx_{1}, \ldots, \ervx_{D-1}]^{\top}\).
\(\blacksquare\)
To draw samples from the logistic-Gaussian, we map a point \(\rvy \in \R^{D-1}\) sampled as \(\rvy \sim \gN(\mu, \mSigma)\) to a point in \(\rvx \in \gS^{D}\) using an additive logistic transformation \(\sigma: \R^{D-1} \to \gS^{D}\) which is defined for dimension \(i \in \{1, \ldots, D\}\) as:
\[\begin{equation} \label{eq:add_one_logistic} \rvx_{i} = \sigma_{i}(\rvy) \doteq \begin{cases} \frac{\exp({\rvy_{i}})}{1 + \sum_{k=1}^{D-1}\exp({\rvy_{k}})},\ \ &\text{if }i \in \{1, \ldots, D-1\} \\ \frac{1}{1 + \sum_{k=1}^{D-1}\exp({\rvy_{k}})},\ \ &\text{if }i = D. \end{cases} \end{equation} \]
Similarly, note that the inverse \(\sigma^{-1}: \gS^{D} \to \R^{D-1}\) is
\[\rvy_{i} = \log \frac{\rvx_{i}}{\rvx_{D}} \ \text{for } i \in \{1, \ldots, D-1\}. \]
Ornstein-Ulenbeck Process
To construct the forward process, Floto et al. (2023) make use of an Ornstein-Ulenbeck (OU) process. This is a real-valued, multidimensional stochastic process which can be described using the following SDE:
\[\begin{equation} \label{eq:ou_sde} d\rvy_{t} = -\theta\rvy_{t} dt + \sigma d\rmW_{t}, \end{equation} \]
for \(\theta > 0\) and \(\sigma > 0\), and \((\rmW_{t})_{t\ge 0}\) is once again a Wiener process. The distribution of the OU process indexed at \(t\) is the following Gaussian:
\[\begin{equation} \label{eq:ou_t} \rvy_{t} = \gN\left(\rvy_{0}\exp(-\theta t), \frac{1}{2\theta}\big(1-\exp(-2\theta t)\big)\rmI\right), \end{equation} \]
and the limiting distribution of the OU process as \(t \to \infty\) is \(\gN(0, \frac{1}{2\theta})\). Note that \(\theta\) uniquely determines this limiting distribution—it is independent of the data \(\rvy_{0}\).
Using the Transformed OU Process for Diffusion on the Simplex
The method proposed by Floto et al. (2023) defines a diffusion process that operates on the probability simplex \(\gS^{D}\) as follows: the forward process uses the additive logistic transformation to map an OU process from \(\R^{D-1}\) to \(\gS^{D}\):
\[\begin{equation} \label{eq:logistic_ou_process} \rvx_{t} = \sigma(\rvy_{t}). \end{equation} \]
Note that the marginal at timestep \(t\) is given by
\[\begin{equation} \label{eq:logistic_ou_process_marginal} \rvx_{t} \sim \sigma\bigg(\gN\Big(\rvy_{0}\exp(-\theta t), \frac{1}{2\theta}\big(1 - \exp(-2\theta t)\big)\Big)\bigg). \end{equation} \]
In other words, we can easily obtain the transition density \(p_{t\vert 0}(\rvx_{t} \vert \rvx_{0})\) for the forward process as a logistic Gaussian distribution that we can easily sample from.
We can obtain the SDE for \(\rvx_{t}\) defined in Equation \(\ref{eq:logistic_ou_process}\) by applying Ito’s lemma to the SDE for \(\rvy_{t}\), which yields
\[\begin{equation} \label{eq:logistic_ou_sde} d\rvx_{t} = \rvf(\rvx_{t}, t)dt + \rmG(\rvx_{t}, t)d\rmW_{t}, \end{equation} \]
where the diffusion coefficient matrix is
\[\begin{equation} \label{eq:logistic_ou_sde_diffusion_coef} \rmG_{ij}(\rvx, t) = \begin{cases} \rvx_{i}(1 - \rvx_{i}), &i = j \\ -\rvx_{i}\rvx_{j}, &i \ne j \end{cases} \end{equation} \]
and the drift term is
\[\begin{equation} \label{eq:logistic_ou_sde_drift_coef} \rvf_{i}(\rvx, t) = -\theta \rvx_{i} \left[(1-\rvx_{i})\rva_{i} + \sum_{j\ne i}\rvx_{j}\rva_{j}\right], \end{equation} \]
for \(\rva_{j} = \rvx_{j} + \frac{1}{2}(1-2\rvx_{j})\).
Reversing the Simplex Diffusion
From the previous section, we have identified the terms of the general vector form of the SDE of Equation \(\ref{eq:forward_sde}\). We can hence obtain the time-reversal of the forward process, \((\rvx_{t})_{t\in[0, T]}\), where once again \(\rvx_{t} = \rvy_{T-t}\), by substituting these drift and diffusion coefficients into expression for the time-reversed SDE (Equation \(\ref{eq:time_reversed_sde}\)) which we re-state here:
\[\begin{align} d\rvx_{t} &= \rvf(\rvx_{t}, t) - \frac{1}{2}\nabla_{\rvx} [\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^\top]dt \nonumber \\ &-\frac{1}{2}\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^{\top}\nabla_{\rvx} \log p_{t}(\rvx_{t})dt + \rmG(\rvx_{t}, t)d\overline{\rmW} \end{align} \]
Note that the term \(\nabla_{\rvx}\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)\) can be found in closed form (see Appendix 3 in Floto et al. (2023)) and we can also obtain the score of the logistic-Gaussian in closed form as
\[\begin{align} \nabla_{\rvx}\log p(\rvx)_{i} = &-\frac{1}{v} \left(\frac{1}{x_{D}} \sum_{k=1}^{D-1}\sigma^{\mu}_{k}(\rvx) + \frac{1}{x_{D}}\sigma^{\mu}_{i}(\rvx)\right) \nonumber \\ &+ \frac{\rvx_{i} - \rvx_{D}}{x_{i}x_{D}} \end{align} \]
where \(\sigma_{k}^{\mu}(\rvx) \doteq \log \frac{\rvx_{i}}{\rvx_{d}} - \mu\).
Training
We may insert the above terms into a standard weighted Fisher divergence objective to obtain
\[\begin{equation} \gL(\vtheta) = \E_{t\sim U[0, 1]} \mathbb{E}_{\substack{ \rvy_{0} \sim \gD \\ \rvy_{y} \sim p_{t\vert 0}(\rvy_{t}\vert \rvy_{0}) \\ }} \lambda(t)\Big[\Vert \nabla_{\rvy}\log p_{t\vert 0}(\rvy_{t} \vert \rvy_{0}) - s_{\vtheta}(\rvy_{t}, t)\Vert_{2}^{2}\Big]. \end{equation} \]
While in theory, for a dataset of \(D\) different categories, each datapoint would correspond to a one-hot vector at \(t=0\), in practice this leads to instabilities and the authors instead map samples to \(\rvy = [\alpha, \beta, \ldots, \beta]^{\top}\) where \(\beta \doteq \frac{1-\alpha}{d-2}\) and \(\alpha = 1-\epsilon\).
The authors also report that the score suffers from numerical instabilities near the perimeter regions of the simplex. To rectify this, the authors train the network to directly predict \(-\frac{1}{2}\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^{\top}\nabla_{\rvx} \log p_{t}(\rvx_{t})\) which is bounded at the border of the interval \([0, 1]\) and empirically varies less quickly.
Conclusion
While we have seen a couple of methods for performing diffusion over the probability simplex, and hence allow us to use diffusion models with discrete data, these aren’t without their limitations. The methods above often require ad-hoc \(\epsilon\) stabilisation parameters to avoid discretisation errors, among other engineering tricks such as carefully selected warmup schedules and controlling the noise to respect the geometric constraints of the simplex.
Further these methods may not scale elegantly to very high dimensional discrete data, where the inputs become extremely sparse owing to the one-hot representation of the discrete data. This is crucial for applications such as language modelling where tokenisation schemes often result in vocabularies in excess of \(50,000\) tokens.
Nonetheless these methods show promise for a range of generative modelling tasks in discrete data settings.
Appendix A: Score Network Objective
The Fisher divergence
\[\frac{1}{2}\E_{p_{\text{data}}(\rvx)}\Big[\big\Vert \nabla_{\rvx} \log p_{\text{data}}(\rvx) - s_{\vtheta}(\rvx)\big\Vert_{2}^{2}\Big], \]
is not computable due to the presence of the \(\nabla_{\rvx} \log p_{\text{data}}(\rvx)\) term.
To eliminate this term, we may use score matching and integration by parts. Taking the 1D setting for simplicity, we begin by expanding out the brackets in the Fisher divergence
\[\begin{align*} &\hspace{1.2em} \frac{1}{2}\E_{p_{\text{data}}(x)}\Big[\big(\nabla_{x}\log p_{\text{data}}(x) - s_{\vtheta}(x)\big)^2\Big] \\ &= \frac{1}{2} \int p_{\text{data}}(x) \big(\nabla_{x} \log p_{\text{data}}(x) - s_{\vtheta}(x)\big)^{2}dx \\ &= \underbrace{\frac{1}{2} \int p_{\text{data}}(x) \big(\nabla_{x} \log p_{\text{data}}(x)\big)^{2}dx}_{\text{const wrt. }\vtheta} + \frac{1}{2} \int p_{\text{data}}(x)\big(s_{\vtheta}(x)\big)^{2}dx \\ &\hspace{1.2em}-\int p_{\text{data}}(x) s_{\vtheta}(x) \nabla_{x}\log p_{\text{data}}(x)dx. \end{align*} \]
We note from the chain rule that \(p(x)\nabla_{x}\log p(x) = \nabla_{x}p(x)\) 12: sometimes called the ‘log-derivative trick’ in machine learning 12[12] and hence the final term above becomes
\[- \int s_{\vtheta}(x)\nabla_{x}p_{\text{data}}(x) dx. \]
Integrating the above by parts is very straightforward; selecting \(u = s_{\vtheta}(x)\) and \(dv = \nabla_{x}p_{\text{data}}(x)\), we get \(du = \nabla_{x}s_{\vtheta}(x)\) and \(v = p_{\text{data}}(x)\) and thus
\[\begin{align*} -\int s_{\vtheta}(x)\nabla_{x}p_{\text{data}}(x)dx &= -p_{\text{data}} s_{\vtheta}(x)\bigg\vert_{-\infty}^{\infty} + \int p_{\text{data}}(x) \nabla_{x}s_{\vtheta}(x)dx \\ &= \E_{p_{\text{data}}(x)}\big[\nabla_{x}s_{\vtheta}(x)\big], \end{align*} \]
where the last line above holds if we assume that \(p_{\text{data}}(x) \to 0\) as \(\vert x \vert \to \infty\), which is a reasonable assumption for correctly normalised data distributions.
Substituting this back into the Fisher divergence,
\[\begin{align*} &\hspace{1.2em}\frac{1}{2}\E_{p_{\text{data}}(x)}\big[\big(\nabla_{x}\log p_{\text{data}}(x) - s_{\vtheta}(x)\big)^2\big] \\ &= \E_{p_{\text{data}}(x)}\left[\frac{1}{2}s_{\vtheta}(x)^{2} +\nabla_{x}s_{\vtheta}(x)\right] + \text{const}. \end{align*} \]
we see that we have removed the dependence on the score of the data distribution in the 1D case.
Hyvärinen (2005) show how we can extend the above to the more general multidimensional setting as:
\[\E_{p_{\text{data}}(\rvx)}\left[\frac{1}{2}\Vert s_{\vtheta}(\rvx)\Vert^{2}_{2} + \Tr\big(\nabla_{\rvx}s_{\vtheta}(\rvx)\big)\right] + \text{const}. \]