Tuan Anh Le

Amortized Inference

19 December 2017

Consider a probabilistic model \(p(x, y)\) of \(\mathcal X\)-valued latents \(x\), \(\mathcal Y\)-valued observes \(y\). Amortized inference is finding a mapping \(g_\phi: \mathcal Y \to \mathcal P(\mathcal X)\) from an observation \(y\) to a distribution \(q_{\phi}(x \given y)\) that is close to \(p(x \given y)\). Let’s go with the following objective that is to be minimized: \begin{align} \mathcal L(\phi) = \int \overbrace{w(y)\pi(y)}^{f(y)} \, \mathrm{divergence}\left(p(\cdot \given y), g_\phi(y)\right) \,\mathrm dy. \label{eqn:amortization/objective} \end{align}

Here,

Variational Inference

The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \underbrace{\left[\KL{q_{\phi}(\cdot \given y)}{p(\cdot \given y)} - \log p(y)\right]}_{-\mathrm{ELBO}(\phi, \theta, y)} \,\mathrm dy. \label{eqn:amortization/vae-objective} \end{align} Call this the \(qp\) loss. This objective is also suitable for simultaneous model learning.

Inference Compilation

The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \, \KL{p(\cdot \given y)}{q_\phi(\cdot \given y)} \,\mathrm dy. \end{align} Call this the \(pq\) loss.

Examples

Let’s compare these two losses on a sequence of increasingly difficult generative models.

Gaussian Unknown Mean

The generative model: \begin{align} p(x) &= \mathrm{Normal}(x; \mu_0, \sigma_0^2) \\
p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set \(\mu_0 = 0\), \(\sigma_0 = 1\), \(\sigma = 1\).

The inference network: \begin{align} q_{\phi}(x \given y) = \mathrm{Normal}(x; ay + b, c^2), \end{align} where \(\phi = (a, b, c)\) consists of a multiplier \(a\), offset \(b\) and standard deviation \(c\).

We can obtain the true values for \(\phi\): \(a^* = \frac{1/\sigma^2}{1/\sigma_0^2 + 1/\sigma^2}\), \(b^* = \frac{\mu_0/\sigma_0^2}{1/\sigma_0^2 + 1/\sigma^2}\), \(c^* = \frac{1}{1/\sigma_0^2 + 1/\sigma^2}\).

Amortizing inference using both the \(pq\) and the \(qp\) loss gets it spot on.

Python script for generating these figures.

Gaussian Mixtures

The generative model: \begin{align} p(x) &= \sum_{k = 1}^K \pi_k \mathrm{Normal}(x; \mu_k, \sigma_k^2) \\
p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set \(K = 2\) and \(\mu_1 = -5\), \(\mu_2 = 5\), \(\pi_k = 1 / K\), \(\sigma_k = 1\), \(\sigma = 10\).

The inference network: \begin{align} q_{\phi}(x \given y) &= \mathrm{Normal}(x; \eta_{\phi_1}^1(y), \eta_{\phi_2}^2(y)), \end{align} where \(\eta_{\phi_1}^1\) and \(\eta_{\phi_2}^2\) are neural networks parameterized by \(\phi = (\phi_1, \phi_2)\).

Amortizing inference using the \(pq\) loss results in the mass-covering/mean-seeking behavior whereas using the \(qp\) loss results in zero-forcing/mode-seeking behavior (for various test observations \(y\)). The zero-forcing/mode-seeking behavior is very clear in the second plot below: \(\eta_{\phi_1}^1\) always maps to either \(\mu_1\) or \(\mu_2\), depending on which peak is larger; \(\eta_{\phi_2}^2\) always maps to more or less a constant. It is also interesting to look at \(\eta_{\phi_1}^1\) and \(\eta_{\phi_2}^2\) when the \(pq\) loss is used to amortize inference. It would actually make more sense if \(\eta_{\phi_2}^2\) dropped to the same value as \(\eta_{\phi_2}^2\) in the \(qp\) case.

Python script for generating these figures.

Gaussian Mixtures (Non-Marginalized)

The generative model: \begin{align} p(z) &= \mathrm{Discrete}(z; \pi_1, \dotsc, \pi_K) \\
p(x \given z) &= \mathrm{Normal}(x; \mu_z, \sigma_z^2) \\
p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set \(K = 2\) and \(\mu_1 = -5\), \(\mu_2 = 5\), \(\pi_k = 1 / K\), \(\sigma_k = 1\), \(\sigma = 10\).

The inference network: \begin{align} q_{\phi}(z \given y) &= \mathrm{Discrete}(z; \eta_{\phi_z}^{z \given y}(y)) \\
q_{\phi}(x \given y, z) &= \mathrm{Normal}(x; \eta_{\phi_\mu}^{\mu \given y, z}(y, z), \eta_{\phi_{\sigma^2}}^{\sigma^2 \given y, z}(y, z)), \end{align} where \(\eta_{\phi_z}^{z \given y}\), \(\eta_{\phi_\mu}^{\mu \given y, z}(y, z)\), and \(\eta_{\phi_{\sigma^2}}^{\sigma^2 \given y, z}\) are neural networks parameterized by \(\phi = (\phi_z, \phi_\mu, \phi_{\sigma^2})\).

In the first plot below, we show the marginal posterior \(p(z \given y)\) and \(p(x \given y)\) and the corresponding marginals of the inference network. The posterior density is approximated as a kernel density estimation of resampled importance samples. In the second plot below, we show the outputs of the neural networks for different inputs.

Python script for generating these figures.

Gaussian Mixtures (Open Universe)

Generative model:

The inference network:

In the plot below, we can see that the inference network learns a good mapping to the posterior (tested on various test observations).

Python script for generating these figures.

References

    [back]