-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathcontext_gating.py
37 lines (25 loc) · 960 Bytes
/
context_gating.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch import nn
import torch.nn.functional as F
class Gated_Embedding_Unit(nn.Module):
def __init__(self, input_dimension, output_dimension):
super().__init__()
self.fc = nn.Linear(input_dimension, output_dimension)
self.cg = Context_Gating(output_dimension)
def forward(self, x):
x = self.fc(x) #FC layer
x = self.cg(x) #Context Gating Unit
x = F.normalize(x) #normalise
return x
class Context_Gating(nn.Module):
def __init__(self, dimension, add_batch_norm=False):
super().__init__()
self.fc = nn.Linear(dimension, dimension)
self.add_batch_norm = add_batch_norm
self.batch_norm = nn.BatchNorm1d(dimension)
def forward(self, x):
x1 = self.fc(x)
if self.add_batch_norm:
x1 = self.batch_norm(x1)
x = torch.cat((x, x1), 1)
return F.glu(x, 1)