# Attend, Infer, Repeat

19 February 2018

## Generative network

Here is a pseudocode for the generative network $$p_\theta(x \given z)$$ where the observation $$x \in \mathbb R^{D \times D}$$ is an image and $$z$$ are all the latent variables in the execution trace. $$\theta$$ contains parameters of various neural nets in the generative network.

• Initialize the image mean $$\mu = 0$$,
• While $$\mathrm{sample}\left(\mathrm{Bernoulli}(\rho)\right)$$:
• $$z^{\text{where}} = \mathrm{sample}\left(\mathrm{Normal}(0, I)\right)$$,
• $$z^{\text{what}} = \mathrm{sample}\left(\mathrm{Normal}(0, I)\right)$$,
• $$\hat g = D_\theta(z^{\text{what}})$$,
• $$\mu = \mu + \mathrm{STN}^{-1}(\hat g, z^{\text{where}})$$,
• $$\mathrm{observe}\left(x, \prod_{\text{pixel } i} \mathrm{Normal}(\mu_i, \sigma_x^2)\right)$$.

where $$\mathrm{STN}^{-1}$$ is an inverse Spatial Transformer Network, and the $$D_\theta$$ a parametric function.

## Inference Network

Here is a pseudocode for the inference network $$q_\phi(z \given x)$$. $$\phi$$ contains parameters of various neural nets in the inference network.

• Initialize the hidden state $$h = 0$$ for the LSTM cell $$R_\phi$$,
• $$w, h = R_\phi(\mathrm{concat}(x, 0, 0), h)$$,
• While $$\mathrm{sample}\left(\mathrm{Bernoulli}(f_\phi(w))\right)$$:
• $$w, h = R_\phi(\mathrm{concat}(x, z^{\text{where}}, z^{\text{what}}), h)$$,
• $$z^{\text{where}} = \mathrm{sample}\left(\mathrm{Normal}(\mu_\phi^{\text{where}}(w), \sigma_\phi^{\text{where}}(w)^2)\right)$$,
• $$g = \mathrm{STN}(x, z^{\text{where}})$$,
• $$z^{\text{what}} = \mathrm{sample}\left(\mathrm{Normal}(\mu_\phi^{\text{what}}(g), \sigma_\phi^{\text{what}}(g)^2)\right)$$.

where $$\mathrm{STN}$$ is a Spatial Transformer Network, and the LSTM cell $$R_\phi$$ takes in an (input, hidden state) pair and outputs an (output, next hidden state) pair.

[back]