-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathbabi_main.py
353 lines (309 loc) · 14.2 KB
/
babi_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
from babi_loader import BabiDataset, pad_collate
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.utils.data import DataLoader
def position_encoding(embedded_sentence):
'''
embedded_sentence.size() -> (#batch, #sentence, #token, #embedding)
l.size() -> (#sentence, #embedding)
output.size() -> (#batch, #sentence, #embedding)
'''
_, _, slen, elen = embedded_sentence.size()
l = [[(1 - s/(slen-1)) - (e/(elen-1)) * (1 - 2*s/(slen-1)) for e in range(elen)] for s in range(slen)]
l = torch.FloatTensor(l)
l = l.unsqueeze(0) # for #batch
l = l.unsqueeze(1) # for #sen
l = l.expand_as(embedded_sentence)
weighted = embedded_sentence * Variable(l.cuda())
return torch.sum(weighted, dim=2).squeeze(2) # sum with tokens
class AttentionGRUCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(AttentionGRUCell, self).__init__()
self.hidden_size = hidden_size
self.Wr = nn.Linear(input_size, hidden_size)
init.xavier_normal(self.Wr.state_dict()['weight'])
self.Ur = nn.Linear(hidden_size, hidden_size)
init.xavier_normal(self.Ur.state_dict()['weight'])
self.W = nn.Linear(input_size, hidden_size)
init.xavier_normal(self.W.state_dict()['weight'])
self.U = nn.Linear(hidden_size, hidden_size)
init.xavier_normal(self.U.state_dict()['weight'])
def forward(self, fact, C, g):
'''
fact.size() -> (#batch, #hidden = #embedding)
c.size() -> (#hidden, ) -> (#batch, #hidden = #embedding)
r.size() -> (#batch, #hidden = #embedding)
h_tilda.size() -> (#batch, #hidden = #embedding)
g.size() -> (#batch, )
'''
r = F.sigmoid(self.Wr(fact) + self.Ur(C))
h_tilda = F.tanh(self.W(fact) + r * self.U(C))
g = g.unsqueeze(1).expand_as(h_tilda)
h = g * h_tilda + (1 - g) * C
return h
class AttentionGRU(nn.Module):
def __init__(self, input_size, hidden_size):
super(AttentionGRU, self).__init__()
self.hidden_size = hidden_size
self.AGRUCell = AttentionGRUCell(input_size, hidden_size)
def forward(self, facts, G):
'''
facts.size() -> (#batch, #sentence, #hidden = #embedding)
fact.size() -> (#batch, #hidden = #embedding)
G.size() -> (#batch, #sentence)
g.size() -> (#batch, )
C.size() -> (#batch, #hidden)
'''
batch_num, sen_num, embedding_size = facts.size()
C = Variable(torch.zeros(self.hidden_size)).cuda()
for sid in range(sen_num):
fact = facts[:, sid, :]
g = G[:, sid]
if sid == 0:
C = C.unsqueeze(0).expand_as(fact)
C = self.AGRUCell(fact, C, g)
return C
class EpisodicMemory(nn.Module):
def __init__(self, hidden_size):
super(EpisodicMemory, self).__init__()
self.AGRU = AttentionGRU(hidden_size, hidden_size)
self.z1 = nn.Linear(4 * hidden_size, hidden_size)
self.z2 = nn.Linear(hidden_size, 1)
self.next_mem = nn.Linear(3 * hidden_size, hidden_size)
init.xavier_normal(self.z1.state_dict()['weight'])
init.xavier_normal(self.z2.state_dict()['weight'])
init.xavier_normal(self.next_mem.state_dict()['weight'])
def make_interaction(self, facts, questions, prevM):
'''
facts.size() -> (#batch, #sentence, #hidden = #embedding)
questions.size() -> (#batch, 1, #hidden)
prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding)
z.size() -> (#batch, #sentence, 4 x #embedding)
G.size() -> (#batch, #sentence)
'''
batch_num, sen_num, embedding_size = facts.size()
questions = questions.expand_as(facts)
prevM = prevM.expand_as(facts)
z = torch.cat([
facts * questions,
facts * prevM,
torch.abs(facts - questions),
torch.abs(facts - prevM)
], dim=2)
z = z.view(-1, 4 * embedding_size)
G = F.tanh(self.z1(z))
G = self.z2(G)
G = G.view(batch_num, -1)
G = F.softmax(G)
return G
def forward(self, facts, questions, prevM):
'''
facts.size() -> (#batch, #sentence, #hidden = #embedding)
questions.size() -> (#batch, #sentence = 1, #hidden)
prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding)
G.size() -> (#batch, #sentence)
C.size() -> (#batch, #hidden)
concat.size() -> (#batch, 3 x #embedding)
'''
G = self.make_interaction(facts, questions, prevM)
C = self.AGRU(facts, G)
concat = torch.cat([prevM.squeeze(1), C, questions.squeeze(1)], dim=1)
next_mem = F.relu(self.next_mem(concat))
next_mem = next_mem.unsqueeze(1)
return next_mem
class QuestionModule(nn.Module):
def __init__(self, vocab_size, hidden_size):
super(QuestionModule, self).__init__()
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
def forward(self, questions, word_embedding):
'''
questions.size() -> (#batch, #token)
word_embedding() -> (#batch, #token, #embedding)
gru() -> (1, #batch, #hidden)
'''
questions = word_embedding(questions)
_, questions = self.gru(questions)
questions = questions.transpose(0, 1)
return questions
class InputModule(nn.Module):
def __init__(self, vocab_size, hidden_size):
super(InputModule, self).__init__()
self.hidden_size = hidden_size
self.gru = nn.GRU(hidden_size, hidden_size, bidirectional=True, batch_first=True)
for name, param in self.gru.state_dict().items():
if 'weight' in name: init.xavier_normal(param)
self.dropout = nn.Dropout(0.1)
def forward(self, contexts, word_embedding):
'''
contexts.size() -> (#batch, #sentence, #token)
word_embedding() -> (#batch, #sentence x #token, #embedding)
position_encoding() -> (#batch, #sentence, #embedding)
facts.size() -> (#batch, #sentence, #hidden = #embedding)
'''
batch_num, sen_num, token_num = contexts.size()
contexts = contexts.view(batch_num, -1)
contexts = word_embedding(contexts)
contexts = contexts.view(batch_num, sen_num, token_num, -1)
contexts = position_encoding(contexts)
contexts = self.dropout(contexts)
h0 = Variable(torch.zeros(2, batch_num, self.hidden_size).cuda())
facts, hdn = self.gru(contexts, h0)
facts = facts[:, :, :hidden_size] + facts[:, :, hidden_size:]
return facts
class AnswerModule(nn.Module):
def __init__(self, vocab_size, hidden_size):
super(AnswerModule, self).__init__()
self.z = nn.Linear(2 * hidden_size, vocab_size)
init.xavier_normal(self.z.state_dict()['weight'])
self.dropout = nn.Dropout(0.1)
def forward(self, M, questions):
M = self.dropout(M)
concat = torch.cat([M, questions], dim=2).squeeze(1)
z = self.z(concat)
return z
class DMNPlus(nn.Module):
def __init__(self, hidden_size, vocab_size, num_hop=3, qa=None):
super(DMNPlus, self).__init__()
self.num_hop = num_hop
self.qa = qa
self.word_embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0, sparse=True).cuda()
init.uniform(self.word_embedding.state_dict()['weight'], a=-(3**0.5), b=3**0.5)
self.criterion = nn.CrossEntropyLoss(size_average=False)
self.input_module = InputModule(vocab_size, hidden_size)
self.question_module = QuestionModule(vocab_size, hidden_size)
self.memory = EpisodicMemory(hidden_size)
self.answer_module = AnswerModule(vocab_size, hidden_size)
def forward(self, contexts, questions):
'''
contexts.size() -> (#batch, #sentence, #token) -> (#batch, #sentence, #hidden = #embedding)
questions.size() -> (#batch, #token) -> (#batch, 1, #hidden)
'''
facts = self.input_module(contexts, self.word_embedding)
questions = self.question_module(questions, self.word_embedding)
M = questions
for hop in range(self.num_hop):
M = self.memory(facts, questions, M)
preds = self.answer_module(M, questions)
return preds
def interpret_indexed_tensor(self, var):
if len(var.size()) == 3:
# var -> n x #sen x #token
for n, sentences in enumerate(var):
for i, sentence in enumerate(sentences):
s = ' '.join([self.qa.IVOCAB[elem.data[0]] for elem in sentence])
print(f'{n}th of batch, {i}th sentence, {s}')
elif len(var.size()) == 2:
# var -> n x #token
for n, sentence in enumerate(var):
s = ' '.join([self.qa.IVOCAB[elem.data[0]] for elem in sentence])
print(f'{n}th of batch, {s}')
elif len(var.size()) == 1:
# var -> n (one token per batch)
for n, token in enumerate(var):
s = self.qa.IVOCAB[token.data[0]]
print(f'{n}th of batch, {s}')
def get_loss(self, contexts, questions, targets):
output = self.forward(contexts, questions)
loss = self.criterion(output, targets)
reg_loss = 0
for param in self.parameters():
reg_loss += 0.001 * torch.sum(param * param)
preds = F.softmax(output)
_, pred_ids = torch.max(preds, dim=1)
corrects = (pred_ids.data == answers.data)
acc = torch.mean(corrects.float())
return loss + reg_loss, acc
if __name__ == '__main__':
for run in range(10):
for task_id in range(1, 21):
dset = BabiDataset(task_id)
vocab_size = len(dset.QA.VOCAB)
hidden_size = 80
model = DMNPlus(hidden_size, vocab_size, num_hop=3, qa=dset.QA)
model.cuda()
early_stopping_cnt = 0
early_stopping_flag = False
best_acc = 0
optim = torch.optim.Adam(model.parameters())
for epoch in range(256):
dset.set_mode('train')
train_loader = DataLoader(
dset, batch_size=100, shuffle=True, collate_fn=pad_collate
)
model.train()
if not early_stopping_flag:
total_acc = 0
cnt = 0
for batch_idx, data in enumerate(train_loader):
optim.zero_grad()
contexts, questions, answers = data
batch_size = contexts.size()[0]
contexts = Variable(contexts.long().cuda())
questions = Variable(questions.long().cuda())
answers = Variable(answers.cuda())
loss, acc = model.get_loss(contexts, questions, answers)
loss.backward()
total_acc += acc * batch_size
cnt += batch_size
if batch_idx % 20 == 0:
print(f'[Task {task_id}, Epoch {epoch}] [Training] loss : {loss.data[0]: {10}.{8}}, acc : {total_acc / cnt: {5}.{4}}, batch_idx : {batch_idx}')
optim.step()
dset.set_mode('valid')
valid_loader = DataLoader(
dset, batch_size=100, shuffle=False, collate_fn=pad_collate
)
model.eval()
total_acc = 0
cnt = 0
for batch_idx, data in enumerate(valid_loader):
contexts, questions, answers = data
batch_size = contexts.size()[0]
contexts = Variable(contexts.long().cuda())
questions = Variable(questions.long().cuda())
answers = Variable(answers.cuda())
_, acc = model.get_loss(contexts, questions, answers)
total_acc += acc * batch_size
cnt += batch_size
total_acc = total_acc / cnt
if total_acc > best_acc:
best_acc = total_acc
best_state = model.state_dict()
early_stopping_cnt = 0
else:
early_stopping_cnt += 1
if early_stopping_cnt > 20:
early_stopping_flag = True
print(f'[Run {run}, Task {task_id}, Epoch {epoch}] [Validate] Accuracy : {total_acc: {5}.{4}}')
with open('log.txt', 'a') as fp:
fp.write(f'[Run {run}, Task {task_id}, Epoch {epoch}] [Validate] Accuracy : {total_acc: {5}.{4}}' + '\n')
if total_acc == 1.0:
break
else:
print(f'[Run {run}, Task {task_id}] Early Stopping at Epoch {epoch}, Valid Accuracy : {best_acc: {5}.{4}}')
break
dset.set_mode('test')
test_loader = DataLoader(
dset, batch_size=100, shuffle=False, collate_fn=pad_collate
)
test_acc = 0
cnt = 0
for batch_idx, data in enumerate(test_loader):
contexts, questions, answers = data
batch_size = contexts.size()[0]
contexts = Variable(contexts.long().cuda())
questions = Variable(questions.long().cuda())
answers = Variable(answers.cuda())
model.load_state_dict(best_state)
_, acc = model.get_loss(contexts, questions, answers)
test_acc += acc * batch_size
cnt += batch_size
print(f'[Run {run}, Task {task_id}, Epoch {epoch}] [Test] Accuracy : {test_acc / cnt: {5}.{4}}')
os.makedirs('models', exist_ok=True)
with open(f'models/task{task_id}_epoch{epoch}_run{run}_acc{test_acc/cnt}.pth', 'wb') as fp:
torch.save(model.state_dict(), fp)
with open('log.txt', 'a') as fp:
fp.write(f'[Run {run}, Task {task_id}, Epoch {epoch}] [Test] Accuracy : {total_acc: {5}.{4}}' + '\n')