Article


On Variational Objectives for Discrete Latents

Or, “how to sample discrete things with gradients”. A comparison of Gumbel-softmax, straight-through estimators, REINFORCE, VIMCO, NVIL, REBAR, CONCRETE and other mystifying acronyms.

February 6, 2024

London, UK


We often have to optimise an objective of the following form

\[\begin{equation} \label{eq:sensitivity_analysis} \gL(\vphi) = \E_{\Q_{\vphi}}\left[f(\rvx)\right], \end{equation} \]

where the parameter that we are optimising \(\vphi\) appears in the distribution over which we are taking the expectation; \(\Q_{\vphi}\).

Some examples are:

  • variational inference, when \(\Q_{\vphi}(\rvx)\) is an approximate posterior that is otherwise intractable to obtain, and \(f(\rvx)\) is a log likelihood or log joint.
  • reinforcement learning, where \(\Q_{\vphi}(\rvx)\) is the policy, and \(f(\rvx)\) is the value function.

However, when the latent variable, \(\rvx\) is discrete, solving these types of problems can be tricky. This article will give an overview of the main approaches to get around this.

To introduce the problem of optimising a variational objective with discrete latent variables, we will briefly look at (deep) latent variable models. The remainder of the article will cover some key ideas and methods for dealing with this.

A latent variable model (LVM) is a joint distribution over some observed data \(\rvx\), and some latent variables \(\rvz\); \(\P(\rvx, \rvz)\). We can factorise this in two ways; either working with \(\P(\rvz \vert \rvx)\) to, perhaps stochastically, map (or encode) a datapoint \(\rvx\) to its latent representation \(\rvz\), or the conditional \(\P(\rvx \vert \rvz)\) to generate (or decode) latents \(\rvz\) to new plausible \(\rvx\) samples.

In the second factorisation, \(\P(\rvx, \rvz) = \P(\rvx \vert \rvz) \P(\rvz)\), the (deep) LVM takes the form of

\[\begin{align} \P_{\vtheta_{z}}(\rvz) &= f_{z}(\rvz; \vtheta_{z}) \\ \P_{\vtheta_{y}}(\rvx \vert \rvz) &= f_{y}(\rvx; \rvz, \vtheta_[y]), \end{align} \]

where \(f_{z}\) and \(f_{y}\) are valid density functions, and \(\vtheta = \{\vtheta_{z}, \vtheta_{y}\}\) parametrises the generative process.

To draw a sample from the posterior \(\P_{\theta}(\rvx \vert \rvz)\), we first sample from the prior over latent variables \(\hat{\rvz} \sim \P_{\vtheta_{z}}(\rvx)\), and condition on this value when we sample the data \(\hat{\rvx} \sim \P_{\vtheta_{y}}(\rvx \vert \hat{\rvz})\).

To learn these parameters \(\vtheta\), the intuitive thing to do is to maximise the marginal likelihood:

\[\begin{align} \vtheta = \argmax_{\hat{\vtheta} \in \theta}\log \P_{\hat{\vtheta}}(\rvx) = \argmax_{\hat{\vtheta} \in \theta}\int_{\gZ} \P_{\vtheta}(\rvx, \rvz)d\rvz. \end{align} \]

When the likelihood and/or prior is a deep parametric function approximator, this integral is almost always intractable. The variational approach is to introduce an approximate posterior distribution \(\Q_{\vphi}(\rvz \vert \rvx)\) to approximate \(\P_{\vtheta}(\rvz \vert \rvx)\), which has a simpler form and is easier to work with. Then, using Jensen’s inequality, we can form a variational objective:

\[\begin{align} \int_{\gZ}\P_{\vtheta}(\rvx, \rvz) d\rvz &\ge \int_{\gZ}\Q_{\vphi}(\rvz \vert \rvx) \log \frac{\P_{\vtheta}(\rvx, \rvz)}{\Q_{\vphi}(\rvz \vert \rvx)}d\rvz \\ &= \E_{\Q_{\vphi}}\left[\log \P_{\vtheta}(\rvx, \rvz) - \log \Q_{\vphi}(\rvz \vert \rvx)\right] \label{eq:elbo_expectation} \\ &\doteq \gL(\rvx, \vtheta, \vphi). \label{eq:elbo} \end{align} \]

Maximising Equation \(\ref{eq:elbo}\), the lower bound to the marginal likelihood, allows us to train the likelihood and prior’s parameters, \(\vtheta\), as well as the variational parameters, \(\vphi\).

So far, we have left the space of latents \(\gZ\) unspecified. However, looking at Equation \(\ref{eq:elbo_expectation}\), we see that when we are optimising the variational parameters \(\vphi\), we must draw latent samples from the variational distribution with gradient information, such that we can use backprop tools to train the neural density estimators \(\P_{\vtheta}\) and \(\Q_{\vphi}\).

For real-valued latents (\(\gZ = \R\)), modeled from distributions which are amenable to reparametrised sampling, this sampling-with-gradients is relatively straighforward to do. However, discrete latent variables are not so easily sampled with gradients. Doing so will be the subject of the techniques and methods discussed below.

Work in Progress

The rest of this article will go over the key ideas set out in the list of papers below. Please check back in the next few days!