import argparse import numpy as np import matplotlib.pyplot as plt import scipy.stats import torch import torch.nn as nn import torch.distributions import torch.nn.functional as F from torch.autograd import Variable import torch.optim as optim from numbers import Number def logsumexp(values, dim=0, keepdim=False): """Logsumexp of a Tensor/Variable. See https://en.wikipedia.org/wiki/LogSumExp. input: values: Tensor/Variable [dim_1, ..., dim_N] dim: n output: result Tensor/Variable [dim_1, ..., dim_{n - 1}, dim_{n + 1}, ..., dim_N] where result[i_1, ..., i_{n - 1}, i_{n + 1}, ..., i_N] = log(sum_{i_n = 1}^N exp(values[i_1, ..., i_N])) """ values_max, _ = torch.max(values, dim=dim, keepdim=True) result = values_max + torch.log(torch.sum( torch.exp(values - values_max), dim=dim, keepdim=True )) return result if keepdim else result.squeeze(dim) def lognormexp(values, dim=0): """Exponentiates, normalizes and takes log of a Tensor/Variable/np.ndarray. input: values: Tensor/Variable/np.ndarray [dim_1, ..., dim_N] dim: n output: result: Tensor/Variable/np.ndarray [dim_1, ..., dim_N] where result[i_1, ..., i_N] = exp(values[i_1, ..., i_N]) log( ------------------------------------------------------------ ) sum_{j = 1}^{dim_n} exp(values[i_1, ..., j, ..., i_N]) """ if isinstance(values, np.ndarray): log_denominator = scipy.special.logsumexp( values, axis=dim, keepdims=True ) # log_numerator = values return values - log_denominator else: log_denominator = logsumexp(values, dim=dim, keepdim=True) # log_numerator = values return values - log_denominator def exponentiate_and_normalize(values, dim=0): """Exponentiates and normalizes a Tensor/Variable/np.ndarray. input: values: Tensor/Variable/np.ndarray [dim_1, ..., dim_N] dim: n output: result: Tensor/Variable/np.ndarray [dim_1, ..., dim_N] where result[i_1, ..., i_N] = exp(values[i_1, ..., i_N]) ------------------------------------------------------------ sum_{j = 1}^{dim_n} exp(values[i_1, ..., j, ..., i_N]) """ if isinstance(values, np.ndarray): return np.exp(lognormexp(values, dim=dim)) else: return torch.exp(lognormexp(values, dim=dim)) def sample_ancestral_index(log_weight, cuda=False): """Sample ancestral index using systematic resampling. input: log_weight: log of unnormalized weights, Tensor/Variable [batch_size, num_particles] output: zero-indexed ancestral index: LongTensor/Variable [batch_size, num_particles] """ if ( isinstance(log_weight, Variable) and torch.sum(log_weight != log_weight).data[0] != 0 ) or ( (not isinstance(log_weight, Variable)) and torch.sum(log_weight != log_weight) != 0 ): raise FloatingPointError('log_weight contains nan element(s)') batch_size, num_particles = log_weight.size() indices = np.zeros([batch_size, num_particles]) uniforms = np.random.uniform(size=[batch_size, 1]) pos = (uniforms + np.arange(0, num_particles)) / num_particles if isinstance(log_weight, Variable): normalized_weights = exponentiate_and_normalize( log_weight.cpu().data.numpy(), dim=1 ) else: normalized_weights = exponentiate_and_normalize( log_weight.cpu().numpy(), dim=1 ) # np.ndarray [batch_size, num_particles] cumulative_weights = np.cumsum(normalized_weights, axis=1) # hack to prevent numerical issues cumulative_weights = cumulative_weights / np.max( cumulative_weights, axis=1, keepdims=True ) for batch in range(batch_size): indices[batch] = np.digitize(pos[batch], cumulative_weights[batch]) if cuda: temp = torch.from_numpy(indices).long().cuda() else: temp = torch.from_numpy(indices).long() if isinstance(log_weight, Variable): return Variable(temp) else: return temp def generate_traces(num_traces, mixture_probs, means, stds, obs_std): normals = torch.distributions.Normal( mean=torch.Tensor(means), std=torch.Tensor(stds) ) z_tensor = torch.multinomial(torch.Tensor([mixture_probs]), num_traces, True).view(-1) x_tensor = torch.gather(normals.sample_n(num_traces), 1, z_tensor.unsqueeze(-1)).squeeze(-1) obs_tensor = torch.distributions.Normal(x_tensor, torch.Tensor([obs_std]).expand_as(x_tensor)).sample() z = z_tensor.numpy() x = x_tensor.numpy() obs = obs_tensor.numpy() return z, x, obs def get_posterior_samples(num_samples, obs, mixture_probs, means, stds, obs_std, num_importance_samples=None): if num_importance_samples is None: num_importance_samples = num_samples normals = torch.distributions.Normal( mean=torch.Tensor(means), std=torch.Tensor(stds) ) z_tensor = torch.multinomial(torch.Tensor([mixture_probs]), num_samples, True).view(-1) x_tensor = torch.gather(normals.sample_n(num_samples), 1, z_tensor.unsqueeze(-1)).squeeze(-1) log_weights = torch.distributions.Normal(x_tensor, torch.Tensor([obs_std]).expand_as(x_tensor)).log_prob(torch.Tensor([obs])) index = sample_ancestral_index(log_weights.unsqueeze(0)).squeeze(0) z_tensor_resampled = torch.gather(z_tensor, 0, index) x_tensor_resampled = torch.gather(x_tensor, 0, index) return z_tensor_resampled.numpy(), x_tensor_resampled.numpy() def get_posterior_pdf(x_points, num_samples, obs, mixture_probs, means, stds, obs_std, num_importance_samples=None): z, x = get_posterior_samples(num_samples, obs, mixture_probs, means, stds, obs_std) x_kde = scipy.stats.gaussian_kde(x) return x_kde.evaluate(x_points), [np.sum(z == i) / len(z) for i in range(len(mixture_probs))] def get_prior_pdf(x_points, mixture_probs, means, stds): x_pdf = np.zeros(x_points.shape) for mixture_prob, mean, std in zip(mixture_probs, means, stds): x_pdf += scipy.stats.norm.pdf(x_points, loc=mean, scale=std) * mixture_prob return x_pdf, np.array(mixture_probs) class InferenceNetworkPQ(nn.Module): def __init__(self, num_mixtures): super(InferenceNetworkPQ, self).__init__() self.num_mixtures = num_mixtures self.obs_to_z_params = nn.Sequential( nn.Linear(1, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, num_mixtures), nn.Softmax(dim=1) ) self.obs_z_to_x_mean = nn.Sequential( nn.Linear(1 + num_mixtures, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) self.obs_z_to_x_logstd = nn.Sequential( nn.Linear(1 + num_mixtures, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) def get_z_params(self, obs): prob = self.obs_to_z_params(obs.unsqueeze(-1)) return prob def get_x_params(self, obs, z): z_one_hot = Variable(torch.zeros(len(z), self.num_mixtures)).scatter_(1, z.unsqueeze(-1), 1) obs_z = torch.cat([obs.unsqueeze(-1), z_one_hot], dim=1) mean = self.obs_z_to_x_mean(obs_z) std = torch.exp(self.obs_z_to_x_logstd(obs_z)) return mean, std def get_zx_samples(self, obs, num_samples): obs_expanded = Variable(torch.Tensor([obs]).expand(num_samples)) z_prob = self.get_z_params(obs_expanded) z = torch.multinomial(z_prob, 1, True).view(-1) x_mean, x_std = self.get_x_params(obs_expanded, z) x = torch.distributions.Normal(x_mean, x_std).sample().view(-1) return z.data, x.data def get_zx_pdf(self, x_points, obs, num_samples): z, x = self.get_zx_samples(obs, num_samples) z = z.numpy() x = x.numpy() x_kde = scipy.stats.gaussian_kde(x) return x_kde.evaluate(x_points), [np.sum(z == i) / len(z) for i in range(self.num_mixtures)] def forward(self, z, x, obs): z_prob = self.get_z_params(obs) x_mean, x_std = self.get_x_params(obs, z) log_q_z = torch.gather(torch.log(z_prob), 1, z.unsqueeze(-1)) log_q_x = torch.distributions.Normal(x_mean, x_std).log_prob(x.unsqueeze(-1)) return -torch.mean(log_q_z + log_q_x) class InferenceNetworkQP(nn.Module): def __init__(self, mixture_probs, means, stds, obs_std): super(InferenceNetworkQP, self).__init__() self.num_mixtures = len(mixture_probs) self.mixture_probs = mixture_probs self.means = means self.stds = stds self.obs_std = obs_std self.obs_to_z_params = nn.Sequential( nn.Linear(1, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, self.num_mixtures), nn.Softmax(dim=1) ) self.obs_z_to_x_mean = nn.Sequential( nn.Linear(1 + self.num_mixtures, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) self.obs_z_to_x_logstd = nn.Sequential( nn.Linear(1 + self.num_mixtures, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) def get_z_params(self, obs): prob = self.obs_to_z_params(obs.unsqueeze(-1)) return prob def get_x_params(self, obs, z): z_one_hot = Variable(torch.zeros(len(z), self.num_mixtures)).scatter_(1, z.unsqueeze(-1), 1) obs_z = torch.cat([obs.unsqueeze(-1), z_one_hot], dim=1) mean = self.obs_z_to_x_mean(obs_z) std = torch.exp(self.obs_z_to_x_logstd(obs_z)) return mean, std def get_zx_samples(self, obs, num_samples): obs_expanded = Variable(torch.Tensor([obs]).expand(num_samples)) z_prob = self.get_z_params(obs_expanded) z = torch.multinomial(z_prob, 1, True).view(-1) x_mean, x_std = self.get_x_params(obs_expanded, z) x = torch.distributions.Normal(x_mean, x_std).sample().view(-1) return z.data, x.data def get_zx_pdf(self, x_points, obs, num_samples): z, x = self.get_zx_samples(obs, num_samples) z = z.numpy() x = x.numpy() x_kde = scipy.stats.gaussian_kde(x) return x_kde.evaluate(x_points), [np.sum(z == i) / len(z) for i in range(self.num_mixtures)] def forward(self, z_, x_, obs): z_prob = self.get_z_params(obs) z = torch.multinomial(z_prob, 1, True).view(-1) x_mean, x_std = self.get_x_params(obs, z) x = Variable(torch.distributions.Normal( mean=torch.zeros(*obs.size()), std=torch.ones(*obs.size()) ).sample()) * x_std.view(-1) + x_mean.view(-1) prior_x_mean = torch.gather( Variable(torch.Tensor(self.means).unsqueeze(0).expand(len(obs), self.num_mixtures)), 1, z.unsqueeze(-1) ).view(-1) prior_x_std = torch.gather( Variable(torch.Tensor(self.stds).unsqueeze(0).expand(len(obs), self.num_mixtures)), 1, z.unsqueeze(-1) ).view(-1) log_prior_z = torch.gather( Variable(torch.log(torch.Tensor(self.mixture_probs)).unsqueeze(0).expand(len(obs), self.num_mixtures)), 1, z.unsqueeze(-1) ).view(-1) log_prior_x = torch.distributions.Normal(prior_x_mean, prior_x_std).log_prob(x).view(-1) log_lik = torch.distributions.Normal(mean=x, std=self.obs_std).log_prob(obs).view(-1) log_q_z = torch.log(torch.gather(z_prob, 1, z.unsqueeze(-1))).view(-1) log_q_x = torch.distributions.Normal(x_mean, x_std).log_prob(x.unsqueeze(-1)).view(-1) almost_elbo = log_prior_z + log_prior_x + log_lik - log_q_z - log_q_x return -torch.mean(almost_elbo + almost_elbo.detach() * log_q_z) def amortize_inference(mixture_probs, means, stds, obs_std, num_iterations, num_traces, learning_rate, loss_type): loss_history = np.zeros([num_iterations]) if loss_type == 'qp': inference_network = InferenceNetworkQP(mixture_probs, means, stds, obs_std) else: inference_network = InferenceNetworkPQ(len(mixture_probs)) sgd_optimizer = optim.SGD(inference_network.parameters(), lr=learning_rate) for i in range(num_iterations): z, x, obs = generate_traces(num_traces, mixture_probs, means, stds, obs_std) z_var = Variable(torch.from_numpy(z).long()) x_var = Variable(torch.from_numpy(x).float()) obs_var = Variable(torch.from_numpy(obs).float()) sgd_optimizer.zero_grad() if loss_type == 'qp': loss = inference_network(z_var, x_var, obs_var) else: loss = inference_network(z_var, x_var, obs_var) loss.backward() sgd_optimizer.step() loss_history[i] = loss.data[0] return loss_history, inference_network def main(): parser = argparse.ArgumentParser() parser.add_argument('--checkpoint-pq', action='store_true') parser.add_argument('--checkpoint-qp', action='store_true') args = parser.parse_args() mixture_probs = [0.5, 0.5] means = [-5, 5] stds = [1, 1] obs_std = 10 num_iterations = 100000 num_traces = 1000 learning_rate = 0.01 inference_network_pq_filename = 'inference_network_pq.pt' inference_network_qp_filename = 'inference_network_qp.pt' if not args.checkpoint_pq: # Amortize inference loss_type = 'pq' loss_history_pq, inference_network_pq = amortize_inference(mixture_probs, means, stds, obs_std, num_iterations, num_traces, learning_rate, loss_type) for [data, filename] in zip( [loss_history_pq], ['loss_history_pq.npy'] ): np.save(filename, data) print('Saved to {}'.format(filename)) torch.save(inference_network_pq.state_dict(), inference_network_pq_filename) print('Saved to {}'.format(inference_network_pq_filename)) else: # Load loss history filename = 'loss_history_pq.npy' loss_history_pq = np.load(filename) print('Loaded from {}'.format(filename)) # Load inference artifact inference_network_pq = InferenceNetworkPQ(len(mixture_probs)) inference_network_pq.load_state_dict(torch.load(inference_network_pq_filename)) print('Loaded from {}'.format(inference_network_pq_filename)) if not args.checkpoint_qp: # Amortize inference loss_type = 'qp' loss_history_qp, inference_network_qp = amortize_inference(mixture_probs, means, stds, obs_std, num_iterations, num_traces, learning_rate, loss_type) for [data, filename] in zip( [loss_history_qp], ['loss_history_qp.npy'] ): np.save(filename, data) print('Saved to {}'.format(filename)) torch.save(inference_network_qp.state_dict(), inference_network_qp_filename) print('Saved to {}'.format(inference_network_qp_filename)) else: # Load loss history filename = 'loss_history_qp.npy' loss_history_qp = np.load(filename) print('Loaded from {}'.format(filename)) # Load inference artifact inference_network_qp = InferenceNetworkQP(mixture_probs, means, stds, obs_std) inference_network_qp.load_state_dict(torch.load(inference_network_qp_filename)) print('Loaded from {}'.format(inference_network_qp_filename)) # Plot stats iterations = np.arange(num_iterations) fig, axs = plt.subplots(2, 1, sharex=True) fig.set_size_inches(3, 2.5) for ax in axs: ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) axs[0].plot(iterations, loss_history_pq, label='pq loss', color='black') axs[0].set_ylabel('$pq$ loss') axs[1].plot(iterations, loss_history_qp, label='qp loss', color='black') axs[1].set_ylabel('$qp$ loss') filenames = ['gaussian_mixture_nonmarginalized_1.pdf', 'gaussian_mixture_nonmarginalized_1.png'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) # Plot test obs num_test_obs = 5 test_obs_min = -20 test_obs_max = 20 test_obss = np.linspace(test_obs_min, test_obs_max, num=num_test_obs) num_posterior_samples = 10000 num_inference_network_samples = 10000 num_points = 100 min_point = min([-10, test_obs_min]) max_point = max([10, test_obs_max]) bar_width = 0.1 num_barplots = 4 x_points = np.linspace(min_point, max_point, num_points) z_points = np.arange(len(mixture_probs)) fig, axs = plt.subplots(2, num_test_obs, sharey='row') fig.set_size_inches(8, 2.5) for test_obs_idx, test_obs in enumerate(test_obss): x_pdf, z_pdf = get_posterior_pdf(x_points, num_posterior_samples, test_obs, mixture_probs, means, stds, obs_std) x_pq_pdf, z_pq_pdf = inference_network_pq.get_zx_pdf(x_points, test_obs, num_inference_network_samples) x_qp_pdf, z_qp_pdf = inference_network_qp.get_zx_pdf(x_points, test_obs, num_inference_network_samples) x_prior_pdf, z_prior_pdf = get_prior_pdf(x_points, mixture_probs, means, stds) i = 0 axs[0][test_obs_idx].bar(z_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), z_prior_pdf, width=bar_width, color='lightgray', edgecolor='lightgray', fill=True, label='prior') i = 1 axs[0][test_obs_idx].bar(z_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), z_pdf, width=bar_width, color='black', fill=True, label='posterior') i = 2 axs[0][test_obs_idx].bar(z_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), z_qp_pdf, width=bar_width, color='black', fill=False, linestyle='dashed', label='inference network qp') i = 3 axs[0][test_obs_idx].bar(z_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), z_pq_pdf, width=bar_width, color='black', fill=False, linestyle='dotted', label='inference network pq') axs[0][test_obs_idx].set_xticks(z_points) axs[0][test_obs_idx].set_yticks([]) axs[0][test_obs_idx].set_ylim(0, 1) axs[0][test_obs_idx].xaxis.set_ticks_position('none') axs[0][test_obs_idx].spines['top'].set_visible(False) axs[0][test_obs_idx].spines['right'].set_visible(False) axs[0][0].set_ylabel('$z$') axs[0][0].set_yticks([0, 1]) axs[1][test_obs_idx].plot(x_points, x_prior_pdf, color='lightgray', label='prior') axs[1][test_obs_idx].plot(x_points, x_pdf, color='black', label='posterior') axs[1][test_obs_idx].plot(x_points, x_qp_pdf, color='black', linestyle='dashed', label='inference network qp') axs[1][test_obs_idx].plot(x_points, x_pq_pdf, color='black', linestyle='dotted', label='inference network pq') axs[1][test_obs_idx].scatter(x=test_obs, y=0, color='black', label='test obs', marker='x') axs[1][test_obs_idx].set_yticks([]) axs[1][test_obs_idx].spines['top'].set_visible(False) axs[1][test_obs_idx].spines['right'].set_visible(False) axs[1][0].set_ylabel('$x$') axs[1][test_obs_idx // 2].legend(loc='upper center', bbox_to_anchor=(0.5, -0.3), ncol=5, fontsize='small') fig.tight_layout() filenames = ['gaussian_mixture_nonmarginalized_2.pdf', 'gaussian_mixture_nonmarginalized_2.png'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) # Plot inference network mappings num_test_obs = 100 test_obs_min = -500 test_obs_max = 500 test_obss = np.linspace(test_obs_min, test_obs_max, num=num_test_obs) z_prob_pq = np.zeros([num_test_obs]) z_prob_qp = np.zeros([num_test_obs]) x_mean_given_z1_qp = np.zeros([num_test_obs]) x_std_given_z1_qp = np.zeros([num_test_obs]) x_mean_given_z2_qp = np.zeros([num_test_obs]) x_std_given_z2_qp = np.zeros([num_test_obs]) x_mean_given_z1_pq = np.zeros([num_test_obs]) x_std_given_z1_pq = np.zeros([num_test_obs]) x_mean_given_z2_pq = np.zeros([num_test_obs]) x_std_given_z2_pq = np.zeros([num_test_obs]) for test_obs_idx, test_obs in enumerate(test_obss): z_prob_qp[test_obs_idx] = inference_network_qp.get_z_params(Variable(torch.Tensor([test_obs]))).data[0, 0] x_mean_given_z1_qp_var, x_std_given_z1_qp_var = inference_network_qp.get_x_params( Variable(torch.Tensor([test_obs])), Variable(torch.LongTensor([0])) ) x_mean_given_z2_qp_var, x_std_given_z2_qp_var = inference_network_qp.get_x_params( Variable(torch.Tensor([test_obs])), Variable(torch.LongTensor([1])) ) x_mean_given_z1_qp[test_obs_idx] = x_mean_given_z1_qp_var.data[0] x_std_given_z1_qp[test_obs_idx] = x_std_given_z1_qp_var.data[0] x_mean_given_z2_qp[test_obs_idx] = x_mean_given_z2_qp_var.data[0] x_std_given_z2_qp[test_obs_idx] = x_std_given_z2_qp_var.data[0] z_prob_pq[test_obs_idx] = inference_network_pq.get_z_params(Variable(torch.Tensor([test_obs]))).data[0, 0] x_mean_given_z1_pq_var, x_std_given_z1_pq_var = inference_network_pq.get_x_params( Variable(torch.Tensor([test_obs])), Variable(torch.LongTensor([0])) ) x_mean_given_z2_pq_var, x_std_given_z2_pq_var = inference_network_pq.get_x_params( Variable(torch.Tensor([test_obs])), Variable(torch.LongTensor([1])) ) x_mean_given_z1_pq[test_obs_idx] = x_mean_given_z1_pq_var.data[0] x_std_given_z1_pq[test_obs_idx] = x_std_given_z1_pq_var.data[0] x_mean_given_z2_pq[test_obs_idx] = x_mean_given_z2_pq_var.data[0] x_std_given_z2_pq[test_obs_idx] = x_std_given_z2_pq_var.data[0] fig, axs = plt.subplots(5, 1, sharex=True) fig.set_size_inches(3, 5.5) for ax in axs: ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) axs[0].plot(test_obss, z_prob_qp, color='black', linestyle='dashed', label='inference network qp') axs[0].plot(test_obss, z_prob_pq, color='black', linestyle='dotted', label='pq') axs[0].set_ylabel('$q(z = 0)$', fontsize='small') axs[1].plot(test_obss, x_mean_given_z1_qp, color='black', linestyle='dashed', label='inference network qp') axs[1].plot(test_obss, x_mean_given_z1_pq, color='black', linestyle='dotted', label='pq') axs[1].set_ylabel('mean of\n$q(x | z = 0)$', fontsize='small') axs[2].plot(test_obss, x_std_given_z1_qp, color='black', linestyle='dashed', label='inference network qp') axs[2].plot(test_obss, x_std_given_z1_pq, color='black', linestyle='dotted', label='pq') axs[2].set_ylabel('std of\n$q(x | z = 0)$', fontsize='small') axs[3].plot(test_obss, x_mean_given_z2_qp, color='black', linestyle='dashed', label='inference network qp') axs[3].plot(test_obss, x_mean_given_z2_pq, color='black', linestyle='dotted', label='pq') axs[3].set_ylabel('mean of\n$q(x | z = 1)$', fontsize='small') axs[4].plot(test_obss, x_std_given_z2_qp, color='black', linestyle='dashed', label='qp') axs[4].plot(test_obss, x_std_given_z2_pq, color='black', linestyle='dotted', label='pq') axs[4].set_ylabel('std of\n$q(x | z = 1)$', fontsize='small') axs[-1].set_xlabel('test obs') axs[-1].legend(fontsize='small', loc='upper center', bbox_to_anchor=(0.5, -0.55), ncol=2) filenames = ['gaussian_mixture_nonmarginalized_3.pdf', 'gaussian_mixture_nonmarginalized_3.png'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) if __name__ == '__main__': main()