✏️Gumbel Softmax

Motivation. In many models we need to select a discrete option inside the computation graph (e.g., pick one branch of a network). A hard argmax is non-differentiable, so gradients can’t flow through it. Gumbel-Softmax provides a continuous, differentiable approximation to this discrete sampling step.

Gumbel-Max Trick Link to heading

Assume we have discrete distribution

$X$ 1 2 3
$p$ 0.2 0.3 0.5

And want to get $X$ follow this distribution. If we directly sample from distribution, the $X$ can’t calculate from $p$. Which means $X$ can’t be differentiated w.r.t. $p$. This means we can’t do back propagation.

Gumbel-max is used to deal with this situation. The sample progress can re-parameters as

$$ y=\text{argmax}_k(\text{log} p_k+g_k),\quad g_k\sim \text{Gumbel}(0,1) $$

$y$ is a one-hot vector. $g_k$ can also define as $g_i=-\text{log}(-\text{log}(u_i)),\quad u_i\sim\text{Uniform}(0,1)$.

For now, only $\text{argmax}$ is non-differentiable.

Gumbel Softmax Link to heading

Change the normal $\text{argmax}$ to a softmax function, and add a parameter $\tau$ to control

$$ y_k=\frac{\text{exp}(\text{log}p_k+g_k)/\tau}{\sum_j\text{exp}(\text{log}p_j+g_j)/\tau} $$
  • $\tau \to 0$ :almost one-hot (but gradients vanish).
  • $\tau \to \infty$: near-uniform (too smooth).

Code tricky Link to heading

Forward pass: take hard one-hot (argmax). Backward pass: use gradients of the soft \tilde{y} (biased but practical).

y_soft = gumbel_softmax_sample(logits, tau)

if hard:
    # forward: using one-hot
    index = y_soft.argmax(dim=-1, keepdim=True)
    y_hard = torch.zeros_like(logits).scatter_(-1, index, 1)
    # backward: usey_soft
    y = y_hard.detach() - y_soft.detach() + y_soft
else:
    y = y_soft

Difference from “Vanilla” Softmax Link to heading

Aspect Softmax Gumbel-Softmax
Purpose Get a probability distribution from logits Differentiably approximate sampling a discrete category
Randomness Deterministic (given logits) Stochastic (via Gumbel noise)
One-hot? Not during training (probability vector) Can be nearly/actually one-hot (with ST)
Gradients Exact Biased if ST, but low variance
Use cases Classification outputs Discrete latent vars, NAS, RL action sampling, VAE with categorical codes