Tuan Anh Le

Alternative proof for tighter bounds [WIP]

17 December 2019

This is an alternative proof that the asymptotic signal-to-noise ratio of the importance weighted autoencoder (IWAE)-based inference gradient estimator is \(O(1 / \sqrt{K})\) for number of particles \(K\) as given in Theorem 1 of our paper. This alternative proof is due to Finke and Thiery’s Remark 1 (second bullet-point).

Consider a generative network \(p_\theta(z, x)\) and an inference network \(q_\phi(z \given x)\) on latents \(z\) and observations \(x\). Given a set of observations \((x^{(n)})_{n = 1}^N\) sampled iid from the true generative model \(p(x)\), we want to learn \(\theta\) to maximize \(\frac{1}{N} \sum_{n = 1}^N \log p_\theta(x)\) and \(\phi\) so that \(q_\phi(z \given x)\) is close to \(p_\theta(z \given x)\) for the learned \(\theta\).

Evidence Lower Bound

Consider a generative network \(p_\theta(z, x)\) and an inference network \(q_\phi(z \given x)\) on latents \(z\) and observations \(x\). The \(K\)-particle IWAE-based evidence lower bound (ELBO) as: \begin{align} \mathrm{ELBO}_{\text{IWAE}}^K(\theta, \phi, x) = \E_{\prod_{k = 1}^K q_\phi(z_k \given x)}\left[\log \left(\frac{1}{K} \sum_{k = 1}^K w_k \right)\right], \end{align} where \begin{align} w_k = \frac{p_\theta(z_k, x)}{q_\phi(z_k \given x)} \end{align} is the unnormalized importance weight.

Inference Gradient Estimator

Given a differentiable reparameterization function \(r_\phi\) and a noise distribution \(s(\epsilon)\) such that \(r_\phi(\epsilon, x)\) for \(\epsilon \sim s(\epsilon)\) has the same distribution as \(z \sim q_\phi(z \given x)\), the gradient estimator for \(\nabla_\phi \mathrm{ELBO}_{\text{IWAE}}^K(\theta, \phi, x)\) can be written as \begin{align} \hat g_K = \sum_{k = 1}^K \bar w_k \nabla_\phi \log w_k, \end{align} where \(w_k = {p_\theta(r_\phi(\epsilon_k, x), x)} / {q_\phi(r_\phi(\epsilon_k, x) \given x)}\), \(\bar w_k = w_k / \sum_{\ell = 1}^K w_\ell\) and \(\epsilon_k \sim s(\epsilon)\) for \(k = 1, \dotsc, K\).

Decreasing Signal-to-Noise Ratio

To prove that the signal-to-noise ratio (SNR) of \(\hat g_K\), defined as \(\mathrm{SNR}(\hat g_K) = \E[\hat g_K] / \mathrm{std}(\hat g_K)\) is \(O(1 / \sqrt{K})\), we prove that it is a \(K\)-particle self-normalized importance sampling (SNIS) estimate of a zero vector. It follows from the SNIS literature that \(\E[\hat g_K]\) is \(O(1 / K)\) and \(\mathrm{std}(\hat g_K)\) is \(O(1 / \sqrt{K})\) and hence \(\mathrm{SNR}(\hat g_K)\) is \(O(1 / \sqrt{K})\).

We prove this in two steps:

  1. \(\hat g_K\) is an SNIS estimator of \(\E_{p_\theta(z \given x)}\left[\nabla_\phi \log w\right]\) where \(w = p_\theta(r_\phi(\boxed{r_\phi^{-1}(z, x)}, x), x) / q_\phi(r_\phi(\boxed{r_\phi^{-1}(z, x)}, x) \given x)\) and the boxed variables are treated as constants (i.e. they are detached or stop_gradiented);
  2. \(\E_{p_\theta(z \given x)}\left[\nabla_\phi \log w\right] = 0\).

The first step

In this step, we use the fact that any reparameterization function \(r_\phi(\cdot, x): \mathcal E \to \mathcal Z\) must be a bijection. This means that for a fixed \(\phi\) and \(x\), there exists an inverse \(r_\phi^{-1}(\cdot, x): \mathcal Z \to \mathcal E\) such that \(r_\phi^{-1}(r_\phi(\epsilon, x), x) = \epsilon\) for all \(\epsilon\).

Using this, we can rewrite \(\hat g_K\) as \begin{align} \hat g_K = \sum_{k = 1}^K \bar w_k \nabla_\phi \log w_k, \end{align} where \(w_k = {p_\theta(r_\phi(\boxed{r_\phi^{-1}(z_k, x)}, x), x)} / {q_\phi(r_\phi(\boxed{r_\phi^{-1}(z_k, x)}, x) \given x)}\), \(\bar w_k = w_k / \sum_{\ell = 1}^K w_\ell\) and \(z_k \sim q_\phi(z \given x)\) for \(k = 1, \dotsc, K\).

This means that \(\hat g_K\) is an SNIS estimate of \(\E_{p_\theta(z \given x)}\left[\nabla_\phi \log w\right]\).

The second step

To prove \(\E_{p_\theta(z \given x)}\left[\nabla_\phi \log w\right] = 0\), we will use the equivalence of REINFORCE and reparameterization tricks used in (Tucker et al. 2018, equation 5) in which for any \(f(z)\) which can potentially depend on \(\phi\): \begin{align} \E_{q_\phi(z \given x)}\left[f(z) \frac{\partial}{\partial \phi}\log q_\phi(z \given x)\right] = \E_{s(\epsilon)}\left[\frac{\partial f(z)}{\partial z} \frac{\partial z(\epsilon, \phi)}{\partial \phi}\right]. \end{align}

bla \begin{align} \E_{p_\theta(z \given x)}\left[\nabla_\phi \log w\right] &= \E_{p_\theta(z \given x)}\left[\frac{\partial \log w}{\partial z}\frac{\partial z}{\partial \phi} + \frac{\partial \log w}{\partial \phi}\right] \\
&= \E_{p_\theta(z \given x)}\left[\frac{\partial \log w}{\partial z}\frac{\partial z}{\partial \phi} - \frac{\partial \log q}{\partial \phi}\right] \end{align}

need to prove \begin{align} \E_{p_\theta(z \given x)}\left[\frac{\partial \log w}{\partial z}\frac{\partial z}{\partial \phi}\right] = \E_{p_\theta(z \given x)}\left[\frac{\partial \log q}{\partial \phi}\right] \end{align}

subst \(f = w\) \begin{align} \E_{q_\phi(z \given x)}\left[w \frac{\partial}{\partial \phi}\log q_\phi(z \given x)\right] &= \E_{s(\epsilon)}\left[\frac{\partial w}{\partial z} \frac{\partial z(\epsilon, \phi)}{\partial \phi}\right] \\
&= \E_{s(\epsilon)}\left[w \frac{\partial \log w}{\partial z} \frac{\partial z(\epsilon, \phi)}{\partial \phi}\right] \end{align}

RHS is \begin{align} RHS &= \E_{q_\phi(z \given x)}\left[w \frac{\partial \log w}{\partial z} \frac{\partial r_\phi(\boxed{r_\phi^{-1}(z, x)}, x)}{\partial \phi}\right] \\
&= \E_{p_\theta(z \given x)}\left[\frac{\partial \log w}{\partial z} \frac{\partial r_\phi(\boxed{r_\phi^{-1}(z, x)}, x)}{\partial \phi}\right] p_\theta(x) \end{align}

LHS is \begin{align} LHS &= \E_{p_\theta(z \given x)}\left[\frac{\partial}{\partial \phi}\log q_\phi(z \given x)\right] p_\theta(x) \end{align}

therefore (8) is true.