✏️Gumbel Softmax
Motivation. In many models we need to select a discrete option inside the computation graph (e.g., pick one branch of a network). A hard argmax is non-differentiable, so gradients can’t flow through it. Gumbel-Softmax provides a continuous, differentiable approximation to this discrete sampling step.
Gumbel-Max Trick Link to heading
Assume we have discrete distribution
$X$ | 1 | 2 | 3 |
---|---|---|---|
$p$ | 0.2 | 0.3 | 0.5 |
And want to get $X$ follow this distribution. If we directly sample from distribution, the $X$ can’t calculate from $p$. Which means $X$ can’t be differentiated w.r.t. $p$. This means we can’t do back propagation.
Gumbel-max is used to deal with this situation. The sample progress can re-parameters as
$$ y=\text{argmax}_k(\text{log} p_k+g_k),\quad g_k\sim \text{Gumbel}(0,1) $$$y$ is a one-hot vector. $g_k$ can also define as $g_i=-\text{log}(-\text{log}(u_i)),\quad u_i\sim\text{Uniform}(0,1)$.
For now, only $\text{argmax}$ is non-differentiable.
Gumbel Softmax Link to heading
Change the normal $\text{argmax}$ to a softmax function, and add a parameter $\tau$ to control
$$ y_k=\frac{\text{exp}(\text{log}p_k+g_k)/\tau}{\sum_j\text{exp}(\text{log}p_j+g_j)/\tau} $$- $\tau \to 0$ :almost one-hot (but gradients vanish).
- $\tau \to \infty$: near-uniform (too smooth).
Code tricky Link to heading
Forward pass: take hard one-hot (argmax). Backward pass: use gradients of the soft \tilde{y} (biased but practical).
y_soft = gumbel_softmax_sample(logits, tau)
if hard:
# forward: using one-hot
index = y_soft.argmax(dim=-1, keepdim=True)
y_hard = torch.zeros_like(logits).scatter_(-1, index, 1)
# backward: usey_soft
y = y_hard.detach() - y_soft.detach() + y_soft
else:
y = y_soft
Difference from “Vanilla” Softmax Link to heading
Aspect | Softmax | Gumbel-Softmax |
---|---|---|
Purpose | Get a probability distribution from logits | Differentiably approximate sampling a discrete category |
Randomness | Deterministic (given logits) | Stochastic (via Gumbel noise) |
One-hot? | Not during training (probability vector) | Can be nearly/actually one-hot (with ST) |
Gradients | Exact | Biased if ST, but low variance |
Use cases | Classification outputs | Discrete latent vars, NAS, RL action sampling, VAE with categorical codes |