import argparse import numpy as np import matplotlib.pyplot as plt import scipy.stats import torch import torch.nn as nn import torch.distributions import torch.nn.functional as F from torch.autograd import Variable import torch.optim as optim from numbers import Number def logsumexp(values, dim=0, keepdim=False): """Logsumexp of a Tensor/Variable. See https://en.wikipedia.org/wiki/LogSumExp. input: values: Tensor/Variable [dim_1, ..., dim_N] dim: n output: result Tensor/Variable [dim_1, ..., dim_{n - 1}, dim_{n + 1}, ..., dim_N] where result[i_1, ..., i_{n - 1}, i_{n + 1}, ..., i_N] = log(sum_{i_n = 1}^N exp(values[i_1, ..., i_N])) """ values_max, _ = torch.max(values, dim=dim, keepdim=True) result = values_max + torch.log(torch.sum( torch.exp(values - values_max), dim=dim, keepdim=True )) return result if keepdim else result.squeeze(dim) class NormalMixture1D(torch.distributions.Distribution): def __init__(self, mixture_probs, means, stds, variable=False): self.num_mixtures = len(mixture_probs) self.mixture_probs = Variable(torch.Tensor(mixture_probs)) if variable else torch.Tensor(mixture_probs) self.categorical = torch.distributions.Categorical( Variable(torch.Tensor(mixture_probs)) if variable else torch.Tensor(mixture_probs) ) self.normals = torch.distributions.Normal( mean=Variable(torch.Tensor(means)) if variable else torch.Tensor(means), std=Variable(torch.Tensor(stds)) if variable else torch.Tensor(stds) ) def sample(self): return self.normals.sample()[self.categorical.sample()] def sample_n(self, n): return torch.gather(self.normals.sample_n(n), 1, torch.multinomial(self.mixture_probs, n, True).unsqueeze(-1)).squeeze(-1) def log_prob(self, value): if isinstance(value, Number): return logsumexp(self.normals.log_prob(value) + torch.log(self.mixture_probs)) else: return logsumexp(self.normals.log_prob(value.unsqueeze(-1).expand(-1, self.num_mixtures)) + torch.log(self.mixture_probs), dim=1) def get_posterior_log_prob(xs, obs, mixture_probs, means, stds, obs_std): xs_tensor = torch.from_numpy(xs).float() log_prior = NormalMixture1D(mixture_probs, means, stds).log_prob(xs_tensor).numpy() log_likelihood = torch.distributions.Normal(xs_tensor, torch.Tensor([obs_std]).expand_as(xs_tensor)).log_prob(obs).numpy() log_evidence = NormalMixture1D(mixture_probs, means, [np.sqrt(std**2 + obs_std**2) for std in stds]).log_prob(obs).numpy() return log_prior + log_likelihood - log_evidence def generate_traces(num_traces, mixture_probs, means, stds, obs_std): nm = NormalMixture1D(mixture_probs, means, stds) x = nm.sample_n(num_traces).numpy() obs = np.random.normal(loc=x, scale=obs_std) return x, obs class InferenceNetworkPQ(nn.Module): def __init__(self): super(InferenceNetworkPQ, self).__init__() self.mean_lin_1 = nn.Linear(1, 4) self.mean_lin_2 = nn.Linear(4, 4) self.mean_lin_3 = nn.Linear(4, 1) self.std_lin_1 = nn.Linear(1, 4) self.std_lin_2 = nn.Linear(4, 4) self.std_lin_3 = nn.Linear(4, 1) def get_q_params(self, obs): q_mean = self.mean_lin_3(F.tanh(self.mean_lin_2(F.tanh(self.mean_lin_1(obs.unsqueeze(-1)))))).squeeze(-1) q_std = torch.exp(self.std_lin_3(F.tanh(self.std_lin_2(F.tanh(self.std_lin_1(obs.unsqueeze(-1))))))).squeeze(-1) return q_mean, q_std def forward(self, x, obs): q_mean, q_std = self.get_q_params(obs) return -torch.mean(torch.distributions.Normal( mean=q_mean, std=q_std ).log_prob(x)) class InferenceNetworkQP(nn.Module): def __init__(self, mixture_probs, means, stds, obs_std): super(InferenceNetworkQP, self).__init__() self.mean_lin_1 = nn.Linear(1, 4) self.mean_lin_2 = nn.Linear(4, 4) self.mean_lin_3 = nn.Linear(4, 1) self.std_lin_1 = nn.Linear(1, 4) self.std_lin_2 = nn.Linear(4, 4) self.std_lin_3 = nn.Linear(4, 1) self.mixture_probs = mixture_probs self.means = means self.stds = stds self.obs_std = obs_std def get_q_params(self, obs): q_mean = self.mean_lin_3(F.tanh(self.mean_lin_2(F.tanh(self.mean_lin_1(obs.unsqueeze(-1)))))).squeeze(-1) q_std = torch.exp(self.std_lin_3(F.tanh(self.std_lin_2(F.tanh(self.std_lin_1(obs.unsqueeze(-1))))))).squeeze(-1) return q_mean, q_std def forward(self, _, obs): q_mean, q_std = self.get_q_params(obs) x = Variable(torch.distributions.Normal( mean=torch.zeros(*obs.size()), std=torch.ones(*obs.size()) ).sample()) * q_std + q_mean obs_std = Variable(torch.Tensor([self.obs_std]).expand_as(obs)) return -torch.mean( NormalMixture1D(self.mixture_probs, self.means, self.stds, variable=True).log_prob(x) + torch.distributions.Normal(mean=x, std=self.obs_std).log_prob(obs) - torch.distributions.Normal(mean=q_mean, std=q_std).log_prob(x) ) def amortize_inference(mixture_probs, means, stds, obs_std, num_iterations, num_traces, learning_rate, loss_type): loss_history = np.zeros([num_iterations]) if loss_type == 'qp': inference_network = InferenceNetworkQP(mixture_probs, means, stds, obs_std) else: inference_network = InferenceNetworkPQ() optimizer = optim.Adam(inference_network.parameters(), lr=learning_rate) for i in range(num_iterations): x, obs = generate_traces(num_traces, mixture_probs, means, stds, obs_std) x_var = Variable(torch.from_numpy(x).float()) obs_var = Variable(torch.from_numpy(obs).float()) optimizer.zero_grad() loss = inference_network(x_var, obs_var) loss.backward() optimizer.step() loss_history[i] = loss.data[0] if i % 100 == 0: print('Iteration {}: Loss = {}'.format(i, loss_history[i])) return loss_history, inference_network def main(): parser = argparse.ArgumentParser() parser.add_argument('--checkpoint', action='store_true') args = parser.parse_args() mixture_probs = [0.5, 0.5] means = [-5, 5] stds = [1, 1] obs_std = 10 num_iterations = 100000 num_traces = 1000 learning_rate = 0.001 inference_network_pq_filename = 'inference_network_pq.pt' inference_network_qp_filename = 'inference_network_qp.pt' if not args.checkpoint: # Amortize inference loss_type = 'pq' loss_history_pq, inference_network_pq = amortize_inference(mixture_probs, means, stds, obs_std, num_iterations, num_traces, learning_rate, loss_type) for [data, filename] in zip( [loss_history_pq], ['loss_history_pq.npy'] ): np.save(filename, data) print('Saved to {}'.format(filename)) torch.save(inference_network_pq.state_dict(), inference_network_pq_filename) print('Saved to {}'.format(inference_network_pq_filename)) loss_type = 'qp' loss_history_qp, inference_network_qp = amortize_inference(mixture_probs, means, stds, obs_std, num_iterations, num_traces, learning_rate, loss_type) for [data, filename] in zip( [loss_history_qp], ['loss_history_qp.npy'] ): np.save(filename, data) print('Saved to {}'.format(filename)) torch.save(inference_network_qp.state_dict(), inference_network_qp_filename) print('Saved to {}'.format(inference_network_qp_filename)) else: # Load inference artifact filename = 'loss_history_qp.npy' loss_history_qp = np.load(filename) print('Loaded from {}'.format(filename)) filename = 'loss_history_pq.npy' loss_history_pq = np.load(filename) print('Loaded from {}'.format(filename)) inference_network_qp = InferenceNetworkQP(mixture_probs, means, stds, obs_std) inference_network_qp.load_state_dict(torch.load(inference_network_qp_filename)) print('Loaded from {}'.format(inference_network_qp_filename)) inference_network_pq = InferenceNetworkPQ() inference_network_pq.load_state_dict(torch.load(inference_network_pq_filename)) print('Loaded from {}'.format(inference_network_pq_filename)) # Plot stats iterations = np.arange(num_iterations) fig, axs = plt.subplots(2, 1, sharex=True) fig.set_size_inches(3, 3) axs[0].plot(iterations, loss_history_qp, label='qp loss', color='black') axs[0].set_ylabel('$qp$ loss') axs[0].spines['right'].set_visible(False) axs[0].spines['top'].set_visible(False) axs[1].plot(iterations, loss_history_pq, label='pq loss', color='black') axs[1].set_ylabel('$pq$ loss') axs[1].spines['right'].set_visible(False) axs[1].spines['top'].set_visible(False) filenames = ['gaussian_mixture_1.pdf', 'gaussian_mixture_1.png'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) # Plot test obs num_test_obs = 5 test_obs_min = -40 test_obs_max = 40 test_obss = np.linspace(test_obs_min, test_obs_max, num=num_test_obs) num_xs = 100 min_xs = -40 max_xs = 40 xs = np.linspace(min_xs, max_xs, num_xs) fig, axs = plt.subplots(1, num_test_obs) fig.set_size_inches(8, 1.5) for test_obs_idx, test_obs in enumerate(test_obss): prior_pdf = torch.exp(NormalMixture1D(mixture_probs, means, stds).log_prob(torch.from_numpy(xs).float())).numpy() posterior_pdf = np.exp(get_posterior_log_prob(xs, test_obs, mixture_probs, means, stds, obs_std)) q_mean_qp_var, q_std_qp_var = inference_network_qp.get_q_params(Variable(torch.Tensor([test_obs]))) learned_posterior_pdf_qp = scipy.stats.norm.pdf(xs, loc=q_mean_qp_var.data[0], scale=q_std_qp_var.data[0]) q_mean_pq_var, q_std_pq_var = inference_network_pq.get_q_params(Variable(torch.Tensor([test_obs]))) learned_posterior_pdf_pq = scipy.stats.norm.pdf(xs, loc=q_mean_pq_var.data[0], scale=q_std_pq_var.data[0]) axs[test_obs_idx].scatter(x=test_obs, y=0, color='black', label='test obs', marker='x') axs[test_obs_idx].plot(xs, prior_pdf, label='prior', color='lightgray') axs[test_obs_idx].plot(xs, posterior_pdf, label='posterior', color='black', linestyle='solid', alpha=0.8) axs[test_obs_idx].plot(xs, learned_posterior_pdf_qp, label='inference network qp', color='black', linestyle='dashed', alpha=0.8) axs[test_obs_idx].plot(xs, learned_posterior_pdf_pq, label='inference network pq', color='black', linestyle='dotted', alpha=0.8) axs[test_obs_idx].spines['right'].set_visible(False) axs[test_obs_idx].spines['top'].set_visible(False) axs[test_obs_idx].set_yticks([]) axs[test_obs_idx].set_xticks([]) axs[num_test_obs // 2].legend(ncol=5, loc='upper center', bbox_to_anchor=(0.5, 0), fontsize='small') filenames = ['gaussian_mixture_2.pdf', 'gaussian_mixture_2.png'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) # Plot inference network mappings num_test_obs = 1000 test_obs_min = -200 test_obs_max = 200 test_obss = np.linspace(test_obs_min, test_obs_max, num=num_test_obs) q_mean_qp = np.zeros([num_test_obs]) q_std_qp = np.zeros([num_test_obs]) q_mean_pq = np.zeros([num_test_obs]) q_std_pq = np.zeros([num_test_obs]) for test_obs_idx, test_obs in enumerate(test_obss): q_mean_qp_var, q_std_qp_var = inference_network_qp.get_q_params(Variable(torch.Tensor([test_obs]))) q_mean_pq_var, q_std_pq_var = inference_network_pq.get_q_params(Variable(torch.Tensor([test_obs]))) q_mean_qp[test_obs_idx] = q_mean_qp_var.data[0] q_std_qp[test_obs_idx] = q_std_qp_var.data[0] q_mean_pq[test_obs_idx] = q_mean_pq_var.data[0] q_std_pq[test_obs_idx] = q_std_pq_var.data[0] fig, axs = plt.subplots(2, 1, sharex=True) fig.set_size_inches(5.5, 2.5) axs[0].plot(test_obss, q_mean_qp, color='black', linestyle='dashed', label='qp') axs[0].plot(test_obss, q_mean_pq, color='black', linestyle='dotted', label='pq') axs[0].spines['top'].set_visible(False) axs[0].spines['right'].set_visible(False) axs[0].set_ylabel('mean') axs[1].plot(test_obss, q_std_qp, color='black', linestyle='dashed', label='qp') axs[1].plot(test_obss, q_std_pq, color='black', linestyle='dotted', label='pq') axs[1].spines['top'].set_visible(False) axs[1].spines['right'].set_visible(False) axs[1].set_ylabel('std') axs[1].set_xlabel('test obs') axs[1].legend(ncol=2, fontsize='small', loc='upper center', bbox_to_anchor=(0.5, -0.45)) fig.suptitle('Parameters of $q$') filenames = ['gaussian_mixture_3.pdf', 'gaussian_mixture_3.png'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) if __name__ == '__main__': main()