Tuan Anh Le

Variational Inference for Monte Carlo Objectives (VIMCO)

25 February 2018

These are notes on the VIMCO paper.

The goal is to reduce variance for gradient estimators for importance weighted autoencoders, especially when there are discrete latent variables and a REINFORCE gradient estimator must usually be used.

Let \(p_\theta(z, x)\) be a generative network of latent variables \(z\) and observations \(x\). Let \(q_\phi(z \given x)\) be the inference network. Given a dataset \((x^{(n)})_{n = 1}^N\), we want to maximize \(\sum_{n = 1}^N \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x^{(n)})\) where: \begin{align} \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x) = \int \left( \prod_{k = 1}^K q_\phi(z^k \given x) \right) \log \left( \frac{1}{K} \sum_{k = 1}^K \frac{p_\theta(z^k, x)}{q_\phi(z^k \given x)} \right) \,\mathrm dz^{1:K}. \label{eq:elbo} \end{align}

If \(z^{1:K}\) is not reparameterizable, we must use the REINFORCE gradient estimator to estimate gradients of \eqref{eq:elbo} with respect to \(\phi\): \begin{align} \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) \nabla_\phi \log \left( \prod_{k = 1}^K q_\phi(z^k \given x) \right) + \nabla_\phi \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right), \label{eq:reinforce} \end{align} where \(z^{1:K} \sim \prod_{k = 1}^K q_\phi(\mathrm dz^k \given x)\) and \(f_{\theta, \phi}(z, x) := \frac{p_\theta(z, x)}{q_\phi(z \given x)}\). Although this estimator is unbiased, it’s high variance due to the first term.

First, let’s rewrite the first term of the estimator in \eqref{eq:reinforce} as: \begin{align} \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) \nabla_\phi \log \left( \prod_{k = 1}^K q_\phi(z^k \given x) \right) &= \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) \left( \sum_{\ell = 1}^K \nabla_\phi \log q_\phi(z^\ell \given x) \right) \\
&= \sum_{\ell = 1}^K \left( \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) \nabla_\phi \log q_\phi(z^\ell \given x) \right). \label{eq:reinforce-2} \end{align}

Given a function \(\hat f(z^{-\ell}, x)\) (where \(z^{-\ell} := (z^1, \dotsc, z^{\ell - 1}, z^{\ell + 1}, \dotsc, z^K)\)) which is independent of \(z^\ell\), we continue from \eqref{eq:reinforce-2}: \begin{align} \sum_{\ell = 1}^K \left( \left(\log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) - \log \left( \frac{1}{K} \left( \hat f(z^{-\ell}, x) + \sum_{k \neq \ell} f_{\theta, \phi}(z^k, x) \right) \right) \right) \nabla_\phi \log q_\phi(z^\ell \given x) \right). \label{eq:reinforce-3} \end{align}

The authors experiment with \(\hat f(z^{-\ell}, x) = \frac{1}{K - 1} \sum_{k \neq \ell} f_{\theta, \phi}(z^k, x)\) but found \begin{align} \hat f(z^{-\ell}, x) := \exp\left( \frac{1}{K - 1} \sum_{k \neq \ell} \log f_{\theta, \phi}(z^k, x) \right) \end{align} to work better.

Since the term \begin{align} g(z^{-\ell}, x) := \log \left( \frac{1}{K} \left( \hat f(z^{-\ell}, x) + \sum_{k \neq \ell} f_{\theta, \phi}(z^k, x) \right) \right) \end{align} is independent of \(z^\ell\), we can verify that: \begin{align} \E\left[\sum_{\ell = 1}^K g(z^{-\ell}, x) \nabla_\phi \log q_\phi(z^\ell \given x) \right] &= \sum_{\ell = 1}^K \E\left[ g(z^{-\ell}, x) \nabla_\phi \log q_\phi(z^\ell \given x) \right] \\
&= \sum_{\ell = 1}^K \E\left[ g(z^{-\ell}, x) \right] \E\left[ \nabla_\phi \log q_\phi(z^\ell \given x) \right] && \text{(since } z^{-\ell} \text{ and } z^\ell \text{ are independent)} \\
&= 0, \end{align} where we use the fact that \(\E\left[ \nabla_\phi \log q_\phi(z^\ell \given x) \right] = \int \nabla_\phi q_\phi(z^\ell \given x) \,\mathrm dz^\ell = \nabla_\phi \int q_\phi(z^\ell \given x) \,\mathrm dz^\ell = \nabla_\phi 1 = 0\). Hence \eqref{eq:reinforce-2} and \eqref{eq:reinforce-3} have the same expectation.

Putting everything together, the VIMCO estimator is: \begin{align} &\sum_{\ell = 1}^K \left( \left(\log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) - \log \left( \frac{1}{K} \left( \hat f(z^{-\ell}, x) + \sum_{k \neq \ell} f_{\theta, \phi}(z^k, x) \right) \right) \right) \nabla_\phi \log q_\phi(z^\ell \given x) \right) \nonumber\\
&+ \nabla_\phi \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right). \end{align}

[back]