-
Notifications
You must be signed in to change notification settings - Fork 13
/
Modules_ori.py
95 lines (71 loc) · 3.74 KB
/
Modules_ori.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
self.w_1 = nn.Conv1d(d_in, d_hid, 1)
self.w_2 = nn.Conv1d(d_hid, d_in, 1)
self.layer_norm = nn.LayerNorm(d_in)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
output = x.transpose(1, 2)
output = self.w_2(F.relu(self.w_1(output)))
output = output.transpose(1, 2)
output = self.dropout(output)
output = self.layer_norm(output + residual)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, num_units, num_heads, dropout_rate):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
assert hidden_size % num_heads == 0
self.linear_q = nn.Linear(hidden_size, num_units)
self.linear_k = nn.Linear(hidden_size, num_units)
self.linear_v = nn.Linear(hidden_size, num_units)
self.dropout = nn.Dropout(dropout_rate)
self.softmax = nn.Softmax(dim=-1)
def forward(self, queries, keys):
"""
:param queries: A 3d tensor with shape of [N, T_q, C_q]
:param keys: A 3d tensor with shape of [N, T_k, C_k]
:return: A 3d tensor with shape of (N, T_q, C)
"""
Q = self.linear_q(queries) # (N, T_q, C)
K = self.linear_k(keys) # (N, T_k, C)
V = self.linear_v(keys) # (N, T_k, C)
# Split and Concat
split_size = self.hidden_size // self.num_heads
Q_ = torch.cat(torch.split(Q, split_size, dim=2), dim=0) # (h*N, T_q, C/h)
K_ = torch.cat(torch.split(K, split_size, dim=2), dim=0) # (h*N, T_k, C/h)
V_ = torch.cat(torch.split(V, split_size, dim=2), dim=0) # (h*N, T_k, C/h)
# Multiplication
matmul_output = torch.bmm(Q_, K_.transpose(1, 2)) / self.hidden_size ** 0.5 # (h*N, T_q, T_k)
# Key Masking
key_mask = torch.sign(torch.abs(keys.sum(dim=-1))).repeat(self.num_heads, 1) # (h*N, T_k)
key_mask_reshaped = key_mask.unsqueeze(1).repeat(1, queries.shape[1], 1) # (h*N, T_q, T_k)
key_paddings = torch.ones_like(matmul_output) * (-2 ** 32 + 1)
matmul_output_m1 = torch.where(torch.eq(key_mask_reshaped, 0), key_paddings, matmul_output) # (h*N, T_q, T_k)
# Causality - Future Blinding
diag_vals = torch.ones_like(matmul_output[0, :, :]) # (T_q, T_k)
tril = torch.tril(diag_vals) # (T_q, T_k)
causality_mask = tril.unsqueeze(0).repeat(matmul_output.shape[0], 1, 1) # (h*N, T_q, T_k)
causality_paddings = torch.ones_like(causality_mask) * (-2 ** 32 + 1)
matmul_output_m2 = torch.where(torch.eq(causality_mask, 0), causality_paddings, matmul_output_m1) # (h*N, T_q, T_k)
# Activation
matmul_output_sm = self.softmax(matmul_output_m2) # (h*N, T_q, T_k)
# Query Masking
query_mask = torch.sign(torch.abs(queries.sum(dim=-1))).repeat(self.num_heads, 1) # (h*N, T_q)
query_mask = query_mask.unsqueeze(-1).repeat(1, 1, keys.shape[1]) # (h*N, T_q, T_k)
matmul_output_qm = matmul_output_sm * query_mask
# Dropout
matmul_output_dropout = self.dropout(matmul_output_qm)
# Weighted Sum
output_ws = torch.bmm(matmul_output_dropout, V_) # ( h*N, T_q, C/h)
# Restore Shape
output = torch.cat(torch.split(output_ws, output_ws.shape[0] // self.num_heads, dim=0), dim=2) # (N, T_q, C)
# Residual Connection
output_res = output + queries
return output_res