Decomposing the ELBO
Rob Zinkov
2018-11-02
When performing Variational Inference, we are minimizing the KL divergence between some distribution we care about \(p(\v{z} \mid \v{x})\) and some distribution that is easier to work with \(q_\phi(\v{z} \mid \v{x})\).
\[ \begin{align} \phi^* &= \underset{\phi}{\mathrm{argmin}}\, \text{KL}(q_\phi(\v{z} \mid \v{x}) \;\|\; p(\v{z} \mid \v{x})) \\ &= \underset{\phi}{\mathrm{argmin}}\, \mathbb{E}_{q_\phi(\v{z} \mid \v{x})} \big[\log q_\phi(\v{z} \mid \v{x}) - \log p(\v{z} \mid \v{x}) \big]\\ \end{align} \]
Now because the density of \(p(\mathbf{z} \mid \mathbf{x})\) usually isn’t tractable, we use a property of the log model evidence \(\log\, p(\v{x})\) to define a different objective to optimize.
\[ \begin{align} \Expect_{q_\phi(\v{z} \mid \v{x})} \big[\log q_\phi(\v{z} \mid \v{x}) - \log p(\v{z} \mid \v{x})\big] &\leq \Expect_{q_\phi(\v{z} \mid \v{x})} \big[\log q_\phi(\v{z} \mid \v{x}) - \log p(\v{z} \mid \v{x})\big] - \log p(\v{x}) \\ &= \Expect _{q_\phi(\v{z} \mid \v{x})} \big[\log q_\phi(\v{z} \mid \v{x}) - \log p(\v{z} \mid \v{x}) - \log p(\v{x})\big] \\ &= \Expect _{q_\phi(\v{z} \mid \v{x})} \big[\log q_\phi(\v{z} \mid \v{x}) - \log p(\v{x}, \v{z})\big]\\ &= -\mathcal{L}(\phi) \end{align} \]
As \(\mathcal{L}(\phi) = \log p(\v{x}) - \text{KL}(q_\phi(\v{z} \mid \v{x}) \;\|\; p(\v{z} \mid \v{x}))\) maximizing \(\mathcal{L}(\phi)\) effectively minimizes our original KL.
This term \(\mathcal{L}(\phi)\) is sometimes called the evidence lower-bound or ELBO, because the KL term must always be greater-than or equal to zero, \(\mathcal{L}(\phi)\) can be seen as a lower-bound estimate of \(\log p(\v{x})\).
Due to various linearity properties of expectations, this can be rearranged into many different forms. This is useful to get an intuition for what can be going wrong when you learn \(q_\phi(\v{z} \mid \v{x})\)
Now why does this matter? Couldn’t I just optimize this loss with SGD and be done? Well you can, but often if something is going wrong it will show up as one or some terms being unusually off. By making these tradeoffs in the loss function explicit means you can adjust it to favor different properties of your learned representation. Either by hand or automatically.
Entropy form
The classic form is in terms of an energy term and an entropy term. The first term encourages \(q\) to put high probability mass wherever \(p\) does so. The second term is encouraging that \(q\) should as much as possible maximize it’s entropy and put probability mass everywhere it can.
\[ \mathcal{L}(\phi) = \Expect_{q_\phi(\v{z} \mid \v{x})}[\log p(x, z)] + H(q_\phi(\v{z} \mid \v{x})) \]
where
\[ H(q_\phi(\v{z} \mid \v{x})) \triangleq - \Expect_{q_\phi(\v{z} \mid \v{x})}[\log q_\phi(\v{z} \mid \v{x})] \]
Reconstruction error minus KL on the prior
More often these days, we describe the \(\mathcal{L}\) in terms of a reconstruction term and KL on the prior for \(p\). Here the first term is saying we should put mass on latent codes \(\v{z}\) from which \(p\) is likely to generate our observation \(\v{x}\). The second term then suggests to this trade off with \(q\) also being near the prior.
\[ \mathcal{L}(\phi) = \Expect_{q_\phi(\v{z} \mid \v{x})}[\log p(\v{x} \mid \v{z})] - \text{KL}(q_\phi(\v{z} \mid \v{x}) \;\|\; p(\v{z}))\]
ELBO surgery
But there are other ways to think about this decomposition. Because we frequently use amortized inference to learn a \(\phi\) useful for describing all kinds of \(q\) distributions regardless of our choice of observation \(\v{x}\). We can talk about the average distribution we learn over our observed data, with \(p_d\) being the empirical distribution of our observations.
\[ \overline{q}_\phi(\v{z}) = \Expect_{p_d(\v{x})} \big[ q_\phi(\v{z} \mid \v{x}) \big] \]
This is sometimes called the aggregate posterior.
With that we can decompose our KL on the prior into a mutual information term that encourages each \(q_\phi(\v{z} \mid \v{x})\) we create to be near the average one \(\overline{q}_\phi(\v{z})\) and a KL between this average distribution and the prior. The encourages the representation generated for \(\v{z}\) to be useful.
\[ w\mathcal{L}(\phi) = \Expect_{q_\phi(\v{z} \mid \v{x})}[\log p(\v{x} \mid \v{z})] - \mathbb{I}_q(\v{x},\v{z}) - \text{KL}(\overline{q}_\phi(\v{z}) \;\|\; p(\v{z})) \]
where
\[ \mathbb{I}_q(\v{x},\v{z}) \triangleq \Expect_{p_d}\big[\Expect_{q_\phi(\v{z} \mid \v{x})} \big[\log q_\phi(\v{z} \mid \v{x})\big] \big] - \Expect_{\overline{q}_\phi(\v{z})} \log \overline{q}_\phi(\v{z}) \]
Difference of two KL divergences
With something like \(p_d\) around it is also possible to pull out the relationship between \(p\) and \(p_d\). This is particularly relevant if you intend to learn \(p\).
\[ \mathcal{L}(\phi) = - \text{KL}(q_\phi(\v{z} \mid \v{x}) \;\|\; p(\v{z} \mid \v{x})) - \text{KL}(p_d(\v{x}) \;\|\; p(\v{x})) \]
Full decomposition
Of course with more aggressive rearranging, we can just have a term to encourage learning better latent representations. In a setting where you aren’t learning \(p\) some of these terms are constant and can generally be ignored. I provide them here for completeness.
\[ \mathcal{L}(\phi) = \Expect_{q_\phi(\v{z} \mid \v{x})}\left[ \log\frac{p(\v{x} \mid \v{z})}{p(\v{x})} - \log\frac{q_\phi(\v{z} \mid \v{x})}{q_\phi(\v{z})} \right] - \text{KL}(p_d(\v{x}) \;\|\; p(\v{x})) - \text{KL}(\overline{q}_\phi(\v{z}) \;\|\; p(\v{z}))\]
I highly encourage checking out the Appendix of the Structured Disentangled Representations paper to see how much further this can be pushed.
Final notes
Of course, all the above still holds in the VAE setting where \(p\) becomes \(p_\theta\) but I felt the notation was cluttered enough already. It’s kind of amazing how much insight can be gained through expanding and collapsing one loss function.