Consider a generic generative modelling task, where we would like to estimate the density of some unknown data distribution \(p_{\text{data}}(\rvx)\) given a dataset of points \(\gD = \{\rvx_{i}\}_{i=1}^{N}\) each drawn from this distribution \(\rvx_{i} \sim p_{\text{data}}(\rvx)\).

Different choices of generative models represent the underlying probability distribution in different ways, and these will have implications for any downstream modelling tasks.

For instance, likelihood based methods are a broad class of methods which
directly model the data likelihood \(p_{\text{data}}(\rvx)\) using a parametric
function approximator \(f_{\vtheta}(\rvx) \in \R\).
0: i.e. a neural network of
some sort.
^{0[0]}
One very general way to construct a *neural* likelihood, inspired
by energy-based models, might be to write the likelihood as:

\[\begin{equation} \label{eq:ebm} p_{\vtheta}(\rvx) = \frac{1}{Z_{\vtheta}}\exp\big({-f_{\vtheta}(\rvx)}\big), \end{equation} \]

where the partition function \(Z_{\vtheta} > 0\) ensures that the density
normalises; that is, \(\int p_{\vtheta}(\rvx)d\rvx = 1\). For most neural networks
\(f_{\vtheta}\), computing this partition function is intractable. Further, for
maximum likelihood training of \(\vtheta\) under such a model, where we maximise
an objective of the form
1: i.e. maximising the (log) likelihood of the data,
which is assumed i.i.d. hence the simple sum over the individual datapoint
log-likelihoods in Equation \(\ref{eq:ml}\)
^{1[1]}

\[\begin{equation} \label{eq:ml} \gL_{\text{ML}}(\vtheta) = \sum_{x \in \gD} \log p_{\vtheta}(\rvx), \end{equation} \]

we must re-compute the partition function for each gradient step, exacerbating our computational woes under this framework.

Likelihood based model families get around this need to compute an expensive
(if, intractable) normalising constant in different ways. VAEs
(Kingma *et al.*, 2014; Rezende *et al.*, 2014) optimise a tractable lower
bound on the log-likelihood; autoregressive approaches
(Germain *et al.*, 2015; Salimans *et al.*, 2016)
exploit the structure of the joint; and flow-based models
(Rezende *et al.*, 2015; Kingma *et al.*, 2017) impose a restrictive model
architecture on \(f_{\vtheta}\) to ensure the invertability of the learned
function and an easily computable Jacobian determinant. Each of these
solutions introduce side effects—such as slower sampling speed or less
expressive models—that one must trade off when choosing a method.

Diffusion models on the other hand have emerged as a successful alternative
class of generative models, which owe much of their popularity to the
conspicuous lack of constraints on the model architecture, non-adversarial
training procedure and their impressive empirical performance in numerous
domains
2: including, most famously, image generation, but also audio
synthesis, time series modelling and many others.
^{2[2]}
. These generative models
not only allow us to sample new instances from the approximated distribution,
but also condition on part of a data point to impute the rest (perhaps to fill
in missing datapaoints) or evaluate the likelihood of a new test point.

Diffusion models, as originally set out
(Sohl-Dickstein *et al.*, 2015; Ho *et al.*, 2020), are however only suitable for
continuous data. Yet there are many instances of structured, discrete data
for which we might want a generative model: graphs, text, genomic sequences
and many others. In what follows, we will first consider the SDE view of
diffusion models, which will give us a flexible framework in which to discuss
several modifications that allow diffusion models to work on the probability
simples, and thus over discrete data.

## Preliminaries: Score-Based Generative Modelling with SDEs

By way of setting out notation and keeping this post relatively self-contained,
we will take a brief overview of the SDE interpretation of diffusion
models (Song *et al.*, 2020).

In a basic score-based generative model, we train a neural network to estimate the
Stein score
3: as opposed to the *Fisher* score \(\nabla_{\vtheta}\log p_{\vtheta}(\rvx)\) where the gradient is taken wrt. the parameters \(\vtheta\) of
the distribution, and which we come across in RL, VI and other areas in ML, the
*Stein* score \(\nabla_{\rvx}\log p_{\vtheta}(\rvx)\) takes the gradient wrt. the
data, \(\rvx\).
^{3[3]}
of the data distribution, \(\nabla_{\rvx}\log p_{\text{data}}(\rvx)\), using a *score network* \(s_{\vtheta}: \R^{D} \to \R^{D}\)
optimised such that \(s_{\vtheta}(\rvx) \approx \nabla_{\rvx}\log p_{\text{data}}(\rvx)\). Taking the gradient of the likelihood in Equation
\(\ref{eq:ebm}\) wrt. \(\rvx\), we
can see how this gets around the normalisation issues faced by likelihood-based
models, since the gradient of the log partition function goes to \(0\)

\[s_{\vtheta}(\rvx) = \nabla_{\rvx}\log p_{\vtheta}(\rvx) = - \underbrace{\nabla_{\rvx}\log Z_{\vtheta}}_{=0} -\nabla_{\rvx} f_{\vtheta}(\rvx) = -\nabla_{\rvx}f_{\vtheta}(\rvx). \]

One objective we may *minimise*
4: apologies for the inconsistency in notation
between the ML objective we *maximised* in Equation \(\ref{eq:ml}\) and
this one which we *minimise*.
^{4[4]}
to train such a score network \(s_{\vtheta}\) is
the squared \(\ell_{2}\) distance between the ground-truth data score and the
score network output, a quantity known as the *Fisher divergence*:

\[\begin{equation} \label{eq:fisher_divergence} \gL_{\text{fisher}}(\vtheta) = \E_{p_{\text{data}}(\rvx)}\left[\Vert \nabla_{\rvx}\log p_{\text{data}}(\rvx) - s_{\vtheta}(\rvx)\Vert_{2}^{2}\right]. \end{equation} \]

While the ground-truth data score is unknown
5: the \(p_{\text{data}}(\rvx)\)
distribution is the very thing we’re trying to learn
^{5[5]}
, it can be shown
(Hyvärinen, 2005; Vincent, 2011) (also see Appendix A ) that
an alternative objective which does not depend on the ground truth scores
\(\nabla_{\rvx}\log p_{\text{data}}(\rvx)\) can remarkably be found as:

\[\begin{equation} \label{eq:score_matching} \gL(\vtheta) = \E_{p_{\text{data}}(\rvx)}\left[\frac{1}{2}\Vert s_{\vtheta}(\rvx)\Vert_{2}^{2} + \Tr(\nabla_{x}s_{\vtheta}(\rvx))\right] + \text{const}. \end{equation} \]

Despite resolving this, an issue remains with the above in that this objective will produce score networks \(s_{\vtheta}(\rvx)\) which are very inaccurate in low density regions where we observe few data points:

\[\begin{equation} \label{eq:score_matching_int} \gL(\vtheta) = \int {\color{#D32F2F}p_{\text{data}}(\rvx)}\left(\frac{1}{2}\Vert s_{\vtheta}(\rvx)\Vert_{2}^{2} + \Tr\big(\nabla_{x}s_{\vtheta}(\rvx)\big)\right) d\rvx + \text{const}. \end{equation} \]

Instead, diffusion models resolve this issue by introducing a known, time-dependent perturbation to the data \(p_{t}(\rvx)\), and instead training the score network to match the score of this perturbation \(\nabla_{\rvx}\log p_{t}(\rvx)\) at a given time \(t \in [0, T]\) rather than the score of the data distribution \(\nabla_{\rvx} \log p_{\text{data}}(\rvx)\). Intuitively, \(\nabla_{\rvx} \log p_{\text{data}}(\rvx)\) may not make for a very good learning signal, and we would rather design our own fate by being the ones in control of \(\nabla_{\rvx}\log p_{t}(\rvx)\).

This perturbation is selected such that at \(t = 0\), we recover the original
(unperturbed) data, for increasing time \(t > 0\) we progressively add more noise
until at some final time \(T\) we reach a limiting distribution
\(p_{t=1}(\rvx)\) which is ideally chosen to be easily sampled from and
independent from the data distribution
6: for this reason, it is sometimes
referred to as a *prior* distribution; by analogy to the prior in a VAE
or normalising flow.
^{6[6]}
.
By ensuring that \(p_{t}(\rvx)\) is sufficiently noisy for large values of \(t\),
we can make sure that there is a sufficiently strong learning signal away from
the observed data.

### The SDE View

More generally, these perturbations can be flexibly specified
through a stochastic differential equation (SDE), whose solution is known
as the forward (or *noise corrupting*) process, and whose marginal at any time
\(t \in [0, T]\) can be obtained in closed form, \(p_{t\vert 0}(\rvx_{t} \vert \rvx_{0})\) given the boundary condition of \(\rvx_{0}\) at \(t=0\) being a point
drawn from the dataset.

That is, we have an SDE of the following form

\[\begin{equation} \label{eq:forward_sde} d\rvx_{t} = f(\rvx, t)dt + \rmG(\rvx_{t}, t)d\rmW_{t}, \end{equation} \]

where \(\rmW\) is the standard Wiener process, \(f(\cdot, t): \R^{D} \to \R^{D}\) is the drift term, and \(\rmG(\cdot, t): \R^{D} \to \R^{D\times D}\) the diffusion coefficient.

For example, one extremely simple SDE is \(d\rvx_{t} = e^{t}\rmI d\rmW_{t}\),
corresponding to zero-mean Gaussian noise perturbations with exponentially
growing variance in \(t\). Other potential SDEs include the *Variance
Exploding SDE*, the *Variance Preserving SDE* and the *sub-VP
SDE* (Song *et al.*, 2020).

Our new score-matching objective is to match the score of the forward process
\(\nabla_{\rvx} \log p_{t}(\rvx)\) with a *time-dependent* score
network
7: Introducing the time dependence usually involves a fairly
straightforward modification to most architectures, where we embed the time
value using a learned or fixed embedding, and add it to the input activations
at various layers.
^{7[7]}
\(s_{\vtheta}(\rvx, t)\), such that \(s_{\vtheta}(\rvx, t) \approx \nabla_{\rvx}\log p_{t}(\rvx)\). We can use the following weighted
combination of Fisher divergences as our objective to minimise

\[\begin{equation} \label{eq:weighted_fisher_div} \gL_{\text{TDS}}(\vtheta) = \E_{t\sim U[0, T]}\E_{p_{t}(\rvx)}\Big[\lambda(t) \big\Vert \nabla_{\rvx}\log p_{t}(\rvx) - s_{\vtheta}(\rvx, t)\big\Vert_{2}^{2}\Big], \end{equation} \]

where we sample time values uniformly at random along the time interval, and \(\lambda: \R \to \R_{>0}\), defined as \(\lambda(t) \propto 1/\E\big[\Vert \nabla_{\rvx_{t}}\log p_{t\vert 0}\big(\rvx_{t} \vert \rvx_{0}\big)\Vert_{2}^{2}\big]\), serves to normalise the magnitude of the different \(\ell_{2}\) losses across time.

With this score function (or the approximation thereof; \(s_{\vtheta}(\rvx, t)\))
in hand, we can draw samples from the learned distribution by solving the
time-reversed SDE
8: which we state for the multivariate form for consistency with later sections. The univariate form is somewhat simpler: \(d\rvx_{t} = [\rvf(\rvx_{t}, t) - g^{2}(t)\rvs_{\vtheta}(\rvx_{t}, t)]dt + g(t)d\rvw\)
^{8[8]}

\[\begin{align} d\rvx_{t} &= \rvf(\rvx_{t}, t) - \frac{1}{2}\nabla_{\rvx} [\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^\top]dt \nonumber \\ &-\frac{1}{2}\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^{\top}\nabla_{\rvx} \log p_{t}(\rvx_{t})dt + \rmG(\rvx_{t}, t)d\overline{\rmW} \label{eq:time_reversed_sde} \end{align} \]

where time flows backwards from \(t=T\) to \(0\) and \(\overline{\rmW}\) is the
time-reversed Wiener process
9: Whose properties are identical to the usual
Wiener process, yet we use the bar notation to avoid ambiguity about whether
we’re referring to the forward or reverse process.
^{9[9]}
.

We can solve this reverse SDE and generate a sample using any numerical method we’d like, with the Euler-Maruyama method being the most straightforward. Selecting some small negative time increment \(\Delta t = -\epsilon\) for some \(\epsilon > 0\), initialising \(t \gets T\), and iterating until \(t \approx 0^{+}\), we repeat

\[\begin{align*} \Delta \rvx &\gets \rvf(\rvx_{t}, t) dt - \frac{1}{2}\nabla_{\rvx}[\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^{\top}]dt \\ &- \frac{1}{2}\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^{\top}{\color{#3B5E8C}\nabla_{\rvx}\log p_{t}(\rvx_{t})}dt + \rmG(\rvx_{t}, t)\sqrt{\vert \Delta t\vert}\rvz_{t} \\ \rvx &\gets \rvx + \Delta \rvx \\ t &\gets t + \Delta t, \end{align*} \]

where \(\rvz_{t} \sim \gN(0, \rmI)\). In practice, we substitute \({\color{#3B5E8C}\nabla_{\rvx}\log p_{t}(\rvx_{t})}\) with the output of our time-conditioned score network \(s_{\vtheta}(\rvx_{t}, t)\).

The key challenge of applying the above on discrete data is that the score
function \(\nabla_{\rvx}\log p(\rvx)\) is undefined: there is no notion of the
direction in which the probability mass function changes smoothly, or any
*smoothness* at all for that matter.

There have been several proposals to adapt this model class to work with
discrete data. In what follows, we will consider two approaches to perform
diffusion on the probability simplex (Richemond *et al.*, 2023; Floto *et al.*, 2023) as well as the Diffusion Bridges approach of Liu *et al.* (2022).
Other promising approaches include
Concrete Score Matching (Meng *et al.*, 2022), VQ-Diffusion
(Gu *et al.*, 2022) which we do not treat here.
In a related line of work, Bayesian flow networks
(Graves *et al.*, 2023) also provide a way to use a
‘diffusion-like’ model
10: in the sense that the BFN is a generative model
class which is not autoregressive and places few constraints on the network
^{10[10]}
with discrete data.

## Categorical SDEs with Simplex Diffusion

In *Categorical SDEs with Simplex Diffusion*,
Richemond *et al.* (2023) propose a discrete diffusion
construction by deriving a tractable multivariate stochastic process that
operates on the probability simplex itself. That is, the limiting distribution
for each dimension of this multivariate process approach a certain Gamma
distribution, which may be used to draw samples from a Dirichlet distribution
on the simplex.

### Dirichlet Distributions

The \(k\)-dimensional probability simplex is defined as

\[\begin{equation} \label{eq:k_simplex} \gS^{k} \doteq \Big\{\rvx \in \R^{k}: 0 \le \rvx_{i} \le 1, \sum_{i=1}^{k} \rvx_{i} = 1\Big\}. \end{equation} \]

Each point on the simplex can be seen to define an \(k+1\) dimensional categorical distribution.

**Definition** (Dirichlet distribution)**.** The *Dirichlet distribution*,
denoted \(\gD(\valpha)\) is a multivariate, continuous distribution parametrised
by a vector of positive scalars \(\valpha \doteq (\alpha_{1}, \ldots, \alpha_{D})\) whose probability density function is defined
11: Using the
standard Lebesgue measure on \(\R^{D}\)
^{11[11]}
as

\[\begin{equation} \label{eq:dirichlet_density} f_{\gD}(\rvx, \valpha) = \frac{1}{Z_{\valpha}}\prod_{i=1}^{D}\ervx_{i}^{\alpha_{i}-1}, \end{equation} \]

for \(\rvx \in \R^{D}\), \(\alpha_{i} \in \R^{+}\ \ \forall i\) and \(Z_{\valpha}\) the partition function or *normalising constant*.

\(\blacksquare\)

Note that setting \(\alpha_{i} = 1\) for all \(i\) gives us a uniform distribution over the simplex. Further, sampling from a Dirichlet can be done with the following two-step procedure:

- Sample \(D\)
*independent*Gamma random variables \(\ervy_{1} \sim \gG(\alpha_{1}, \beta), \ldots, \ervy_{D} \sim \gG(\alpha_{D}, \beta)\) where \(\alpha_{i}\) are the individual shape parameters, and \(\beta\) is a shared rate parameter. - Normalise these random variables to sum to \(1\), resulting in a Dirichlet-distributed random vector:

\[\left(\frac{\ervy_{1}}{\sum_{i=1}^{D}\ervy_{i}}, \ldots, \frac{\ervy_{D}}{\sum_{i=1}^{D}\ervy_{i}}\right) \sim \gD(\valpha). \]

The above result holds for any positive \(\beta\).
Richemond *et al.* (2023) choose to use a shared rate
parameter of \(\beta = 1\).

### Cox-Ingersoll-Ross Process

In order to define a suitable forward process, we use a Cox-Ingersoll-Ross (CIR) process, whose limiting distribution can be made into the Gamma distribution \(\gG(\alpha_{i}, 1)\). The key insight is that we can treat this as one of the \(D\) Gamma distributions required to sample from the Dirichlet, as set out in the sampling procedure above.

In more detail, the CIR process is a real-valued diffusion process defined by the following SDE in \(\theta_{t}\)

\[\begin{equation} \label{eq:cir_process_sde} d\theta_{t} = b(a - \theta_{t})dt + \sigma \sqrt{\theta_{t}}dW_{t}, \end{equation} \]

which holds for any \(\theta_{0} \ge 0\), \(a, b, \sigma > 0\) and where \((W_{t})_{t\ge 0}\) is a Wiener process. The invariant limiting distribution is \(\gG(2ab/\sigma^{2}, 2b/\sigma^{2})\), and if \(2ab \ge \sigma^{2}\) with positive \(\theta_{0} > 0\), then the process becomes strictly positive. Setting \(2b = \sigma^{2}\), the SDE in Equation \(\ref{eq:cir_process_sde}\) becomes

\[\begin{equation} d\theta_{t} = b(a - \theta_{t})dt + \sqrt{2b\theta_{t}}dW_{t}, \end{equation} \]

which admits the Gamma distribution \(\gG(a, 1)\) as its limiting distribution, as required for constructing one dimension of the Dirichlet random vector.

### Using the CIR Process for Diffusion on the Simplex

A forward process whose marginal distribution in the large time limit provides samples from a Dirichlet distribution \(\gD(\valpha)\) would allow us to perform diffusion on the simplex. Using \(D\) independent CIR processes (as described in the previous section) and simulating them in parallel, we obtain a process \(\rvy_{t}\) with positive marginals, each following

\[\begin{equation} \label{eq:cir_fwd_sde} d\ervy_{t}^{i} = b(\alpha_{i} - \ervy_{t}^{i})dt + \sqrt{2b\ervy_{t}^{i}}dW_{t}^{i}, \end{equation} \]

where, as previously, each \(\ervy^{i}\) has limiting distribution \(\gG(\alpha_{i}, 1)\). The normalised, unit-sum vector is then Dirichlet distributed as \(t\to \infty\)

\[\rvz_{t} = \left(\frac{\ervy_{t}^{1}}{\sum_{i=1}^{D}\ervy_{t}^{i}}, \ldots, \frac{\ervy_{t}^{D}}{\sum_{i=1}^{D}\ervy_{t}^{i}}\right) \sim \gD(\valpha). \]

### Reversing the Simplex Diffusion

In the previous section we obtained a forward process which allows us to perform a diffusion towards a vertex of the simplex (that is, a one-hot vector corresponding to a categorical sample).

Note that the forward process SDE in Equation \(\ref{eq:cir_fwd_sde}\) is in the general vector form given in Equation \(\ref{eq:forward_sde}\), where the drift coefficient is \(\rvf(\rvy_{t}, t) = b(\valpha - \rvy_{t}) \in \R^{D}\) and the diffusion coefficient is \(\rmG(\rvy_{t}, t) = \sqrt{2b} \cdot \text{diag}\big(\sqrt{\ervy_{t}^{1}}, \ldots, \sqrt{\ervy_{t}^{D}}\big)\in \R^{D\times D}\). Further, let \(p_{t}(\rvy_{t})\) denote the law of the probability density function of \(\rvy_{t}\) and \(\mSigma(\rvy_{t}, t) = \rmG(\rvy_{t}, t)\rmG(\rvy_{t}, t)^{\top} = 2b \rvy_{t} \cdot \text{diag}(\rvy_{t})\).

Now, by Equation \(\ref{eq:time_reversed_sde}\), the time-reversal of the multidimensional CIR given above is \((\rvx_{t})_{t\in [0, T]}\) such that \(\rvx_{t} = \rvy_{T-t}\) satisfies

\[\begin{align} d\rvx_{t} &= [-b(\valpha - \rvx_{t}) + 2b\rvx_{t}\cdot \text{diag}(\rvx_{t}) \nabla_{\rvx}\log p_{T-t}(\rvx_{t}) + 2b\rmI]dt \nonumber \\ &+ \rmG(\rvx_{t}, T-t) d\rmW_{t} \label{eq:time_reversed_cir} \end{align} \]

where \(\rvx_{0} \sim p_{T}\) is a one-hot vector. In practice, we use the score network

\[\begin{align} d\rvx_{t} &= [-b(\valpha - \rvx_{t}) + 2b \cdot \text{diag}(\rvx_{t}) s_{\vtheta}(\rvx_{t}, T-t) + 2b\rmI]dt \\ &+\rmG(\rvx_{t}, T-t) d\rmW_{t}, \end{align} \]

where at \(t = 0\) we recover the joint over the \(D\) independent Gamma distributions \(\rvx_{0} \sim p_{\text{ref}}(x^{1}, \ldots, x^{D}) = \prod_{i=1}^{D} \gG(z^{i}, \alpha_{i}, 1)\) and \(s_{\vtheta}(\rvx, t)\) is the score network approximating \(\nabla_{\rvx}\log p_{t}(\rvx)\).

### Training

In order to use the usual denoising score matching training objectives, we need a closed form transition density for the CIR process. Omitting derivations, this can be obtained as

\[\begin{align} \label{eq:cir_density} p_{t\vert 0}(\rvy_{t} \vert \rvy_{0}) &= c \exp\left(-c(\rvy_{0}\exp(-bt) + \rvy_{t})\right)\left(\frac{\rvy_{t}\exp(bt)}{\rvy_{0}}\right)^{\frac{\alpha_{i}-1}{2}} \nonumber \\ &\hspace{1em} \cdot I_{\alpha_{i}-1}\left(2c\sqrt{\rvy_{0}\rvy_{t}\exp(-bt)}\right), \end{align} \]

where \(c = (1-\exp(-bt))^{-1}\) and \(I_{\alpha_{i}-1}\) is the modified Bessel function of the first kind of order \(\alpha_{i} - 1\).

Hence, we can evaluate the following objective to train the score network:

\[\begin{equation} \gL(\vtheta) = \mathbb{E}_{\substack{ \rvy_{0} \sim \gD \\ \rvy_{y} \sim p_{t\vert 0}(\rvy_{t}\vert \rvy_{0}) \\ }} \E_{t\sim U[0, T]} \Big[\big(\nabla_{\rvy}\log p_{t \vert 0}(\rvy_t \vert \rvy_{0}) - s_{\vtheta}(\rvy_{t}, t)\big)\text{diag}(\rvy_t)\big(\nabla_{\rvy}\log p_{t\vert 0}(\rvy_{t} \vert \rvy_{0}) - s_{\vtheta}(\rvy_{t}, t)\big)\Big]. \end{equation} \]

## Diffusion on the Probability Simplex

In *Diffusion on the Probability Simplex*,
Floto *et al.* (2023) also propose an alternative
forward process that allows diffusion to proceed on the probability simplex in
the discrete data setting. Whereas Richemond *et al.* (2023)
use a forward process with a Dirichlet limiting distribution, in this section
we will see a forward process with a Logistic-Gaussian limiting distribution.

### Logistic-Gaussian Distributions

Recall the definition of the \(k\)-dimensional probability simplex \(\gS^{k}\) of Equation \(\ref{eq:k_simplex}\).

**Definition** (Logistic-Gaussian distribution)**.** The *Logistic-Gaussian
distribution*, has a bounded support over \([0, 1]\) and is the distribution of
the softmax function applied to a Gaussian. That is, if \(\ry\) is a
Gaussian-distributed random variable, and \(\sigma\) is the standard logistic
function, then \(\rx = \sigma(\ry)\) is logistic-Gaussian distributed. The
PDF is given by

\[\begin{align} f_{\gL\gG}(\rvx; \mu, \mSigma) &= \big\vert(2\pi)^{D-1}\mSigma\big\vert^{-1/2} \Big[\prod_{i=1}^{D}\rvx_{i}\Big]^{-1} \nonumber \\ &\hspace{1.2em}\exp\left(-\frac{1}{2}\left[\log \left(\frac{\bar{\rvx}}{\rvx}\right) - \mu\right]^{\top}\mSigma^{-1}\left[\log \left(\frac{\bar{\rvx}}{\rvx}\right) - \mu\right]\right), \label{eq:logistic_gaussian} \end{align} \]

where \(\rvx \in \gS^{D}\) and we use \(\bar{\rvx}\) to denote the first \(D-1\)
components of \(\rvx\); \(\bar{\rvx} = [\ervx_{1}, \ldots, \ervx_{D-1}]^{\top}\).

\(\blacksquare\)

To draw samples from the logistic-Gaussian, we map a point \(\rvy \in \R^{D-1}\) sampled as \(\rvy \sim \gN(\mu, \mSigma)\) to a point in \(\rvx \in \gS^{D}\) using an additive logistic transformation \(\sigma: \R^{D-1} \to \gS^{D}\) which is defined for dimension \(i \in \{1, \ldots, D\}\) as:

\[\begin{equation} \label{eq:add_one_logistic} \rvx_{i} = \sigma_{i}(\rvy) \doteq \begin{cases} \frac{\exp({\rvy_{i}})}{1 + \sum_{k=1}^{D-1}\exp({\rvy_{k}})},\ \ &\text{if }i \in \{1, \ldots, D-1\} \\ \frac{1}{1 + \sum_{k=1}^{D-1}\exp({\rvy_{k}})},\ \ &\text{if }i = D. \end{cases} \end{equation} \]

Similarly, note that the inverse \(\sigma^{-1}: \gS^{D} \to \R^{D-1}\) is

\[\rvy_{i} = \log \frac{\rvx_{i}}{\rvx_{D}} \ \text{for } i \in \{1, \ldots, D-1\}. \]

### Ornstein-Ulenbeck Process

To construct the forward process, Floto *et al.* (2023)
make use of an Ornstein-Ulenbeck (OU) process. This is a real-valued,
multidimensional stochastic process which can be described using the following
SDE:

\[\begin{equation} \label{eq:ou_sde} d\rvy_{t} = -\theta\rvy_{t} dt + \sigma d\rmW_{t}, \end{equation} \]

for \(\theta > 0\) and \(\sigma > 0\), and \((\rmW_{t})_{t\ge 0}\) is once again a Wiener process. The distribution of the OU process indexed at \(t\) is the following Gaussian:

\[\begin{equation} \label{eq:ou_t} \rvy_{t} = \gN\left(\rvy_{0}\exp(-\theta t), \frac{1}{2\theta}\big(1-\exp(-2\theta t)\big)\rmI\right), \end{equation} \]

and the limiting distribution of the OU process as \(t \to \infty\) is \(\gN(0, \frac{1}{2\theta})\). Note that \(\theta\) uniquely determines this limiting distribution—it is independent of the data \(\rvy_{0}\).

### Using the Transformed OU Process for Diffusion on the Simplex

The method proposed by Floto *et al.* (2023) defines a
diffusion process that operates on the probability simplex \(\gS^{D}\) as
follows: the forward process uses the additive logistic transformation to map
an OU process from \(\R^{D-1}\) to \(\gS^{D}\):

\[\begin{equation} \label{eq:logistic_ou_process} \rvx_{t} = \sigma(\rvy_{t}). \end{equation} \]

Note that the marginal at timestep \(t\) is given by

\[\begin{equation} \label{eq:logistic_ou_process_marginal} \rvx_{t} \sim \sigma\bigg(\gN\Big(\rvy_{0}\exp(-\theta t), \frac{1}{2\theta}\big(1 - \exp(-2\theta t)\big)\Big)\bigg). \end{equation} \]

In other words, we can easily obtain the transition density \(p_{t\vert 0}(\rvx_{t} \vert \rvx_{0})\) for the forward process as a logistic Gaussian distribution that we can easily sample from.

We can obtain the SDE for \(\rvx_{t}\) defined in Equation \(\ref{eq:logistic_ou_process}\) by applying Ito’s lemma to the SDE for \(\rvy_{t}\), which yields

\[\begin{equation} \label{eq:logistic_ou_sde} d\rvx_{t} = \rvf(\rvx_{t}, t)dt + \rmG(\rvx_{t}, t)d\rmW_{t}, \end{equation} \]

where the diffusion coefficient matrix is

\[\begin{equation} \label{eq:logistic_ou_sde_diffusion_coef} \rmG_{ij}(\rvx, t) = \begin{cases} \rvx_{i}(1 - \rvx_{i}), &i = j \\ -\rvx_{i}\rvx_{j}, &i \ne j \end{cases} \end{equation} \]

and the drift term is

\[\begin{equation} \label{eq:logistic_ou_sde_drift_coef} \rvf_{i}(\rvx, t) = -\theta \rvx_{i} \left[(1-\rvx_{i})\rva_{i} + \sum_{j\ne i}\rvx_{j}\rva_{j}\right], \end{equation} \]

for \(\rva_{j} = \rvx_{j} + \frac{1}{2}(1-2\rvx_{j})\).

### Reversing the Simplex Diffusion

From the previous section, we have identified the terms of the general vector form of the SDE of Equation \(\ref{eq:forward_sde}\). We can hence obtain the time-reversal of the forward process, \((\rvx_{t})_{t\in[0, T]}\), where once again \(\rvx_{t} = \rvy_{T-t}\), by substituting these drift and diffusion coefficients into expression for the time-reversed SDE (Equation \(\ref{eq:time_reversed_sde}\)) which we re-state here:

\[\begin{align} d\rvx_{t} &= \rvf(\rvx_{t}, t) - \frac{1}{2}\nabla_{\rvx} [\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^\top]dt \nonumber \\ &-\frac{1}{2}\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^{\top}\nabla_{\rvx} \log p_{t}(\rvx_{t})dt + \rmG(\rvx_{t}, t)d\overline{\rmW} \end{align} \]

Note that the term \(\nabla_{\rvx}\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)\) can be
found in closed form (see Appendix 3 in
Floto *et al.* (2023)) and we can also obtain the score
of the logistic-Gaussian in closed form as

\[\begin{align} \nabla_{\rvx}\log p(\rvx)_{i} = &-\frac{1}{v} \left(\frac{1}{x_{D}} \sum_{k=1}^{D-1}\sigma^{\mu}_{k}(\rvx) + \frac{1}{x_{D}}\sigma^{\mu}_{i}(\rvx)\right) \nonumber \\ &+ \frac{\rvx_{i} - \rvx_{D}}{x_{i}x_{D}} \end{align} \]

where \(\sigma_{k}^{\mu}(\rvx) \doteq \log \frac{\rvx_{i}}{\rvx_{d}} - \mu\).

### Training

We may insert the above terms into a standard weighted Fisher divergence objective to obtain

\[\begin{equation} \gL(\vtheta) = \E_{t\sim U[0, 1]} \mathbb{E}_{\substack{ \rvy_{0} \sim \gD \\ \rvy_{y} \sim p_{t\vert 0}(\rvy_{t}\vert \rvy_{0}) \\ }} \lambda(t)\Big[\Vert \nabla_{\rvy}\log p_{t\vert 0}(\rvy_{t} \vert \rvy_{0}) - s_{\vtheta}(\rvy_{t}, t)\Vert_{2}^{2}\Big]. \end{equation} \]

While in theory, for a dataset of \(D\) different categories, each datapoint would correspond to a one-hot vector at \(t=0\), in practice this leads to instabilities and the authors instead map samples to \(\rvy = [\alpha, \beta, \ldots, \beta]^{\top}\) where \(\beta \doteq \frac{1-\alpha}{d-2}\) and \(\alpha = 1-\epsilon\).

The authors also report that the score suffers from numerical instabilities near the perimeter regions of the simplex. To rectify this, the authors train the network to directly predict \(-\frac{1}{2}\rmG(\rvx_{t}, t)\rmG(\rvx_{t}, t)^{\top}\nabla_{\rvx} \log p_{t}(\rvx_{t})\) which is bounded at the border of the interval \([0, 1]\) and empirically varies less quickly.

## Conclusion

While we have seen a couple of methods for performing diffusion over the probability simplex, and hence allow us to use diffusion models with discrete data, these aren’t without their limitations. The methods above often require ad-hoc \(\epsilon\) stabilisation parameters to avoid discretisation errors, among other engineering tricks such as carefully selected warmup schedules and controlling the noise to respect the geometric constraints of the simplex.

Further these methods may not scale elegantly to very high dimensional discrete data, where the inputs become extremely sparse owing to the one-hot representation of the discrete data. This is crucial for applications such as language modelling where tokenisation schemes often result in vocabularies in excess of \(50,000\) tokens.

Nonetheless these methods show promise for a range of generative modelling tasks in discrete data settings.

## Appendix A: Score Network Objective

The Fisher divergence

\[\frac{1}{2}\E_{p_{\text{data}}(\rvx)}\Big[\big\Vert \nabla_{\rvx} \log p_{\text{data}}(\rvx) - s_{\vtheta}(\rvx)\big\Vert_{2}^{2}\Big], \]

is not computable due to the presence of the \(\nabla_{\rvx} \log p_{\text{data}}(\rvx)\) term.

To eliminate this term, we may use score matching and integration by parts. Taking the 1D setting for simplicity, we begin by expanding out the brackets in the Fisher divergence

\[\begin{align*} &\hspace{1.2em} \frac{1}{2}\E_{p_{\text{data}}(x)}\Big[\big(\nabla_{x}\log p_{\text{data}}(x) - s_{\vtheta}(x)\big)^2\Big] \\ &= \frac{1}{2} \int p_{\text{data}}(x) \big(\nabla_{x} \log p_{\text{data}}(x) - s_{\vtheta}(x)\big)^{2}dx \\ &= \underbrace{\frac{1}{2} \int p_{\text{data}}(x) \big(\nabla_{x} \log p_{\text{data}}(x)\big)^{2}dx}_{\text{const wrt. }\vtheta} + \frac{1}{2} \int p_{\text{data}}(x)\big(s_{\vtheta}(x)\big)^{2}dx \\ &\hspace{1.2em}-\int p_{\text{data}}(x) s_{\vtheta}(x) \nabla_{x}\log p_{\text{data}}(x)dx. \end{align*} \]

We note from the chain rule that \(p(x)\nabla_{x}\log p(x) = \nabla_{x}p(x)\)
12: sometimes called the ‘log-derivative trick’ in machine learning
^{12[12]}
and hence the final term above becomes

\[- \int s_{\vtheta}(x)\nabla_{x}p_{\text{data}}(x) dx. \]

Integrating the above by parts is very straightforward; selecting \(u = s_{\vtheta}(x)\) and \(dv = \nabla_{x}p_{\text{data}}(x)\), we get \(du = \nabla_{x}s_{\vtheta}(x)\) and \(v = p_{\text{data}}(x)\) and thus

\[\begin{align*} -\int s_{\vtheta}(x)\nabla_{x}p_{\text{data}}(x)dx &= -p_{\text{data}} s_{\vtheta}(x)\bigg\vert_{-\infty}^{\infty} + \int p_{\text{data}}(x) \nabla_{x}s_{\vtheta}(x)dx \\ &= \E_{p_{\text{data}}(x)}\big[\nabla_{x}s_{\vtheta}(x)\big], \end{align*} \]

where the last line above holds if we assume that \(p_{\text{data}}(x) \to 0\) as \(\vert x \vert \to \infty\), which is a reasonable assumption for correctly normalised data distributions.

Substituting this back into the Fisher divergence,

\[\begin{align*} &\hspace{1.2em}\frac{1}{2}\E_{p_{\text{data}}(x)}\big[\big(\nabla_{x}\log p_{\text{data}}(x) - s_{\vtheta}(x)\big)^2\big] \\ &= \E_{p_{\text{data}}(x)}\left[\frac{1}{2}s_{\vtheta}(x)^{2} +\nabla_{x}s_{\vtheta}(x)\right] + \text{const}. \end{align*} \]

we see that we have removed the dependence on the score of the data distribution in the 1D case.

Hyvärinen (2005) show how we can extend the above to the more general multidimensional setting as:

\[\E_{p_{\text{data}}(\rvx)}\left[\frac{1}{2}\Vert s_{\vtheta}(\rvx)\Vert^{2}_{2} + \Tr\big(\nabla_{\rvx}s_{\vtheta}(\rvx)\big)\right] + \text{const}. \]