diff --git a/bayesian_torch/layers/flipout_layers/conv_flipout.py b/bayesian_torch/layers/flipout_layers/conv_flipout.py index cc3c26e..d1996f7 100644 --- a/bayesian_torch/layers/flipout_layers/conv_flipout.py +++ b/bayesian_torch/layers/flipout_layers/conv_flipout.py @@ -100,6 +100,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init self.bias = bias + self.kl = 0 + self.mu_kernel = nn.Parameter( torch.Tensor(out_channels, in_channels // groups, kernel_size)) self.rho_kernel = nn.Parameter( @@ -150,7 +152,7 @@ def init_parameters(self): self.prior_bias_mu.data.fill_(self.prior_mean) self.prior_bias_sigma.data.fill_(self.prior_variance) - def forward(self, x): + def forward(self, x, return_kl=True): # linear outputs outputs = F.conv1d(x, @@ -191,8 +193,11 @@ def forward(self, x): dilation=self.dilation, groups=self.groups) * sign_output + self.kl = kl # returning outputs + perturbations - return outputs + perturbed_outputs, kl + if return_kl: + return outputs + perturbed_outputs, kl + return outputs + perturbed_outputs class Conv2dFlipout(BaseVariationalLayer_): @@ -244,6 +249,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init self.bias = bias + self.kl = 0 + self.mu_kernel = nn.Parameter( torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) @@ -299,7 +306,7 @@ def init_parameters(self): self.prior_bias_mu.data.fill_(self.prior_mean) self.prior_bias_sigma.data.fill_(self.prior_variance) - def forward(self, x): + def forward(self, x, return_kl=True): # linear outputs outputs = F.conv2d(x, @@ -340,8 +347,11 @@ def forward(self, x): dilation=self.dilation, groups=self.groups) * sign_output + self.kl = kl # returning outputs + perturbations - return outputs + perturbed_outputs, kl + if return_kl: + return outputs + perturbed_outputs, kl + return outputs + perturbed_outputs class Conv3dFlipout(BaseVariationalLayer_): @@ -388,6 +398,8 @@ def __init__(self, self.groups = groups self.bias = bias + self.kl = 0 + self.prior_mean = prior_mean self.prior_variance = prior_variance self.posterior_mu_init = posterior_mu_init @@ -448,7 +460,7 @@ def init_parameters(self): self.prior_bias_mu.data.fill_(self.prior_mean) self.prior_bias_sigma.data.fill_(self.prior_variance) - def forward(self, x): + def forward(self, x, return_kl=True): # linear outputs outputs = F.conv3d(x, @@ -489,8 +501,11 @@ def forward(self, x): dilation=self.dilation, groups=self.groups) * sign_output + self.kl = kl # returning outputs + perturbations - return outputs + perturbed_outputs, kl + if return_kl: + return outputs + perturbed_outputs, kl + return outputs + perturbed_outputs class ConvTranspose1dFlipout(BaseVariationalLayer_): @@ -537,6 +552,8 @@ def __init__(self, self.groups = groups self.bias = bias + self.kl = 0 + self.prior_mean = prior_mean self.prior_variance = prior_variance self.posterior_mu_init = posterior_mu_init @@ -593,7 +610,7 @@ def init_parameters(self): self.prior_bias_mu.data.fill_(self.prior_mean) self.prior_bias_sigma.data.fill_(self.prior_variance) - def forward(self, x): + def forward(self, x, return_kl=True): # linear outputs outputs = F.conv_transpose1d(x, @@ -635,8 +652,11 @@ def forward(self, x): dilation=self.dilation, groups=self.groups) * sign_output + self.kl = kl # returning outputs + perturbations - return outputs + perturbed_outputs, kl + if return_kl: + return outputs + perturbed_outputs, kl + return outputs + perturbed_outputs class ConvTranspose2dFlipout(BaseVariationalLayer_): @@ -683,6 +703,8 @@ def __init__(self, self.groups = groups self.bias = bias + self.kl = 0 + self.prior_mean = prior_mean self.prior_variance = prior_variance self.posterior_mu_init = posterior_mu_init @@ -743,7 +765,7 @@ def init_parameters(self): self.prior_bias_mu.data.fill_(self.prior_mean) self.prior_bias_sigma.data.fill_(self.prior_variance) - def forward(self, x): + def forward(self, x, return_kl=True): # linear outputs outputs = F.conv_transpose2d(x, @@ -785,8 +807,11 @@ def forward(self, x): dilation=self.dilation, groups=self.groups) * sign_output + self.kl = kl # returning outputs + perturbations - return outputs + perturbed_outputs, kl + if return_kl: + return outputs + perturbed_outputs, kl + return outputs + perturbed_outputs class ConvTranspose3dFlipout(BaseVariationalLayer_): @@ -838,6 +863,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init self.bias = bias + self.kl = 0 + self.mu_kernel = nn.Parameter( torch.Tensor(in_channels, out_channels // groups, kernel_size, kernel_size, kernel_size)) @@ -893,7 +920,7 @@ def init_parameters(self): self.prior_bias_mu.data.fill_(self.prior_mean) self.prior_bias_sigma.data.fill_(self.prior_variance) - def forward(self, x): + def forward(self, x, return_kl=True): # linear outputs outputs = F.conv_transpose3d(x, @@ -935,5 +962,8 @@ def forward(self, x): dilation=self.dilation, groups=self.groups) * sign_output + self.kl = kl # returning outputs + perturbations - return outputs + perturbed_outputs, kl + if return_kl: + return outputs + perturbed_outputs, kl + return outputs + perturbed_outputs diff --git a/bayesian_torch/layers/flipout_layers/linear_flipout.py b/bayesian_torch/layers/flipout_layers/linear_flipout.py index d7d577f..2538f1d 100644 --- a/bayesian_torch/layers/flipout_layers/linear_flipout.py +++ b/bayesian_torch/layers/flipout_layers/linear_flipout.py @@ -90,6 +90,8 @@ def __init__(self, torch.Tensor(out_features, in_features), persistent=False) + self.kl = 0 + if bias: self.mu_bias = nn.Parameter(torch.Tensor(out_features)) self.rho_bias = nn.Parameter(torch.Tensor(out_features)) @@ -123,7 +125,7 @@ def init_parameters(self): self.mu_bias.data.normal_(mean=self.posterior_mu_init, std=0.1) self.rho_bias.data.normal_(mean=self.posterior_rho_init, std=0.1) - def forward(self, x): + def forward(self, x, return_kl=True): # sampling delta_W sigma_weight = torch.log1p(torch.exp(self.rho_weight)) delta_weight = (sigma_weight * self.eps_weight.data.normal_()) @@ -148,5 +150,9 @@ def forward(self, x): perturbed_outputs = F.linear(x * sign_input, delta_weight, bias) * sign_output + self.kl = kl + # returning outputs + perturbations - return outputs + perturbed_outputs, kl + if return_kl: + return outputs + perturbed_outputs, kl + return outputs + perturbed_outputs diff --git a/bayesian_torch/layers/flipout_layers/rnn_flipout.py b/bayesian_torch/layers/flipout_layers/rnn_flipout.py index 38c222a..317ebc4 100644 --- a/bayesian_torch/layers/flipout_layers/rnn_flipout.py +++ b/bayesian_torch/layers/flipout_layers/rnn_flipout.py @@ -76,6 +76,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho)) self.bias = bias + self.kl = 0 + self.ih = LinearFlipout(prior_mean=prior_mean, prior_variance=prior_variance, posterior_mu_init=posterior_mu_init, @@ -92,7 +94,7 @@ def __init__(self, out_features=out_features * 4, bias=bias) - def forward(self, X, hidden_states=None): + def forward(self, X, hidden_states=None, return_kl=True): batch_size, seq_size, _ = X.size() @@ -137,4 +139,7 @@ def forward(self, X, hidden_states=None): hidden_seq = hidden_seq.transpose(0, 1).contiguous() c_ts = c_ts.transpose(0, 1).contiguous() - return hidden_seq, (hidden_seq, c_ts), kl + self.kl = kl + if return_kl: + return hidden_seq, (hidden_seq, c_ts), kl + return hidden_seq, (hidden_seq, c_ts) diff --git a/bayesian_torch/layers/variational_layers/conv_variational.py b/bayesian_torch/layers/variational_layers/conv_variational.py index 4311400..96b1db5 100644 --- a/bayesian_torch/layers/variational_layers/conv_variational.py +++ b/bayesian_torch/layers/variational_layers/conv_variational.py @@ -112,6 +112,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.kl = 0 + self.mu_kernel = Parameter( torch.Tensor(out_channels, in_channels // groups, kernel_size)) self.rho_kernel = Parameter( @@ -160,7 +162,7 @@ def init_parameters(self): self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], std=0.1) - def forward(self, input): + def forward(self, input, return_kl=True): sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) eps_kernel = self.eps_kernel.data.normal_() weight = self.mu_kernel + (sigma_weight * eps_kernel) @@ -182,7 +184,11 @@ def forward(self, input): else: kl = kl_weight - return out, kl + self.kl = kl + + if return_kl: + return out, kl + return out class Conv2dReparameterization(BaseVariationalLayer_): @@ -239,6 +245,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.kl = 0 + self.mu_kernel = Parameter( torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) @@ -292,7 +300,7 @@ def init_parameters(self): self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], std=0.1) - def forward(self, input): + def forward(self, input, return_kl=True): sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) eps_kernel = self.eps_kernel.data.normal_() weight = self.mu_kernel + (sigma_weight * eps_kernel) @@ -313,8 +321,12 @@ def forward(self, input): kl = kl_weight + kl_bias else: kl = kl_weight + + self.kl = kl - return out, kl + if return_kl: + return out, kl + return out class Conv3dReparameterization(BaseVariationalLayer_): @@ -371,6 +383,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.kl = 0 + self.mu_kernel = Parameter( torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size, kernel_size)) @@ -424,7 +438,7 @@ def init_parameters(self): self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], std=0.1) - def forward(self, input): + def forward(self, input, return_kl=True): sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) eps_kernel = self.eps_kernel.data.normal_() weight = self.mu_kernel + (sigma_weight * eps_kernel) @@ -446,7 +460,11 @@ def forward(self, input): else: kl = kl_weight - return out, kl + self.kl = kl + + if return_kl: + return out, kl + return out class ConvTranspose1dReparameterization(BaseVariationalLayer_): @@ -504,6 +522,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.kl = 0 + self.mu_kernel = Parameter( torch.Tensor(in_channels, out_channels // groups, kernel_size)) self.rho_kernel = Parameter( @@ -552,7 +572,7 @@ def init_parameters(self): self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], std=0.1) - def forward(self, input): + def forward(self, input, return_kl=True): sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) eps_kernel = self.eps_kernel.data.normal_() weight = self.mu_kernel + (sigma_weight * eps_kernel) @@ -575,7 +595,11 @@ def forward(self, input): else: kl = kl_weight - return out, kl + self.kl = kl + + if return_kl: + return out, kl + return out class ConvTranspose2dReparameterization(BaseVariationalLayer_): @@ -633,6 +657,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.kl = 0 + self.mu_kernel = Parameter( torch.Tensor(in_channels, out_channels // groups, kernel_size, kernel_size)) @@ -686,7 +712,7 @@ def init_parameters(self): self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], std=0.1) - def forward(self, input): + def forward(self, input, return_kl=True): sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) eps_kernel = self.eps_kernel.data.normal_() weight = self.mu_kernel + (sigma_weight * eps_kernel) @@ -709,7 +735,11 @@ def forward(self, input): else: kl = kl_weight - return out, kl + self.kl = kl + + if return_kl: + return out, kl + return out class ConvTranspose3dReparameterization(BaseVariationalLayer_): @@ -768,6 +798,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.kl = 0 + self.mu_kernel = Parameter( torch.Tensor(in_channels, out_channels // groups, kernel_size, kernel_size, kernel_size)) @@ -821,7 +853,7 @@ def init_parameters(self): self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], std=0.1) - def forward(self, input): + def forward(self, input, return_kl=True): sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) eps_kernel = self.eps_kernel.data.normal_() weight = self.mu_kernel + (sigma_weight * eps_kernel) @@ -844,4 +876,8 @@ def forward(self, input): else: kl = kl_weight - return out, kl + self.kl = kl + + if return_kl: + return out, kl + return out diff --git a/bayesian_torch/layers/variational_layers/linear_variational.py b/bayesian_torch/layers/variational_layers/linear_variational.py index af113f5..bb3a296 100644 --- a/bayesian_torch/layers/variational_layers/linear_variational.py +++ b/bayesian_torch/layers/variational_layers/linear_variational.py @@ -83,6 +83,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.kl = 0 + self.mu_weight = Parameter(torch.Tensor(out_features, in_features)) self.rho_weight = Parameter(torch.Tensor(out_features, in_features)) self.register_buffer('eps_weight', @@ -124,7 +126,7 @@ def init_parameters(self): self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], std=0.1) - def forward(self, input): + def forward(self, input, return_kl=True): sigma_weight = torch.log1p(torch.exp(self.rho_weight)) weight = self.mu_weight + \ (sigma_weight * self.eps_weight.data.normal_()) @@ -143,5 +145,9 @@ def forward(self, input): kl = kl_weight + kl_bias else: kl = kl_weight + + self.kl = kl - return out, kl + if return_kl: + return out, kl + return out \ No newline at end of file diff --git a/bayesian_torch/layers/variational_layers/rnn_variational.py b/bayesian_torch/layers/variational_layers/rnn_variational.py index ab126ad..c36378c 100644 --- a/bayesian_torch/layers/variational_layers/rnn_variational.py +++ b/bayesian_torch/layers/variational_layers/rnn_variational.py @@ -77,6 +77,8 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.kl = kl + self.ih = LinearReparameterization( prior_mean=prior_mean, prior_variance=prior_variance, @@ -95,7 +97,7 @@ def __init__(self, out_features=out_features * 4, bias=bias) - def forward(self, X, hidden_states=None): + def forward(self, X, hidden_states=None, return_kl=True): batch_size, seq_size, _ = X.size() @@ -140,4 +142,8 @@ def forward(self, X, hidden_states=None): hidden_seq = hidden_seq.transpose(0, 1).contiguous() c_ts = c_ts.transpose(0, 1).contiguous() - return hidden_seq, (hidden_seq, c_ts), kl + self.kl = kl + + if return_kl: + return hidden_seq, (hidden_seq, c_ts), kl + return hidden_seq, (hidden_seq, c_ts)