Tuan Anh Le

Conditional Variational Autoencoders

27 June 2018

This is a perspective on the conditional variational autoencoder.

Variational Autoencoders

In a typical variational autoencoder (VAE), we have

Goal: Given a true data distribution \(p(x)\), we want to learn \((\theta, \phi)\) such that \(p_\theta(x)\) approximates \(p(x)\) and \(q_\phi(z \given x)\) approximates \(p_\theta(z \given x)\) for all \(x\).

Let the evidence lower bound (ELBO) be defined as \begin{align} \mathrm{ELBO}(x, \theta, \phi) = \log p_\theta(x) - \KL{q_\phi(z \given x)}{p_\theta(z \given x)}. \end{align} Maximizing \(\E_{p(x)}[\mathrm{ELBO}(x, \theta, \phi)]\) achieves our goal since it is equivalent to minimizing \(\KL{p(x)}{p_\theta(x)} + \E_{p(x)}[\KL{q_\phi(z \given x)}{p_\theta(z \given x)}]\): \begin{align} \E_{p(x)}[\mathrm{ELBO}(x, \theta, \phi)] &= \E_{p(x)}[\log p_\theta(x) - \KL{q_\phi(z \given x)}{p_\theta(z \given x)}] \\
&= \E_{p(x)}[\log p_\theta(x) - \log p(x)] + \E_{p(x)}[\log p(x)] - \E_{p(x)}[\KL{q_\phi(z \given x)}{p_\theta(z \given x)}] \\
&= -\KL{p(x)}{p_\theta(x)} - \E_{p(x)}[\KL{q_\phi(z \given x)}{p_\theta(z \given x)}] + \E_{p(x)}[\log p(x)]. \end{align}

Conditional Variational Autoencoders

In a conditional VAE, we have

Goal: Given a true conditional data distribution \(p(x \given c)\) for all \(c\), we want to learn \((\theta, \phi)\) such that

Let the conditional ELBO be defined as \begin{align} \mathrm{ELBO}(x, \theta, \phi \given c) = \log p_\theta(x \given c) - \KL{q_\phi(z \given x, c)}{p_\theta(z \given x, c)}. \end{align} Given a distribution \(p(c)\) whose support contains is the set of all \(c\), maximizing \(\E_{p(x \given c) p(c)}[\mathrm{ELBO}(x, \theta, \phi \given c)]\) with respect to \((\theta, \phi)\) achieves our goal since it is equivalent to minimizing \(\E_{p(c)}[\KL{p(x \given c)}{p_\theta(x \given c)}] + \E_{p(x \given c) p(c)}[\KL{q_\phi(z \given x, c)}{p_\theta(z \given x, c)}]\): \begin{align} \E_{p(x \given c) p(c)}[\mathrm{ELBO}(x, \theta, \phi \given c)] &= \E_{p(x \given c) p(c)}[\log p_\theta(x \given c) - \KL{q_\phi(z \given x, c)}{p_\theta(z \given x, c)}] \\
&= \E_{p(x \given c) p(c)}[\log p_\theta(x \given c) - \log p(x \given c)] + \E_{p(x \given c) p(c)}[\log p(x \given c)] - \E_{p(x \given c) p(c)}[\KL{q_\phi(z \given x, c)}{p_\theta(z \given x, c)}] \\
&= -\E_{p(c)}[\KL{p(x \given c)}{p_\theta(x \given c)}] - \E_{p(x \given c) p(c)}[\KL{q_\phi(z \given x, c)}{p_\theta(z \given x, c)}] + \E_{p(x \given c) p(c)}[\log p(x \given c)]. \end{align}

Gaussian Unknown Mean Example

Let the conditional generative model be \begin{align} p_\theta(z \given c) &= \mathrm{Normal}(z \given \theta_1 + \theta_2 c, \sigma_0^2) \\
p_\theta(x \given z, c) &= \mathrm{Normal}(x \given z, \exp(\theta_3)), \end{align} where \(\theta = (\theta_1, \theta_2, \theta_3)\) and the conditional inference network be \begin{align} q_\phi(z \given x, c) &= \mathrm{Normal}(z \given \phi_1 x + \phi_2 c + \phi_3, \exp(\phi_4)), \end{align} where \(\phi = (\phi_1, \phi_2, \phi_3, \phi_4)\).

Let the true conditional data distribution \(p(x \given c)\) be defined as a marginal distribution of \(p(z \given c)p(x \given z, c)\) which is defined as: \begin{align} p(z \given c) &= \mathrm{Normal}(z \given \mu_0 + c, \sigma_0^2) \\
p(x \given z, c) &= \mathrm{Normal}(x \given z, \sigma^2). \end{align} The posterior can be analytically derived as \begin{align} p(z \given x, c) &= \mathrm{Normal}\left(z \given \frac{1/\sigma^2}{1/\sigma_0^2 + 1/\sigma^2} x + \frac{1/\sigma_0^2}{1/\sigma_0^2 + 1/\sigma^2} c + \frac{\mu_0/\sigma_0^2}{1/\sigma_0^2 + 1/\sigma^2}, \frac{1}{1/\sigma_0^2 + 1/\sigma^2}\right). \end{align}

Maximizing \(\E_{p(x \given c) p(c)}[\mathrm{ELBO}(x, \theta, \phi \given c)]\) with respect to \((\theta, \phi)\) should yield: \begin{align} \theta_1^* &= \mu_0, \\
\theta_2^* &= 1, \\
\theta_3^* &= \log(\sigma^2), \\
\phi_1^* &= \frac{1/\sigma^2}{1/\sigma_0^2 + 1/\sigma^2}, \\
\phi_2^* &= \frac{1/\sigma_0^2}{1/\sigma_0^2 + 1/\sigma^2}, \\
\phi_3^* &= \frac{\mu_0/\sigma_0^2}{1/\sigma_0^2 + 1/\sigma^2}, \\
\phi_4^* &= \log\left(\frac{1}{1/\sigma_0^2 + 1/\sigma^2}\right).
\end{align}

[back]