*19 December 2017*

Consider a probabilistic model \(p(x, y)\) of \(\mathcal X\)-valued latents \(x\), \(\mathcal Y\)-valued observes \(y\). Amortized inference is finding a mapping \(g_\phi: \mathcal Y \to \mathcal P(\mathcal X)\) from an observation \(y\) to a distribution \(q_{\phi}(x \given y)\) that is close to \(p(x \given y)\). Let’s go with the following objective that is to be minimized: \begin{align} \mathcal L(\phi) = \int \overbrace{w(y)\pi(y)}^{f(y)} \, \mathrm{divergence}\left(p(\cdot \given y), g_\phi(y)\right) \,\mathrm dy. \label{eqn:amortization/objective} \end{align}

Here,

- \(\mathcal L\) is the objective function,
- \(\pi(y)\) is some distribution over \(y\) values we are interested in performing good inference during test-time,
- \(w(y)\) is a weighting for each \(y\),
- \(f(y) := w(y)\pi(y)\) is just grouping the two together, and
- \(\mathrm{divergence}\) measure a distance between two probability distributions (see Wikipedia).

The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \underbrace{\left[\KL{q_{\phi}(\cdot \given y)}{p(\cdot \given y)} - \log p(y)\right]}_{-\mathrm{ELBO}(\phi, \theta, y)} \,\mathrm dy. \label{eqn:amortization/vae-objective} \end{align} Call this the \(qp\) loss. This objective is also suitable for simultaneous model learning.

The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \, \KL{p(\cdot \given y)}{q_\phi(\cdot \given y)} \,\mathrm dy. \end{align} Call this the \(pq\) loss.

Let’s compare these two losses on a sequence of increasingly difficult generative models.

The generative model:
\begin{align}
p(x) &= \mathrm{Normal}(x; \mu_0, \sigma_0^2) \\

p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2).
\end{align}
In the experiments, we set \(\mu_0 = 0\), \(\sigma_0 = 1\), \(\sigma = 1\).

The inference network: \begin{align} q_{\phi}(x \given y) = \mathrm{Normal}(x; ay + b, c^2), \end{align} where \(\phi = (a, b, c)\) consists of a multiplier \(a\), offset \(b\) and standard deviation \(c\).

We can obtain the true values for \(\phi\): \(a^* = \frac{1/\sigma^2}{1/\sigma_0^2 + 1/\sigma^2}\), \(b^* = \frac{\mu_0/\sigma_0^2}{1/\sigma_0^2 + 1/\sigma^2}\), \(c^* = \frac{1}{1/\sigma_0^2 + 1/\sigma^2}\).

Amortizing inference using both the \(pq\) and the \(qp\) loss gets it spot on.

The generative model:
\begin{align}
p(x) &= \sum_{k = 1}^K \pi_k \mathrm{Normal}(x; \mu_k, \sigma_k^2) \\

p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2).
\end{align}
In the experiments, we set \(K = 2\) and \(\mu_1 = -5\), \(\mu_2 = 5\), \(\pi_k = 1 / K\), \(\sigma_k = 1\), \(\sigma = 10\).

The inference network: \begin{align} q_{\phi}(x \given y) &= \mathrm{Normal}(x; \eta_{\phi_1}^1(y), \eta_{\phi_2}^2(y)), \end{align} where \(\eta_{\phi_1}^1\) and \(\eta_{\phi_2}^2\) are neural networks parameterized by \(\phi = (\phi_1, \phi_2)\).

Amortizing inference using the \(pq\) loss results in the mass-covering/mean-seeking behavior whereas using the \(qp\) loss results in zero-forcing/mode-seeking behavior (for various test observations \(y\)). The zero-forcing/mode-seeking behavior is very clear in the second plot below: \(\eta_{\phi_1}^1\) always maps to either \(\mu_1\) or \(\mu_2\), depending on which peak is larger; \(\eta_{\phi_2}^2\) always maps to more or less a constant. It is also interesting to look at \(\eta_{\phi_1}^1\) and \(\eta_{\phi_2}^2\) when the \(pq\) loss is used to amortize inference. It would actually make more sense if \(\eta_{\phi_2}^2\) dropped to the same value as \(\eta_{\phi_2}^2\) in the \(qp\) case.

The generative model:
\begin{align}
p(z) &= \mathrm{Discrete}(z; \pi_1, \dotsc, \pi_K) \\

p(x \given z) &= \mathrm{Normal}(x; \mu_z, \sigma_z^2) \\

p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2).
\end{align}
In the experiments, we set \(K = 2\) and \(\mu_1 = -5\), \(\mu_2 = 5\), \(\pi_k = 1 / K\), \(\sigma_k = 1\), \(\sigma = 10\).

The inference network:
\begin{align}
q_{\phi}(z \given y) &= \mathrm{Discrete}(z; \eta_{\phi_z}^{z \given y}(y)) \\

q_{\phi}(x \given y, z) &= \mathrm{Normal}(x; \eta_{\phi_\mu}^{\mu \given y, z}(y, z), \eta_{\phi_{\sigma^2}}^{\sigma^2 \given y, z}(y, z)),
\end{align}
where \(\eta_{\phi_z}^{z \given y}\), \(\eta_{\phi_\mu}^{\mu \given y, z}(y, z)\), and \(\eta_{\phi_{\sigma^2}}^{\sigma^2 \given y, z}\) are neural networks parameterized by \(\phi = (\phi_z, \phi_\mu, \phi_{\sigma^2})\).

In the first plot below, we show the marginal posterior \(p(z \given y)\) and \(p(x \given y)\) and the corresponding marginals of the inference network. The posterior density is approximated as a kernel density estimation of resampled importance samples. In the second plot below, we show the outputs of the neural networks for different inputs.

Generative model:

- \(K \sim \mathrm{Discrete}(\alpha_1, \alpha_2) + 1\).
- if \(K = 1\):
- \(x \sim \mathrm{Normal}(\mu_{1, 1}, \sigma_{1, 1}^2)\).

- else:
- \(z \sim \mathrm{Discrete}(\pi_1, \pi_2)\).
- if \(z = 0\):
- \(x \sim \mathrm{Normal}(\mu_{2, 1}, \sigma_{2, 1}^2)\).

- else:
- \(x \sim \mathrm{Normal}(\mu_{2, 2}, \sigma_{2, 2}^2)\).

- observe \(y\) under \(\mathrm{Normal}(x, \sigma^2)\).

The inference network:

- \(K \sim \mathrm{Discrete}(\eta_{\phi_1}^1(y)) + 1\).
- if \(K = 1\):
- \(x \sim \mathrm{Normal}(\eta_{\phi_2}^2(y, K), \eta_{\phi_3}^3(y, K))\).

- else:
- \(z \sim \mathrm{Discrete}(\eta_{\phi_3}^3(y, K))\).
- if \(z = 0\):
- \(x \sim \mathrm{Normal}(\eta_{\phi_4}^4(y, K, z), \eta_{\phi_5}^5(y, K, z))\).

- else:
- \(x \sim \mathrm{Normal}(\eta_{\phi_4}^4(y, K, z), \eta_{\phi_5}^5(y, K, z))\).

In the plot below, we can see that the inference network learns a good mapping to the posterior (tested on various test observations).

[back]