Tuan Anh Le

Semi-supervised model learning and amortized inference

27 September 2019

This is a note on Kingma et al.’s paper and Siddharth et al.’s paper.

Say we have unsupervised data \(x_u \sim p(x)\) and some supervised data \(y_s, x_s \sim p(y, x)\) (where \(p(x)\) is the marginal of \(p(x, y)\) which is not super important here).

Then let’s say we want to learn parameters \(\theta\) of a generative model \(p_\theta(z, y, x)\) with a latent variable \(z\) and sometimes-latent variable \(y\). We also want to learn an inference network \(q_\phi(z, y \given x)\).

To learn \(\theta, \phi\), we should maximize the following objective: \begin{align} \mathcal L(\theta, \phi) := \E_{p(x_u)}\left[\mathrm{ELBO}(x_u, \theta, \phi)\right] + \gamma \E_{p(x_s, y_s)}\left[\mathrm{ELBO}(x_s, y_s, \theta, \phi) + \alpha \log q_\phi(y_s \given x_s)\right], \label{eq:obj} \end{align} where \begin{align} \mathrm{ELBO}(x_u, \theta, \phi) &:= \E_{q_\phi(z, y \given x_u)} \left[\log \frac{p_\theta(z, y, x_u)}{q_\phi(z, y \given x_u)}\right], \text{and} \label{eq:elbo1}\\
\mathrm{ELBO}(x_s, y_s, \theta, \phi) &:= \E_{q_\phi(z\given x_s, y_s)} \left[\log \frac{p_\theta(z, y_s, x_s)}{q_\phi(z \given x_s, y_s)}\right]. \label{eq:elbo2} \end{align}

To see why maximizing \eqref{eq:obj} is a good thing to do, rewrite the ELBOs into the logp - KL form and rewrite the logq term as an expected KL: \begin{align} \mathrm{ELBO}(x_u, \theta, \phi) &= \log p_\theta(x_u) - \KL{q_\phi(z, y \given x_u)}{p_\theta(z, y \given x_u)}, \\
\mathrm{ELBO}(x_s, y_s, \theta, \phi) &= \log p_\theta(x_s, y_s) - \KL{q_\phi(z \given x_s, y_s)}{p_\theta(z \given x_s, y_s)}, \\
\E_{p(x_s, y_s)}\left[\log q_\phi(y_s \given x_s)\right] &= -\E_{p(x_s)p(y_s \given x_s)}\left[\log p(y_s \given x_s) - \log q_\phi(y_s \given x_s) - \log p(y_s \given x_s)\right] \nonumber\\
&= -\E_{p(x_s)}\left[\KL{p(y_s \given x_s)}{q_\phi(y_s \given x_s)}\right] - H(p(y_s \given x_s)), \end{align} where \(H(p(y_s \given x_s))\) is the conditional entropy of \(p(y_s \given x_s)\).

This allows us to rewrite \eqref{eq:obj} as \begin{align} \mathcal L(\theta, \phi) = \color{blue}{\E_{p(x_u)}\left[\log p_\theta(x_u)\right]} \color{red}{-\E_{p(x_u)}\left[\KL{q_\phi(z, y \given x_u)}{p_\theta(z, y \given x_u)}\right]} + \color{blue}{\gamma\E_{p(x_s, y_s)}\left[\log p_\theta(x_s, y_s)\right]} \color{red}{-\gamma\E_{p(x_s, y_s)}\left[\KL{q_\phi(z \given x_s, y_s)}{p_\theta(z \given x_s, y_s)}\right]} - \color{red}{\gamma\alpha\E_{p(x_s)}\left[\KL{p(y_s \given x_s)}{q_\phi(y_s \given x_s)}\right]} - \gamma H(p(y_s \given x_s)). \end{align} Maximizing the blue terms leads to model learning and minimizing the red terms leads to amortized inference. The H term is not dependent on either \(\theta, \phi\).

To estimate gradients of \eqref{eq:obj}, we can sample from \(p(x_u)\) and \(p(x_s, y_s)\) (which are our datasets) and “move the gradients inside the expectations.” How do we estimate the ELBOs and the logq term? If the factorization of q is nice, it is easy (Kingma). Otherwise, we need to use self-normalized importance sampling (Siddharth).

Nice factorization of the inference network

Let’s say the inference network is factorized as \begin{align} q_\phi(z, y \given x) = q_\phi(y \given x) q_\phi(z \given y, x). \end{align} Gradients of both ELBOs in \eqref{eq:elbo1} and \eqref{eq:elbo2} are straightforward to estimate, as long as both \(z\) and \(y\) are reparameterizable. The logq term is also easy to evaluate. Kingma et al. use a model where \(y\) is discrete, however the support is just \(10\) elements so we can replace the expectation with a sum over ten terms.

Unfavourable factorization of the inference network

Let’s say the inference network is factorized as \begin{align} q_\phi(z, y \given x) = q_\phi(z \given x) q_\phi(y \given z, x). \label{eq:factorization2} \end{align} There are three problems:

  1. the denominator of \eqref{eq:elbo2}, \(q_\phi(z \given x_s, y_s)\), is difficult to evaluate,
  2. the expectation in \eqref{eq:elbo2} under \(q_\phi(z \given x_s, y_s)\) is difficult to sample from, and
  3. the term \(\log q_\phi(y_s \given x_s)\) is difficult to evaluate.

To solve problem 1, use the identity \(\log q_\phi(z, y \given x) = \log q_\phi(y \given x) + \log q_\phi(z \given y, x)\)—where the terms in RHS are only implicitly defined through the factorization \eqref{eq:factorization2}—and rewrite the ELBO in \eqref{eq:elbo2} as \begin{align} \mathrm{ELBO}(x_s, y_s, \theta, \phi) &= \E_{q_\phi(z\given x_s, y_s)} \left[\log \frac{p_\theta(z, y_s, x_s)}{q_\phi(z, y_s \given x_s)}\right] + \log q_\phi(y_s \given x_s). \label{eq:elbo3} \end{align} The extra logq term can be lumped together with the logq term in \eqref{eq:obj} so that we have \((\alpha + 1) \log q_\phi(y_s \given x_s)\) instead of \(\log q_\phi(y_s \given x_s)\).

To solve problem 2, we use self-normalized importance sampling where the proposal is \(q_\phi(z \given x)\) and the unnormalized target distribution over \(z\) is \(q_\phi(z, y \given x) = q_\phi(z \given y, x) q_\phi(y \given x) \propto q_\phi(z \given y, x)\). This now allows us to estimate the expectation in \eqref{eq:elbo3} as \begin{align} \E_{q_\phi(z\given x_s, y_s)} \left[\log \frac{p_\theta(z, y_s, x_s)}{q_\phi(z, y_s \given x_s)}\right] \approx \sum_{k = 1}^K \bar w_k \log \frac{p_\theta(z_k, y_s, x_s)}{q_\phi(z_k, y_s \given x_s)}, \end{align} where \(z_k \sim q_\phi(z \given x_s)\), \(w_k = q_\phi(z_k, y_s \given x_s) / q_\phi(z_k \given x_s)\), and \(\bar w_k = w_k / \sum_{\ell = 1}^K w_\ell\).

To solve problem 3, we use evaluate an IWAE-like lower bound on \(\log q_\phi(y_s \given x_s)\) where we also use \(q_\phi(z \given x)\) as the proposal and treat \(q_\phi(z, y \given x)\) as the unnormalized target distribution corresponding to the normalized \(q_\phi(y \given x)\). This allows us to use previously sampled \(z_k\) and weights \(w_k\) in evaluating the stochastic lower bound \begin{align} \widehat{logq} := \log\left(\frac{1}{K}\sum_{k = 1}^K w_k\right) \end{align} whose expectation is a lower bound to \(\log q_\phi(y \given x)\).

This allows us to estimate the gradient of \eqref{eq:obj} as \begin{align} \hat g = \nabla_{\theta, \phi} \left(\mathrm{ELBO}(x_u, \theta, \phi) + \gamma \sum_{k = 1}^K \bar w_k \log \frac{p_\theta(z_k, y_s, x_s)}{q_\phi(z_k, y_s \given x_s)} + \gamma(\alpha + 1) \log\left(\frac{1}{K}\sum_{k = 1}^K w_k\right)\right). \end{align} All sampling is reparameterized.

All of this can be generalized to other bad factorizations of the inference network.