import numpy as np import matplotlib.pyplot as plt import scipy.stats import torch import torch.nn as nn import torch.distributions from torch.autograd import Variable import torch.optim as optim def get_posterior_params(obs, prior_mean, prior_std, obs_std): posterior_var = 1 / (1 / prior_std**2 + 1 / obs_std**2) posterior_std = np.sqrt(posterior_var) posterior_mean = posterior_var * (prior_mean / prior_std**2 + obs / obs_std**2) return posterior_mean, posterior_std def get_proposal_params(prior_mean, prior_std, obs_std): posterior_var = 1 / (1 / prior_std**2 + 1 / obs_std**2) posterior_std = np.sqrt(posterior_var) multiplier = posterior_var / obs_std**2 offset = posterior_var * prior_mean / prior_std**2 return multiplier, offset, posterior_std def generate_traces(num_traces, prior_mean, prior_std, obs_std): x = np.random.normal(loc=prior_mean, scale=prior_std, size=num_traces) obs = np.random.normal(loc=x, scale=obs_std) return x, obs class InferenceNetworkPQ(nn.Module): def __init__(self, init_multiplier, init_offset, init_std): super(InferenceNetworkPQ, self).__init__() self.multiplier = nn.Parameter(torch.Tensor([init_multiplier])) self.offset = nn.Parameter(torch.Tensor([init_offset])) self.std = nn.Parameter(torch.Tensor([init_std])) def forward(self, x, obs): return -torch.mean(torch.distributions.Normal( mean=self.multiplier.expand_as(obs) * obs + self.offset.expand_as(obs), std=self.std.expand_as(obs) ).log_prob(x)) class InferenceNetworkQP(nn.Module): def __init__(self, init_multiplier, init_offset, init_std, prior_mean, prior_std, obs_std): super(InferenceNetworkQP, self).__init__() self.multiplier = nn.Parameter(torch.Tensor([init_multiplier])) self.offset = nn.Parameter(torch.Tensor([init_offset])) self.std = nn.Parameter(torch.Tensor([init_std])) self.prior_mean = prior_mean self.prior_std = prior_std self.obs_std = obs_std def forward(self, _, obs): q_mean = self.multiplier.expand_as(obs) * obs + self.offset.expand_as(obs) q_std = self.std.expand_as(obs) x = Variable(torch.distributions.Normal( mean=torch.zeros(*obs.size()), std=torch.ones(*obs.size()) ).sample()) * q_std + q_mean prior_mean = Variable(torch.Tensor([self.prior_mean]).expand_as(obs)) prior_std = Variable(torch.Tensor([self.prior_std]).expand_as(obs)) obs_std = Variable(torch.Tensor([self.obs_std]).expand_as(obs)) return -torch.mean( torch.distributions.Normal(mean=prior_mean, std=prior_std).log_prob(x) + torch.distributions.Normal(mean=x, std=obs_std).log_prob(obs) - torch.distributions.Normal(mean=q_mean, std=q_std).log_prob(x) ) def amortize_inference(prior_mean, prior_std, obs_std, num_iterations, num_traces, learning_rate, init_multiplier, init_offset, init_std, loss_type): multiplier_history = np.zeros([num_iterations]) offset_history = np.zeros([num_iterations]) std_history = np.zeros([num_iterations]) loss_history = np.zeros([num_iterations]) if loss_type == 'qp': inference_network = InferenceNetworkQP(init_multiplier, init_offset, init_std, prior_mean, prior_std, obs_std) else: inference_network = InferenceNetworkPQ(init_multiplier, init_offset, init_std) sgd_optimizer = optim.SGD(inference_network.parameters(), lr=learning_rate) for i in range(num_iterations): x, obs = generate_traces(num_traces, prior_mean, prior_std, obs_std) x_var = Variable(torch.from_numpy(x).float()) obs_var = Variable(torch.from_numpy(obs).float()) sgd_optimizer.zero_grad() loss = inference_network(x_var, obs_var) loss.backward() sgd_optimizer.step() multiplier_history[i] = inference_network.multiplier.data[0] offset_history[i] = inference_network.offset.data[0] std_history[i] = inference_network.std.data[0] loss_history[i] = loss.data[0] return multiplier_history, offset_history, std_history, loss_history def main(): prior_mean, prior_std, obs_std = 0, 1, 1 num_iterations = 1000 num_traces = 10000 learning_rate = 0.01 init_multiplier, init_offset, init_std = 2, 2, 2 print('pq...') loss_type = 'pq' multiplier_history_pq, offset_history_pq, std_history_pq, loss_history_pq = amortize_inference(prior_mean, prior_std, obs_std, num_iterations, num_traces, learning_rate, init_multiplier, init_offset, init_std, loss_type) for [data, filename] in zip( [multiplier_history_pq, offset_history_pq, std_history_pq, loss_history_pq], ['multiplier_history_pq.npy', 'offset_history_pq.npy', 'std_history_pq.npy', 'loss_history_pq.npy'] ): np.save(filename, data) print('Saved to {}'.format(filename)) print('qp...') loss_type = 'qp' multiplier_history_qp, offset_history_qp, std_history_qp, loss_history_qp = amortize_inference(prior_mean, prior_std, obs_std, num_iterations, num_traces, learning_rate, init_multiplier, init_offset, init_std, loss_type) for [data, filename] in zip( [multiplier_history_qp, offset_history_qp, std_history_qp, loss_history_qp], ['multiplier_history_qp.npy', 'offset_history_qp.npy', 'std_history_qp.npy', 'loss_history_qp.npy'] ): np.save(filename, data) print('Saved to {}'.format(filename)) # Plot stats iterations = np.arange(num_iterations) true_multiplier, true_offset, true_std = get_proposal_params(prior_mean, prior_std, obs_std) fig, axs = plt.subplots(3, 1, sharex=True) fig.set_size_inches(3, 3) axs[0].plot(iterations, multiplier_history_qp, label='qp multiplier', color='darkgray', linestyle='solid') axs[0].plot(iterations, offset_history_qp, label='qp offset', color='darkgray', linestyle='dashed') axs[0].plot(iterations, std_history_qp, label='qp std', color='darkgray', linestyle='dotted') axs[0].plot(iterations, multiplier_history_pq, label='pq multiplier', color='lightgray', linestyle='solid') axs[0].plot(iterations, offset_history_pq, label='pq offset', color='lightgray', linestyle='dashed') axs[0].plot(iterations, std_history_pq, label='pq std', color='lightgray', linestyle='dotted') axs[0].axhline(true_multiplier, color='black', linestyle='solid') axs[0].axhline(true_offset, color='black', linestyle='dashed') axs[0].axhline(true_std, color='black', linestyle='dotted') axs[0].legend(ncol=1, fontsize='small', loc='center left', bbox_to_anchor=(1, 0.5)) axs[0].set_ylabel('$\phi$') axs[0].spines['right'].set_visible(False) axs[0].spines['top'].set_visible(False) axs[1].plot(iterations, loss_history_qp, label='qp loss', color='black') axs[1].set_ylabel('$qp$ loss') axs[1].spines['right'].set_visible(False) axs[1].spines['top'].set_visible(False) axs[2].plot(iterations, loss_history_pq, label='pq loss', color='black') axs[2].set_ylabel('$pq$ loss') axs[2].spines['right'].set_visible(False) axs[2].spines['top'].set_visible(False) filenames = ['gaussian_unknown_mean_1.pdf', 'gaussian_unknown_mean_1.png'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) # Plot test obs learned_multiplier_qp = multiplier_history_qp[-1] learned_offset_qp = offset_history_qp[-1] learned_std_qp = std_history_qp[-1] learned_multiplier_pq = multiplier_history_pq[-1] learned_offset_pq = offset_history_pq[-1] learned_std_pq = std_history_pq[-1] num_test_obs = 5 test_obs_min = -10 test_obs_max = 10 test_obss = np.linspace(test_obs_min, test_obs_max, num=num_test_obs) num_xs = 100 min_xs = -10 max_xs = 10 xs = np.linspace(min_xs, max_xs, num_xs) fig, axs = plt.subplots(1, num_test_obs) fig.set_size_inches(8, 1.5) for test_obs_idx, test_obs in enumerate(test_obss): posterior_mean, posterior_std = get_posterior_params(test_obs, prior_mean, prior_std, obs_std) prior_pdf = scipy.stats.norm.pdf(xs, loc=prior_mean, scale=prior_std) posterior_pdf = scipy.stats.norm.pdf(xs, loc=posterior_mean, scale=posterior_std) learned_posterior_pdf_qp = scipy.stats.norm.pdf(xs, loc=test_obs * learned_multiplier_qp + learned_offset_qp, scale=learned_std_qp) learned_posterior_pdf_pq = scipy.stats.norm.pdf(xs, loc=test_obs * learned_multiplier_pq + learned_offset_pq, scale=learned_std_pq) axs[test_obs_idx].scatter(x=test_obs, y=0, color='black', label='test obs', marker='x') axs[test_obs_idx].plot(xs, prior_pdf, label='prior', color='lightgray') axs[test_obs_idx].plot(xs, posterior_pdf, label='posterior', color='black', linestyle='solid', alpha=0.8) axs[test_obs_idx].plot(xs, learned_posterior_pdf_qp, label='inference network qp', color='black', linestyle='dashed', alpha=0.8) axs[test_obs_idx].plot(xs, learned_posterior_pdf_pq, label='inference network pq', color='black', linestyle='dotted', alpha=0.8) axs[test_obs_idx].spines['right'].set_visible(False) axs[test_obs_idx].spines['top'].set_visible(False) axs[test_obs_idx].set_yticks([]) axs[test_obs_idx].set_xticks([]) axs[num_test_obs // 2].legend(ncol=5, loc='upper center', bbox_to_anchor=(0.5, 0), fontsize='small') filenames = ['gaussian_unknown_mean_2.pdf', 'gaussian_unknown_mean_2.png'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) if __name__ == '__main__': main()