from torch.autograd import Variable import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim def isnan(x): if isinstance(x, Variable): return np.isnan(torch.sum(x).data[0]) elif torch.is_tensor(x): return np.isnan(torch.sum(x)) elif isinstance(x, nn.Module): for param in x.parameters(): if isnan(param): return True return False def logit(x): return torch.log(x) - torch.log(1 - x) def safe_sigmoid(x, epsilon=1e-6): return F.sigmoid(x) * (1 - 2 * epsilon) + epsilon def heaviside(x): return x >= 0 def reparam(u, theta, epsilon=1e-6): return torch.log(theta + epsilon) - torch.log(1 - theta + epsilon) + torch.log(u + epsilon) - torch.log(1 - u + epsilon) def conditional_reparam(v, theta, b, epsilon=1e-6): if b.data[0] == 1: return torch.log(v / ((1 - v) * (1 - theta)) + 1 + epsilon) else: return -torch.log(v / ((1 - v) * theta) + 1 + epsilon) def bernoulli_logpdf(b, theta, epsilon=1e-6): if b.data[0] == 1: return torch.log(theta + epsilon) else: return torch.log(1 - theta + epsilon) def continuous_relaxation(z, temperature, epsilon=1e-6): return 1 / (1 + torch.exp(-z / (temperature + epsilon))) class RebarControlVariate(nn.Module): def __init__(self, f, temperature_init=1): super(RebarControlVariate, self).__init__() self.f = f self.log_temperature = nn.Parameter(torch.log(torch.Tensor([temperature_init]))) self.log_multiplier = nn.Parameter(torch.Tensor([0])) # self.log_multiplier = Variable(torch.Tensor([0])) def forward(self, z): return torch.exp(self.log_multiplier) * self.f(continuous_relaxation(z, torch.exp(self.log_temperature))) class RelaxControlVariate(nn.Module): def __init__(self, f, temperature_init=1): super(RelaxControlVariate, self).__init__() self.f = f self.log_temperature = nn.Parameter(torch.log(torch.Tensor([temperature_init]))) # self.log_multiplier = Variable(torch.Tensor([0])) # self.log_multiplier = nn.Parameter(torch.Tensor([0])) self.surrogate = nn.Sequential( nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 5), nn.ReLU(), nn.Linear(5, 1), nn.ReLU() ) def forward(self, z): # return torch.exp(self.log_multiplier) * self.f(continuous_relaxation(z, torch.exp(self.log_temperature))) + \ # self.surrogate(z.unsqueeze(-1)).squeeze(-1) return self.f(continuous_relaxation(z, torch.exp(self.log_temperature))) + \ self.surrogate(z.unsqueeze(-1)).squeeze(-1) def reinforce_estimator(f, theta, num_samples): theta_expanded = theta.expand(num_samples) b = torch.bernoulli(theta_expanded).detach() return f(b) * bernoulli_logpdf(b, theta) def relax_estimator(f, c, theta, num_samples): theta_expanded = theta.expand(num_samples) u = Variable(torch.rand(num_samples)) v = Variable(torch.rand(num_samples)) z = reparam(u, theta_expanded) b = heaviside(z) z_tilde = conditional_reparam(v, theta_expanded, b) return (f(b) - c(z_tilde)).detach() * bernoulli_logpdf(b, theta_expanded) + c(z) - c(z_tilde) def train_reinforce(init_theta, f, num_iterations, learning_rate=0.01): pre_sigmoid_theta = Variable(logit(torch.Tensor([init_theta])), requires_grad=True) theta_history = np.zeros([num_iterations]) pre_sigmoid_theta_optimizer = optim.Adam([pre_sigmoid_theta], lr=learning_rate) for i in range(num_iterations): pre_sigmoid_theta_optimizer.zero_grad() re = reinforce_estimator(f, safe_sigmoid(pre_sigmoid_theta), 1) re.backward() pre_sigmoid_theta_optimizer.step() if isnan(pre_sigmoid_theta): print('Iteration {}: theta is nan'.format(i)) break theta_history[i] = safe_sigmoid(pre_sigmoid_theta).data[0] if i % 1000 == 0: print('Reinforce iteration {}'.format(i)) return theta_history def train_rebar(init_theta, f, num_iterations, learning_rate=0.01): pre_sigmoid_theta = Variable(logit(torch.Tensor([init_theta])), requires_grad=True) c = RebarControlVariate(f) theta_history = np.zeros([num_iterations]) temperature_history = np.zeros([num_iterations]) pre_sigmoid_theta_optimizer = optim.Adam([pre_sigmoid_theta], lr=learning_rate) c_optimizer = optim.Adam(c.parameters(), lr=learning_rate) for i in range(num_iterations): pre_sigmoid_theta_optimizer.zero_grad() re = relax_estimator(f, c, safe_sigmoid(pre_sigmoid_theta), 1) re.backward(create_graph=True) pre_sigmoid_theta_grad_detached = pre_sigmoid_theta.grad.detach() pre_sigmoid_theta_optimizer.step() if isnan(pre_sigmoid_theta): print('Iteration {}: pre_sigmoid_theta is nan'.format(i)) break c_optimizer.zero_grad() pre_sigmoid_theta.grad.backward(2 * pre_sigmoid_theta_grad_detached / len(pre_sigmoid_theta_grad_detached)) c_optimizer.step() if isnan(c): print('Iteration {}: c is nan'.format(i)) break theta_history[i] = safe_sigmoid(pre_sigmoid_theta).data[0] temperature_history[i] = torch.exp(c.log_temperature).data[0] if i % 1000 == 0: print('Rebar iteration {}'.format(i)) return pre_sigmoid_theta, c, theta_history, temperature_history def train_relax(init_theta, f, num_iterations, learning_rate=0.01): pre_sigmoid_theta = Variable(logit(torch.Tensor([init_theta])), requires_grad=True) c = RelaxControlVariate(f) theta_history = np.zeros([num_iterations]) temperature_history = np.zeros([num_iterations]) pre_sigmoid_theta_optimizer = optim.Adam([pre_sigmoid_theta], lr=learning_rate) c_optimizer = optim.Adam(c.parameters(), lr=learning_rate) for i in range(num_iterations): pre_sigmoid_theta_optimizer.zero_grad() re = relax_estimator(f, c, safe_sigmoid(pre_sigmoid_theta), 1) re.backward(create_graph=True) pre_sigmoid_theta_grad_detached = pre_sigmoid_theta.grad.detach() pre_sigmoid_theta_optimizer.step() if isnan(pre_sigmoid_theta): print('Iteration {}: pre_sigmoid_theta is nan'.format(i)) break c_optimizer.zero_grad() pre_sigmoid_theta.grad.backward(2 * pre_sigmoid_theta_grad_detached / len(pre_sigmoid_theta_grad_detached)) c_optimizer.step() if isnan(c): print('Iteration {}: c is nan'.format(i)) break theta_history[i] = safe_sigmoid(pre_sigmoid_theta).data[0] temperature_history[i] = torch.exp(c.log_temperature).data[0] if i % 1000 == 0: print('Relax iteration {}'.format(i)) return pre_sigmoid_theta, c, theta_history, temperature_history def main(): torch.manual_seed(2) np.random.seed(1) init_theta = 0.3 t = 0.499 # t = 0.45 # true_gradient = 1 - 2 * t def f(b): return (b.float() - t)**2 num_iterations = 100000 # num_iterations = 10000 reinforce_theta_history = train_reinforce(init_theta, f, num_iterations) rebar_pre_sigmoid_theta, rebar_c, rebar_theta_history, rebar_temperature_history = train_rebar(init_theta, f, num_iterations) relax_pre_sigmoid_theta, relax_c, relax_theta_history, relax_temperature_history = train_relax(init_theta, f, num_iterations) # Plot 1 fig, axs = plt.subplots(1, 3) fig.set_size_inches(8, 3) for ax in axs: ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) axs[0].plot(reinforce_theta_history, color='black', label='reinforce') axs[0].plot(rebar_theta_history, color='gray', label='rebar') axs[0].plot(relax_theta_history, color='lightgray', label='relax') axs[0].set_xlabel('Iteration') axs[0].set_ylabel('$\\theta$') axs[0].set_ylim(0, 1) axs[0].legend() axs[1].plot(rebar_temperature_history, color='gray', label='rebar') axs[1].plot(relax_temperature_history, color='lightgray', label='relax') axs[1].set_xlabel('Iteration') axs[1].set_ylabel('$\lambda$') axs[1].legend() num_u_points = 100 min_u_point = 0.01 max_u_point = 0.99 u_points = Variable(torch.linspace(min_u_point, max_u_point, num_u_points)) rebar_c_points = rebar_c(reparam(u_points, F.sigmoid(rebar_pre_sigmoid_theta).expand(num_u_points))) relax_c_points = relax_c(reparam(u_points, F.sigmoid(relax_pre_sigmoid_theta).expand(num_u_points))) axs[2].plot(u_points.data.numpy(), rebar_c_points.data.numpy(), color='gray', label='$\eta f(\sigma_\lambda(g(u, \\theta)))$ (rebar)') axs[2].plot(u_points.data.numpy(), relax_c_points.data.numpy(), color='lightgray', label='$c_\phi(g(u, \\theta))$ (relax)') axs[2].set_xlabel('$u$') axs[2].set_xlim(0, 1) axs[2].legend() fig.tight_layout() filenames = ['rebar_relax_1.png', 'rebar_relax_1.pdf'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) # Plot 2 num_theta_points = 100 min_theta_point = 0.01 max_theta_point = 0.99 num_mc_samples = 1000 theta_points = np.linspace(min_theta_point, max_theta_point, num_theta_points) reinforce_estimator_samples = np.zeros([num_theta_points, num_mc_samples]) rebar_estimator_samples = np.zeros([num_theta_points, num_mc_samples]) relax_estimator_samples = np.zeros([num_theta_points, num_mc_samples]) for theta_idx, theta in enumerate(theta_points): for mc_sample_idx in range(num_mc_samples): theta_var = Variable(torch.Tensor([theta]), requires_grad=True) reinforce_estimator(f, theta_var, 1).backward() reinforce_estimator_samples[theta_idx, mc_sample_idx] = theta_var.grad.data[0] theta_var = Variable(torch.Tensor([theta]), requires_grad=True) relax_estimator(f, rebar_c, theta_var, 1).backward() rebar_estimator_samples[theta_idx, mc_sample_idx] = theta_var.grad.data[0] theta_var = Variable(torch.Tensor([theta]), requires_grad=True) relax_estimator(f, relax_c, theta_var, 1).backward() relax_estimator_samples[theta_idx, mc_sample_idx] = theta_var.grad.data[0] fig, ax = plt.subplots(1, 1) fig.set_size_inches(5, 3) ax.plot(theta_points, np.var(reinforce_estimator_samples, axis=1), color='black', label='reinforce') ax.plot(theta_points, np.var(rebar_estimator_samples, axis=1), color='gray', label='rebar') ax.plot(theta_points, np.var(relax_estimator_samples, axis=1), color='lightgray', label='relax') ax.set_ylabel('Variance of the estimator of\n$\\nabla_{\\theta} E_{p(b | \\theta)}[f(b)]$') ax.set_xlabel('$\\theta$') ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.set_xlim(0, 1) ax.set_yscale('log') ax.legend() fig.tight_layout() filenames = ['rebar_relax_2.png', 'rebar_relax_2.pdf'] for filename in filenames: fig.savefig(filename, bbox_inches='tight', dpi=200) print('Saved to {}'.format(filename)) if __name__ == '__main__': main()