From 5f0e43fb961431437d33abe5d70251cf8067d14d Mon Sep 17 00:00:00 2001 From: Eduardo Carvalho Date: Thu, 6 Feb 2020 16:39:38 +0100 Subject: [PATCH] feat: fix shared layers with independent batchnorm --- pytorch_tabnet/tab_network.py | 86 +++++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 30 deletions(-) diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 59bbb2db..9c11b85f 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -127,15 +127,16 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8, self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01) if self.n_shared > 0: - shared_feat_transform = GLU_Block(self.post_embed_dim, - n_d+n_a, - n_glu=self.n_shared, - virtual_batch_size=self.virtual_batch_size, - first=True, - momentum=momentum, - device=self.device) + shared_feat_transform = torch.nn.ModuleList() + for i in range(self.n_shared): + if i == 0: + shared_feat_transform.append(Linear(self.post_embed_dim, 2*(n_d + n_a), bias=False)) + else: + shared_feat_transform.append(Linear(n_d + n_a, 2*(n_d + n_a), bias=False)) + else: shared_feat_transform = None + self.initial_splitter = FeatTransformer(self.post_embed_dim, n_d+n_a, shared_feat_transform, n_glu=self.n_independent, virtual_batch_size=self.virtual_batch_size, @@ -244,7 +245,7 @@ def forward(self, priors, processed_feat): class FeatTransformer(torch.nn.Module): - def __init__(self, input_dim, output_dim, shared_blocks, n_glu, + def __init__(self, input_dim, output_dim, shared_layers, n_glu, virtual_batch_size=128, momentum=0.02, device='cpu'): super(FeatTransformer, self).__init__() """ @@ -256,19 +257,13 @@ def __init__(self, input_dim, output_dim, shared_blocks, n_glu, Input size - output_dim : int Outpu_size - - shared_blocks : torch.nn.Module + - shared_blocks : torch.nn.ModuleList The shared block that should be common to every step - momentum : float Float value between 0 and 1 which will be used for momentum in batch norm """ - self.shared = shared_blocks - if self.shared is not None: - for l in self.shared.glu_layers: - l.bn = GBN(2*output_dim, virtual_batch_size=virtual_batch_size, - momentum=momentum, device=device) - - if self.shared is None: + if shared_layers is None: self.specifics = GLU_Block(input_dim, output_dim, n_glu=n_glu, first=True, @@ -276,6 +271,13 @@ def __init__(self, input_dim, output_dim, shared_blocks, n_glu, momentum=momentum, device=device) else: + self.shared = GLU_Block(input_dim, output_dim, + n_glu=n_glu, + first=True, + shared_layers=shared_layers, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + device=device) self.specifics = GLU_Block(output_dim, output_dim, n_glu=n_glu, virtual_batch_size=virtual_batch_size, @@ -284,7 +286,11 @@ def __init__(self, input_dim, output_dim, shared_blocks, n_glu, def forward(self, x): if self.shared is not None: + # print('-------before----------') + # print(self.shared.glu_layers[0].bn.bn.running_mean) x = self.shared(x) + # print('-------after-----------') + # print(self.shared.glu_layers[0].bn.bn.running_mean) x = self.specifics(x) return x @@ -293,24 +299,41 @@ class GLU_Block(torch.nn.Module): """ Independant GLU block, specific to each step """ - def __init__(self, input_dim, output_dim, n_glu=2, first=False, + def __init__(self, input_dim, output_dim, n_glu=2, first=False, shared_layers=None, virtual_batch_size=128, momentum=0.02, device='cpu'): super(GLU_Block, self).__init__() self.first = first + self.shared_layers = shared_layers self.n_glu = n_glu self.glu_layers = torch.nn.ModuleList() self.scale = torch.sqrt(torch.FloatTensor([0.5]).to(device)) - for glu_id in range(self.n_glu): - if glu_id == 0: - self.glu_layers.append(GLU_Layer(input_dim, output_dim, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - device=device)) - else: - self.glu_layers.append(GLU_Layer(output_dim, output_dim, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - device=device)) + + if shared_layers: + for glu_id in range(self.n_glu): + if glu_id == 0: + self.glu_layers.append(GLU_Layer(input_dim, output_dim, + fc=shared_layers[glu_id], + virtual_batch_size=virtual_batch_size, + momentum=momentum, + device=device)) + else: + self.glu_layers.append(GLU_Layer(output_dim, output_dim, + fc=shared_layers[glu_id], + virtual_batch_size=virtual_batch_size, + momentum=momentum, + device=device)) + else: + for glu_id in range(self.n_glu): + if glu_id == 0: + self.glu_layers.append(GLU_Layer(input_dim, output_dim, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + device=device)) + else: + self.glu_layers.append(GLU_Layer(output_dim, output_dim, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + device=device)) def forward(self, x): if self.first: # the first layer of the block has no scale multiplication @@ -326,12 +349,15 @@ def forward(self, x): class GLU_Layer(torch.nn.Module): - def __init__(self, input_dim, output_dim, + def __init__(self, input_dim, output_dim, fc=None, virtual_batch_size=128, momentum=0.02, device='cpu'): super(GLU_Layer, self).__init__() self.output_dim = output_dim - self.fc = Linear(input_dim, 2*output_dim, bias=False) + if fc: + self.fc = fc + else: + self.fc = Linear(input_dim, 2*output_dim, bias=False) initialize_glu(self.fc, input_dim, 2*output_dim) self.bn = GBN(2*output_dim, virtual_batch_size=virtual_batch_size,