import argparse import numpy as np import matplotlib.pyplot as plt import scipy.stats import torch import torch.nn as nn import torch.distributions import torch.nn.functional as F from torch.autograd import Variable import torch.optim as optim from numbers import Number def logsumexp(values, dim=0, keepdim=False): """Logsumexp of a Tensor/Variable. See https://en.wikipedia.org/wiki/LogSumExp. input: values: Tensor/Variable [dim_1, ..., dim_N] dim: n output: result Tensor/Variable [dim_1, ..., dim_{n - 1}, dim_{n + 1}, ..., dim_N] where result[i_1, ..., i_{n - 1}, i_{n + 1}, ..., i_N] = log(sum_{i_n = 1}^N exp(values[i_1, ..., i_N])) """ values_max, _ = torch.max(values, dim=dim, keepdim=True) result = values_max + torch.log(torch.sum( torch.exp(values - values_max), dim=dim, keepdim=True )) return result if keepdim else result.squeeze(dim) def lognormexp(values, dim=0): """Exponentiates, normalizes and takes log of a Tensor/Variable/np.ndarray. input: values: Tensor/Variable/np.ndarray [dim_1, ..., dim_N] dim: n output: result: Tensor/Variable/np.ndarray [dim_1, ..., dim_N] where result[i_1, ..., i_N] = exp(values[i_1, ..., i_N]) log( ------------------------------------------------------------ ) sum_{j = 1}^{dim_n} exp(values[i_1, ..., j, ..., i_N]) """ if isinstance(values, np.ndarray): log_denominator = scipy.special.logsumexp( values, axis=dim, keepdims=True ) # log_numerator = values return values - log_denominator else: log_denominator = logsumexp(values, dim=dim, keepdim=True) # log_numerator = values return values - log_denominator def exponentiate_and_normalize(values, dim=0): """Exponentiates and normalizes a Tensor/Variable/np.ndarray. input: values: Tensor/Variable/np.ndarray [dim_1, ..., dim_N] dim: n output: result: Tensor/Variable/np.ndarray [dim_1, ..., dim_N] where result[i_1, ..., i_N] = exp(values[i_1, ..., i_N]) ------------------------------------------------------------ sum_{j = 1}^{dim_n} exp(values[i_1, ..., j, ..., i_N]) """ if isinstance(values, np.ndarray): return np.exp(lognormexp(values, dim=dim)) else: return torch.exp(lognormexp(values, dim=dim)) def sample_ancestral_index(log_weight, cuda=False): """Sample ancestral index using systematic resampling. input: log_weight: log of unnormalized weights, Tensor/Variable [batch_size, num_particles] output: zero-indexed ancestral index: LongTensor/Variable [batch_size, num_particles] """ if ( isinstance(log_weight, Variable) and torch.sum(log_weight != log_weight).data[0] != 0 ) or ( (not isinstance(log_weight, Variable)) and torch.sum(log_weight != log_weight) != 0 ): raise FloatingPointError('log_weight contains nan element(s)') batch_size, num_particles = log_weight.size() indices = np.zeros([batch_size, num_particles]) uniforms = np.random.uniform(size=[batch_size, 1]) pos = (uniforms + np.arange(0, num_particles)) / num_particles if isinstance(log_weight, Variable): normalized_weights = exponentiate_and_normalize( log_weight.cpu().data.numpy(), dim=1 ) else: normalized_weights = exponentiate_and_normalize( log_weight.cpu().numpy(), dim=1 ) # np.ndarray [batch_size, num_particles] cumulative_weights = np.cumsum(normalized_weights, axis=1) # hack to prevent numerical issues cumulative_weights = cumulative_weights / np.max( cumulative_weights, axis=1, keepdims=True ) for batch in range(batch_size): indices[batch] = np.digitize(pos[batch], cumulative_weights[batch]) if cuda: temp = torch.from_numpy(indices).long().cuda() else: temp = torch.from_numpy(indices).long() if isinstance(log_weight, Variable): return Variable(temp) else: return temp def generate_traces(num_traces, num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std, generate_obs=True): traces = [] for trace_idx in range(num_traces): trace = [] num_traces = np.random.choice(np.array([1, 2], dtype=float), replace=True, p=num_clusters_probs) trace.append(num_traces) if num_traces == 1: x = np.random.normal(mean_1, std_1) trace.append(x) else: z = np.random.choice(np.arange(len(mixture_probs), dtype=float), replace=True, p=mixture_probs) trace.append(z) x = np.random.normal(means_2[int(z)], stds_2[int(z)]) trace.append(x) if generate_obs: y = np.random.normal(x, obs_std) trace.append(y) traces.append(trace) return traces def get_pdf_from_traces(traces_without_obs, k_points, z_points, x_points): ks = [] zs = [] xs = [] for trace in traces_without_obs: if len(trace) == 2: ks.append(trace[0]) xs.append(trace[1]) else: ks.append(trace[0]) zs.append(trace[1]) xs.append(trace[2]) return [np.sum(np.array(ks) == i) / len(ks) for i in k_points], \ [np.sum(np.array(zs) == i) / len(zs) for i in z_points], \ scipy.stats.gaussian_kde(xs).evaluate(x_points) def get_prior_pdf(x_points, num_samples, num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std): traces = generate_traces(num_samples, num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std, False) return get_pdf_from_traces(traces, range(1, 1 + len(num_clusters_probs)), range(len(mixture_probs)), x_points) def get_posterior_pdf(x_points, num_samples, obs, num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std, num_importance_samples=None): if num_importance_samples is None: num_importance_samples = num_samples traces = generate_traces(num_importance_samples, num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std, False) log_weights = np.zeros([num_importance_samples]) for trace_idx, trace in enumerate(traces): log_weights[trace_idx] = scipy.stats.norm.logpdf(obs, trace[-1], obs_std) weights = np.exp(lognormexp(np.array(log_weights))) resampled_traces = np.random.choice(traces, size=num_samples, replace=True, p=weights) return get_pdf_from_traces(resampled_traces, range(1, 1 + len(num_clusters_probs)), range(len(mixture_probs)), x_points) class InferenceNetworkPQ(nn.Module): def __init__(self, num_clusters_max, num_mixtures): super(InferenceNetworkPQ, self).__init__() self.num_clusters_max = num_clusters_max self.num_mixtures = num_mixtures self.obs_to_k_params = nn.Sequential( nn.Linear(1, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, num_clusters_max), nn.Softmax(dim=1) ) self.obs_k_to_x_mean = nn.Sequential( nn.Linear(1, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) self.obs_k_to_x_logstd = nn.Sequential( nn.Linear(1, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) self.obs_k_to_z_params = nn.Sequential( nn.Linear(1, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, num_mixtures), nn.Softmax(dim=1) ) self.obs_k_z_to_x_mean = nn.Sequential( nn.Linear(1 + num_mixtures, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) self.obs_k_z_to_x_logstd = nn.Sequential( nn.Linear(1 + num_mixtures, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) def get_k_params(self, obs): prob = self.obs_to_k_params(obs.unsqueeze(-1)) return prob def get_x_params_from_obs_k(self, obs, k): mean = self.obs_k_to_x_mean(obs.unsqueeze(-1)) std = torch.exp(self.obs_k_to_x_logstd(obs.unsqueeze(-1))) return mean, std def get_z_params_from_obs_k(self, obs, k): prob = self.obs_k_to_z_params(obs.unsqueeze(-1)) return prob def get_x_params_from_obs_k_z(self, obs, k, z): z_one_hot = Variable(torch.zeros(len(z), self.num_mixtures)).scatter_(1, z.long().unsqueeze(-1), 1) obs_z = torch.cat([obs.unsqueeze(-1), z_one_hot], dim=1) mean = self.obs_k_z_to_x_mean(obs_z) std = torch.exp(self.obs_k_z_to_x_logstd(obs_z)) return mean, std def sort_traces(self, traces): short_traces = [] long_traces = [] for trace in traces: if len(trace) == 3: short_traces.append(trace) else: long_traces.append(trace) short_traces_tensor = torch.Tensor(short_traces) long_traces_tensor = torch.Tensor(long_traces) k_short, x_short, obs_short = short_traces_tensor.t() k_long, z_long, x_long, obs_long = long_traces_tensor.t() return Variable(k_short), Variable(x_short), Variable(obs_short), Variable(k_long), Variable(z_long), Variable(x_long), Variable(obs_long) def get_traces(self, obs, num_samples): obs_var = Variable(torch.Tensor([obs])) traces = [] for sample_idx in range(num_samples): trace = [] k_prob = self.get_k_params(obs_var) k = torch.multinomial(k_prob, 1, True).view(-1) + 1 trace.append(k.data[0]) if k.data[0] == 1: x_mean, x_std = self.get_x_params_from_obs_k(obs_var, k) x = torch.distributions.Normal(x_mean, x_std).sample().view(-1) trace.append(x.data[0]) else: z_prob = self.get_z_params_from_obs_k(obs_var, k) z = torch.multinomial(z_prob, 1, True).view(-1) trace.append(z.data[0]) x_mean, x_std = self.get_x_params_from_obs_k_z(obs_var, k, z) x = torch.distributions.Normal(x_mean, x_std).sample().view(-1) trace.append(x.data[0]) traces.append(trace) return traces def get_pdf(self, x_points, obs, num_samples): traces = self.get_traces(obs, num_samples) return get_pdf_from_traces(traces, range(1, 1 + self.num_clusters_max), range(self.num_mixtures), x_points) def forward(self, traces): k_short, x_short, obs_short, k_long, z_long, x_long, obs_long = self.sort_traces(traces) k_short_prob = self.get_k_params(obs_short) x_short_mean, x_short_std = self.get_x_params_from_obs_k(obs_short, k_short) log_q_k_short = torch.gather(torch.log(k_short_prob), 1, k_short.long().unsqueeze(-1) - 1) log_q_x_short = torch.distributions.Normal(x_short_mean, x_short_std).log_prob(x_short.unsqueeze(-1)) k_long_prob = self.get_k_params(obs_long) z_long_prob = self.get_z_params_from_obs_k(obs_long, k_long) x_long_mean, x_long_std = self.get_x_params_from_obs_k_z(obs_long, k_long, z_long) log_q_k_long = torch.gather(torch.log(k_long_prob), 1, k_long.long().unsqueeze(-1) - 1) log_q_z_long = torch.gather(torch.log(z_long_prob), 1, z_long.long().unsqueeze(-1)) log_q_x_long = torch.distributions.Normal(x_long_mean, x_long_std).log_prob(x_long.unsqueeze(-1)) return -(torch.sum(log_q_k_short + log_q_x_short) + torch.sum(log_q_k_long + log_q_z_long + log_q_x_long)) / len(traces) class InferenceNetworkQP(nn.Module): def __init__(self, num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std): super(InferenceNetworkQP, self).__init__() self.num_clusters_max = len(num_clusters_probs) self.num_mixtures = len(mixture_probs) self.num_clusters_probs = num_clusters_probs self.mean_1 = mean_1 self.std_1 = std_1 self.mixture_probs = mixture_probs self.means_2 = means_2 self.stds_2 = stds_2 self.obs_std = obs_std self.obs_to_k_params = nn.Sequential( nn.Linear(1, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, self.num_clusters_max), nn.Softmax(dim=1) ) self.obs_k_to_x_mean = nn.Sequential( nn.Linear(1, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) self.obs_k_to_x_logstd = nn.Sequential( nn.Linear(1, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) self.obs_k_to_z_params = nn.Sequential( nn.Linear(1, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, self.num_mixtures), nn.Softmax(dim=1) ) self.obs_k_z_to_x_mean = nn.Sequential( nn.Linear(1 + self.num_mixtures, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) self.obs_k_z_to_x_logstd = nn.Sequential( nn.Linear(1 + self.num_mixtures, 8), nn.Tanh(), nn.Linear(8, 8), nn.Tanh(), nn.Linear(8, 1) ) def get_k_params(self, obs): prob = self.obs_to_k_params(obs.unsqueeze(-1)) return prob def get_x_params_from_obs_k(self, obs, k): mean = self.obs_k_to_x_mean(obs.unsqueeze(-1)) std = torch.exp(self.obs_k_to_x_logstd(obs.unsqueeze(-1))) return mean, std def get_z_params_from_obs_k(self, obs, k): prob = self.obs_k_to_z_params(obs.unsqueeze(-1)) return prob def get_x_params_from_obs_k_z(self, obs, k, z): z_one_hot = Variable(torch.zeros(len(z), self.num_mixtures)).scatter_(1, z.long().unsqueeze(-1), 1) obs_z = torch.cat([obs.unsqueeze(-1), z_one_hot], dim=1) mean = self.obs_k_z_to_x_mean(obs_z) std = torch.exp(self.obs_k_z_to_x_logstd(obs_z)) return mean, std def sort_traces(self, traces): short_traces = [] long_traces = [] for trace in traces: if len(trace) == 3: short_traces.append(trace) else: long_traces.append(trace) short_traces_tensor = torch.Tensor(short_traces) long_traces_tensor = torch.Tensor(long_traces) k_short, x_short, obs_short = short_traces_tensor.t() k_long, z_long, x_long, obs_long = long_traces_tensor.t() return Variable(k_short), Variable(x_short), Variable(obs_short), Variable(k_long), Variable(z_long), Variable(x_long), Variable(obs_long) def get_traces(self, obs, num_samples): obs_var = Variable(torch.Tensor([obs])) traces = [] for sample_idx in range(num_samples): trace = [] k_prob = self.get_k_params(obs_var) k = torch.multinomial(k_prob, 1, True).view(-1) + 1 trace.append(k.data[0]) if k.data[0] == 1: x_mean, x_std = self.get_x_params_from_obs_k(obs_var, k) x = torch.distributions.Normal(x_mean, x_std).sample().view(-1) trace.append(x.data[0]) else: z_prob = self.get_z_params_from_obs_k(obs_var, k) z = torch.multinomial(z_prob, 1, True).view(-1) trace.append(z.data[0]) x_mean, x_std = self.get_x_params_from_obs_k_z(obs_var, k, z) x = torch.distributions.Normal(x_mean, x_std).sample().view(-1) trace.append(x.data[0]) traces.append(trace) return traces def get_pdf(self, x_points, obs, num_samples): traces = self.get_traces(obs, num_samples) return get_pdf_from_traces(traces, range(1, 1 + self.num_clusters_max), range(self.num_mixtures), x_points) def forward(self, obs): k_prob = self.get_k_params(obs) k = torch.multinomial(k_prob, 1, True).view(-1) + 1 obs_long, obs_short, k_long, k_short = obs[k == 2], obs[k == 1], k[k == 2], k[k == 1] log_q_k = torch.log(torch.gather(k_prob, 1, k.unsqueeze(-1) - 1)).view(-1) loss = Variable(torch.Tensor([0])) if len(obs_long) > 0: # long traces z_long_prob = self.get_z_params_from_obs_k(obs_long, k_long) z_long = torch.multinomial(z_long_prob, 1, True).view(-1) x_long_mean, x_long_std = self.get_x_params_from_obs_k_z(obs_long, k_long, z_long) x_long = Variable(torch.distributions.Normal( mean=torch.zeros(*obs_long.size()), std=torch.ones(*obs_long.size()) ).sample()) * x_long_std.view(-1) + x_long_mean.view(-1) prior_x_long_mean = torch.gather( Variable(torch.Tensor(self.means_2).unsqueeze(0).expand(len(obs_long), self.num_mixtures)), 1, z_long.unsqueeze(-1) ).view(-1) prior_x_long_std = torch.gather( Variable(torch.Tensor(self.stds_2).unsqueeze(0).expand(len(obs_long), self.num_mixtures)), 1, z_long.unsqueeze(-1) ).view(-1) log_prior_k_long = torch.gather( Variable(torch.log(torch.Tensor(self.num_clusters_probs)).unsqueeze(0).expand(len(obs_long), self.num_clusters_max)), 1, k_long.unsqueeze(-1) - 1 ).view(-1) log_prior_z_long = torch.gather( Variable(torch.log(torch.Tensor(self.mixture_probs)).unsqueeze(0).expand(len(obs_long), self.num_mixtures)), 1, z_long.unsqueeze(-1) ).view(-1) log_prior_x_long = torch.distributions.Normal(prior_x_long_mean, prior_x_long_std).log_prob(x_long).view(-1) log_lik_long = torch.distributions.Normal(mean=x_long, std=self.obs_std).log_prob(obs_long).view(-1) log_q_k_long = log_q_k[k == 2] log_q_z_long = torch.log(torch.gather(z_long_prob, 1, z_long.unsqueeze(-1))).view(-1) log_q_x_long = torch.distributions.Normal(x_long_mean, x_long_std).log_prob(x_long.unsqueeze(-1)).view(-1) long_elbo = log_prior_k_long + log_prior_z_long + log_prior_x_long + log_lik_long - log_q_k_long - log_q_z_long - log_q_x_long loss = loss - torch.sum(long_elbo + long_elbo.detach() * (log_q_k_long + log_q_z_long)) / len(obs) if len(obs_short) > 0: # short traces x_short_mean, x_short_std = self.get_x_params_from_obs_k(obs_short, k_short) x_short = Variable(torch.distributions.Normal( mean=torch.zeros(*obs_short.size()), std=torch.ones(*obs_short.size()) ).sample()) * x_short_std.view(-1) + x_short_mean.view(-1) log_prior_k_short = torch.gather( Variable(torch.log(torch.Tensor(self.num_clusters_probs)).unsqueeze(0).expand(len(obs_short), self.num_clusters_max)), 1, k_short.unsqueeze(-1) - 1 ).view(-1) log_prior_x_short = torch.distributions.Normal(self.mean_1, self.std_1).log_prob(x_short).view(-1) log_lik_short = torch.distributions.Normal(mean=x_short, std=self.obs_std).log_prob(obs_short).view(-1) log_q_k_short = log_q_k[k == 1] log_q_x_short = torch.distributions.Normal(x_short_mean, x_short_std).log_prob(x_short.unsqueeze(-1)).view(-1) short_elbo = log_prior_k_short + log_prior_x_short + log_lik_short - log_q_k_short - log_q_x_short loss = loss - torch.sum(short_elbo + short_elbo.detach() * log_q_k_short) / len(obs) return loss def amortize_inference(num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std, num_iterations, num_traces, learning_rate, loss_type): loss_history = np.zeros([num_iterations]) if loss_type == 'qp': inference_network = InferenceNetworkQP(num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std) else: inference_network = InferenceNetworkPQ(len(num_clusters_probs), len(mixture_probs)) # optimizer = optim.SGD(inference_network.parameters(), lr=learning_rate) optimizer = optim.Adam(inference_network.parameters(), lr=learning_rate) for i in range(num_iterations): traces = generate_traces(num_traces, num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std) optimizer.zero_grad() if loss_type == 'qp': obs = Variable(torch.Tensor([trace[-1] for trace in traces])) loss = inference_network(obs) else: loss = inference_network(traces) loss.backward() optimizer.step() loss_history[i] = loss.data[0] if i % 100 == 0: print('Iteration {}: Loss = {}'.format(i, loss_history[i])) return loss_history, inference_network def main(): parser = argparse.ArgumentParser() parser.add_argument('--checkpoint-qp', action='store_true') parser.add_argument('--checkpoint-pq', action='store_true') args = parser.parse_args() num_clusters_probs = [0.5, 0.5] mean_1 = 0 std_1 = 1 mixture_probs = [0.5, 0.5] means_2 = [-5, 5] stds_2 = [1, 1] obs_std = 1 num_iterations = 10000 num_traces = 1000 learning_rate = 0.001 inference_network_pq_filename = 'inference_network_pq.pt' inference_network_qp_filename = 'inference_network_qp.pt' if not args.checkpoint_pq: # Amortize inference loss_type = 'pq' loss_history_pq, inference_network_pq = amortize_inference(num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std, num_iterations, num_traces, learning_rate, loss_type) for [data, filename] in zip( [loss_history_pq], ['loss_history_pq.npy'] ): np.save(filename, data) print('Saved to {}'.format(filename)) torch.save(inference_network_pq.state_dict(), inference_network_pq_filename) print('Saved to {}'.format(inference_network_pq_filename)) else: # Load inference artifact filename = 'loss_history_pq.npy' loss_history_pq = np.load(filename) print('Loaded from {}'.format(filename)) inference_network_pq = InferenceNetworkPQ(len(num_clusters_probs), len(mixture_probs)) inference_network_pq.load_state_dict(torch.load(inference_network_pq_filename)) print('Loaded from {}'.format(inference_network_pq_filename)) if not args.checkpoint_qp: # Amortize inference loss_type = 'qp' loss_history_qp, inference_network_qp = amortize_inference(num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std, num_iterations, num_traces, learning_rate, loss_type) for [data, filename] in zip( [loss_history_qp], ['loss_history_qp.npy'] ): np.save(filename, data) print('Saved to {}'.format(filename)) torch.save(inference_network_qp.state_dict(), inference_network_qp_filename) print('Saved to {}'.format(inference_network_qp_filename)) else: # Load loss history filename = 'loss_history_qp.npy' loss_history_qp = np.load(filename) print('Loaded from {}'.format(filename)) # Load inference artifact inference_network_qp = InferenceNetworkQP(num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std) inference_network_qp.load_state_dict(torch.load(inference_network_qp_filename)) print('Loaded from {}'.format(inference_network_qp_filename)) # Plot stats iterations = np.arange(num_iterations) fig, axs = plt.subplots(2, 1, sharex=True) fig.set_size_inches(3, 2.5) for ax in axs: ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) axs[0].plot(iterations, loss_history_pq, label='pq loss', color='black') axs[0].set_ylabel('$pq$ loss') axs[1].plot(iterations, loss_history_qp, label='qp loss', color='black') axs[1].set_ylabel('$qp$ loss') filenames = ['gaussian_mixture_openuniverse_1.pdf', 'gaussian_mixture_openuniverse_1.png'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) # Plot test obs num_test_obs = 5 test_obs_min = -7 test_obs_max = 7 test_obss = np.linspace(test_obs_min, test_obs_max, num=num_test_obs) num_prior_samples = 1000 num_posterior_samples = 10000 num_inference_network_samples = 10000 num_points = 100 min_point = min([-10, test_obs_min]) max_point = max([10, test_obs_max]) bar_width = 0.1 num_barplots = 4 x_points = np.linspace(min_point, max_point, num_points) z_points = np.arange(len(mixture_probs)) k_points = np.arange(len(num_clusters_probs)) + 1 fig, axs = plt.subplots(3, num_test_obs, sharey='row') fig.set_size_inches(8, 3.25) for axs_ in axs: for ax in axs_: ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) for test_obs_idx, test_obs in enumerate(test_obss): k_prior_pdf, z_prior_pdf, x_prior_pdf = get_prior_pdf(x_points, num_prior_samples, num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std) k_posterior_pdf, z_posterior_pdf, x_posterior_pdf = get_posterior_pdf(x_points, num_posterior_samples, test_obs, num_clusters_probs, mean_1, std_1, mixture_probs, means_2, stds_2, obs_std) k_pq_pdf, z_pq_pdf, x_pq_pdf = inference_network_pq.get_pdf(x_points, test_obs, num_inference_network_samples) k_qp_pdf, z_qp_pdf, x_qp_pdf = inference_network_qp.get_pdf(x_points, test_obs, num_inference_network_samples) i = 0 axs[0][test_obs_idx].bar(k_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), k_prior_pdf, width=bar_width, color='lightgray', edgecolor='lightgray', fill=True, label='prior') i = 1 axs[0][test_obs_idx].bar(k_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), k_posterior_pdf, width=bar_width, color='black', edgecolor='black', fill=True, label='posterior') i = 2 axs[0][test_obs_idx].bar(k_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), k_qp_pdf, width=bar_width, color='black', fill=False, linestyle='dashed', label='inference network qp') i = 3 axs[0][test_obs_idx].bar(k_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), k_pq_pdf, width=bar_width, color='black', fill=False, linestyle='dotted', label='inference network pq') axs[0][test_obs_idx].set_xticks(k_points) axs[0][test_obs_idx].set_ylim([0, 1]) axs[0][test_obs_idx].set_yticks([0, 1]) axs[0][0].set_ylabel('k') i = 0 axs[1][test_obs_idx].bar(z_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), z_prior_pdf, width=bar_width, color='lightgray', edgecolor='lightgray', fill=True, label='prior') i = 1 axs[1][test_obs_idx].bar(z_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), z_posterior_pdf, width=bar_width, color='black', edgecolor='black', fill=True, label='posterior') i = 2 axs[1][test_obs_idx].bar(z_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), z_qp_pdf, width=bar_width, color='black', fill=False, linestyle='dashed', label='inference network qp') i = 3 axs[1][test_obs_idx].bar(z_points + 0.5 * bar_width * (2 * i + 1 - num_barplots), z_pq_pdf, width=bar_width, color='black', fill=False, linestyle='dotted', label='inference network pq') axs[1][test_obs_idx].set_xticks(z_points) axs[1][test_obs_idx].set_ylim([0, 1]) axs[1][test_obs_idx].set_yticks([0, 1]) axs[1][0].set_ylabel('z') axs[2][test_obs_idx].plot(x_points, x_prior_pdf, color='lightgray', label='prior') axs[2][test_obs_idx].plot(x_points, x_posterior_pdf, color='black', label='posterior') axs[2][test_obs_idx].plot(x_points, x_qp_pdf, color='black', linestyle='dashed', label='inference network qp') axs[2][test_obs_idx].plot(x_points, x_pq_pdf, color='black', linestyle='dotted', label='inference network pq') axs[2][test_obs_idx].scatter(x=test_obs, y=0, color='black', label='test obs', marker='x') axs[2][test_obs_idx].set_yticks([]) axs[2][0].set_ylabel('x') axs[-1][test_obs_idx // 2].legend(loc='upper center', bbox_to_anchor=(0.5, -0.35), ncol=5, fontsize='small') fig.tight_layout() filenames = ['gaussian_mixture_openuniverse_2.pdf', 'gaussian_mixture_openuniverse_2.png'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) if __name__ == '__main__': main()