Reparameterization Trick

Motivation

In a variational autoencoder and many other latent-variable models, the training objective contains an expectation under a parameterized distribution:

\[ J(\phi) = \mathbb{E}_{z \sim q_\phi(z \mid x)}[f(z)], \]

where \(f\) is some function (the reconstruction log-likelihood, in the VAE case). To train by gradient ascent, we need \(\nabla_\phi J(\phi)\) — but the expectation itself depends on \(\phi\), so the gradient is not just \(\mathbb{E}[\nabla_\phi f(z)]\).

The reparameterization trick (Kingma and Welling 2013) rewrites \(z\) as a deterministic function of \(\phi\) and a noise variable with a fixed (parameter-free) distribution. The expectation is then over the noise, the gradient passes through the deterministic function via the chain rule, and the Monte Carlo estimator has dramatically lower variance than the score-function alternative. This is the optimization step that makes VAE training practical.

The Trick

Rewrite \(z \sim q_\phi(z \mid x)\) as \(z = h_\phi(x, \varepsilon)\) where \(\varepsilon \sim p(\varepsilon)\) does not depend on \(\phi\). Then

\[ J(\phi) = \mathbb{E}_{\varepsilon \sim p(\varepsilon)}[f(h_\phi(x, \varepsilon))], \]

and the gradient can be moved inside the expectation:

\[ \nabla_\phi J(\phi) = \mathbb{E}_{\varepsilon \sim p(\varepsilon)}[\nabla_\phi f(h_\phi(x, \varepsilon))] = \mathbb{E}_{\varepsilon}[f'(h_\phi(x, \varepsilon)) \cdot \nabla_\phi h_\phi(x, \varepsilon)]. \]

A single sample of \(\varepsilon\) gives an unbiased estimate of the gradient. The estimator’s variance is determined by how \(f \circ h\) varies with \(\varepsilon\) at fixed \(\phi\), which is typically small. This is the key empirical observation: the reparameterized estimator is much lower variance than the score-function estimator. (proof)

The Gaussian Case

The standard VAE encoder is a diagonal Gaussian: \(q_\phi(z \mid x) = \mathcal{N}(\mu_\phi(x), \operatorname{diag}(\sigma_\phi(x)^2))\). The reparameterization is

\[ z = \mu_\phi(x) + \sigma_\phi(x) \odot \varepsilon, \qquad \varepsilon \sim \mathcal{N}(0, I). \]

This expresses any diagonal Gaussian as a deterministic function of a fixed standard-normal noise, with the dependence on \(\phi\) entirely in \(\mu_\phi\) and \(\sigma_\phi\). Gradients flow through \(z\) to \(\mu_\phi\) and \(\sigma_\phi\) via the chain rule with no special handling.

In code (PyTorch, schematically):

mu, log_sigma = encoder(x)
sigma = log_sigma.exp()
eps = torch.randn_like(mu)
z = mu + sigma * eps
recon = decoder(z)
loss = recon_loss(recon, x) + kl(mu, sigma)
loss.backward()  # gradient flows through mu, sigma; eps is a constant

Note: it is standard to predict \(\log \sigma\) rather than \(\sigma\) directly, both for numerical stability (avoid negative \(\sigma\)) and to keep gradients well-scaled.

Comparison: Score-Function Estimator

The score-function (REINFORCE-style) estimator uses the log-derivative trick:

\[ \nabla_\phi \mathbb{E}_{q_\phi}[f(z)] = \mathbb{E}_{q_\phi}[f(z) \nabla_\phi \log q_\phi(z \mid x)]. \]

This is unbiased but typically high-variance because \(f(z) \nabla_\phi \log q_\phi(z \mid x)\) has variance dominated by the magnitude of \(f(z)\) even when only a small change to \(\phi\) would make the expectation move slightly. Practical use requires variance-reduction techniques (control variates, baselines).

The reparameterization estimator has variance that depends on the gradient of \(f\) rather than its value. For smooth \(f\), this is much smaller. The empirical reduction is often two to three orders of magnitude.

The catch: the reparameterization trick requires a differentiable path from \(\phi\) to \(z\). Continuous distributions usually have such a path; discrete distributions do not. For discrete latents, alternatives include the Gumbel-softmax (a continuous relaxation), straight-through estimators, or sticking with score-function estimators.

When Reparameterization Works

Required: the distribution \(q_\phi\) can be expressed as \(z = h_\phi(\varepsilon)\) for some fixed-distribution \(\varepsilon\), with \(h_\phi\) differentiable in \(\phi\).

Examples: - Gaussian \(\mathcal{N}(\mu, \sigma^2)\): \(z = \mu + \sigma \varepsilon\), \(\varepsilon \sim \mathcal{N}(0, 1)\). - Uniform \(U(a, b)\): \(z = a + (b - a) \varepsilon\), \(\varepsilon \sim U(0, 1)\). - Laplace, Cauchy, exponential, Gamma (with effort). Any continuous distribution that can be sampled by transforming a fixed-distribution noise variable through a differentiable function — i.e., any distribution with a differentiable inverse CDF, or a “location-scale family.” - Mixture distributions are not natively reparameterizable, but Gumbel-softmax provides a relaxation.

Cannot directly reparameterize: - Categorical / Bernoulli / Poisson — discrete. Use Gumbel-softmax relaxations or score-function estimators. - Implicitly-defined distributions that lack a differentiable sampling path.

The Pathwise vs. Score-Function Distinction

The reparameterization trick is one example of a pathwise gradient estimator: gradients flow through the path by which the random variable is computed. The score-function estimator is the alternative for cases where such a path does not exist.

Modern stochastic computation graphs (PyTorch, TensorFlow, JAX) handle reparameterization automatically when distributions are constructed using their reparameterization-aware sampling functions. As long as the encoder predicts \(\mu\) and \(\sigma\) and you sample by mu + sigma * eps, the framework’s autodiff handles the rest.

Why It Matters

Beyond the VAE: the reparameterization trick made amortized variational inference practical and is one of the key ingredients in modern probabilistic deep learning. It is also used in stochastic policies (continuous-action RL), in Bayesian neural networks (Bayes by backprop), and as a foundation for the diffusion model training objective, which can be viewed as a denoising-style ELBO computed with a particular reparameterization of the forward noising process.

References

Kingma, Diederik P., and Max Welling. 2013. “Auto-Encoding Variational Bayes.” arXiv Preprint arXiv:1312.6114. https://arxiv.org/abs/1312.6114.