This repository has been archived by the owner on Jul 4, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 257
/
Copy pathtrain.py
190 lines (152 loc) · 6.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from functools import partial
import glob
import itertools
import os
import time
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler
import torch
import torch.optim as optim
import torch.nn as nn
from torchnlp.samplers import BucketBatchSampler
from torchnlp.datasets import snli_dataset
from torchnlp.encoders.text import WhitespaceEncoder
from torchnlp.encoders import LabelEncoder
from torchnlp import word_to_vector
from model import SNLIClassifier
from util import get_args, makedirs, collate_fn
args = get_args()
if args.gpu >= 0:
torch.cuda.set_device(args.gpu)
# load dataset
train, dev, test = snli_dataset(train=True, dev=True, test=True)
# Preprocess
for row in itertools.chain(train, dev, test):
row['premise'] = row['premise'].lower()
row['hypothesis'] = row['hypothesis'].lower()
# Make Encoders
sentence_corpus = [row['premise'] for row in itertools.chain(train, dev, test)]
sentence_corpus += [row['hypothesis'] for row in itertools.chain(train, dev, test)]
sentence_encoder = WhitespaceEncoder(sentence_corpus)
label_corpus = [row['label'] for row in itertools.chain(train, dev, test)]
label_encoder = LabelEncoder(label_corpus)
# Encode
for row in itertools.chain(train, dev, test):
row['premise'] = sentence_encoder.encode(row['premise'])
row['hypothesis'] = sentence_encoder.encode(row['hypothesis'])
row['label'] = label_encoder.encode(row['label'])
config = args
config.n_embed = sentence_encoder.vocab_size
config.d_out = label_encoder.vocab_size
config.n_cells = config.n_layers
# double the number of cells for bidirectional networks
if config.birnn:
config.n_cells *= 2
if args.resume_snapshot:
model = torch.load(
args.resume_snapshot, map_location=lambda storage, location: storage.cuda(args.gpu))
else:
model = SNLIClassifier(config)
if args.word_vectors:
# Load word vectors
word_vectors = word_to_vector.aliases[args.word_vectors]()
for i, token in enumerate(sentence_encoder.vocab):
model.embed.weight.data[i] = word_vectors[token]
if args.gpu >= 0:
model.cuda()
criterion = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=args.lr)
iterations = 0
start = time.time()
best_dev_acc = -1
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
dev_log_template = ' '.join(
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'
.split(','))
log_template = ' '.join(
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
makedirs(args.save_path)
print(header)
for epoch in range(args.epochs):
n_correct, n_total = 0, 0
train_sampler = SequentialSampler(train)
train_batch_sampler = BucketBatchSampler(
train_sampler, args.batch_size, True, sort_key=lambda r: len(train[r]['premise']))
train_iterator = DataLoader(
train,
batch_sampler=train_batch_sampler,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available(),
num_workers=0)
for batch_idx, (premise_batch, hypothesis_batch, label_batch) in enumerate(train_iterator):
# switch model to training mode, clear gradient accumulators
model.train()
torch.set_grad_enabled(True)
opt.zero_grad()
iterations += 1
# forward pass
answer = model(premise_batch, hypothesis_batch)
# calculate accuracy of predictions in the current batch
n_correct += (torch.max(answer, 1)[1].view(label_batch.size()) == label_batch).sum()
n_total += premise_batch.size()[1]
train_acc = 100. * n_correct / n_total
# calculate loss of the network output with respect to training labels
loss = criterion(answer, label_batch)
# backpropagate and update optimizer learning rate
loss.backward()
opt.step()
# checkpoint model periodically
if iterations % args.save_every == 0:
snapshot_prefix = os.path.join(args.save_path, 'snapshot')
snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(
train_acc, loss.item(), iterations)
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
os.remove(f)
# evaluate performance on validation set periodically
if iterations % args.dev_every == 0:
# switch model to evaluation mode
model.eval()
torch.set_grad_enabled(False)
# calculate accuracy on validation set
n_dev_correct, dev_loss = 0, 0
dev_sampler = SequentialSampler(dev)
dev_batch_sampler = BucketBatchSampler(
dev_sampler, args.batch_size, True, sort_key=lambda r: len(dev[r]['premise']))
dev_iterator = DataLoader(
dev,
batch_sampler=dev_batch_sampler,
collate_fn=partial(collate_fn, train=False),
pin_memory=torch.cuda.is_available(),
num_workers=0)
for dev_batch_idx, (premise_batch, hypothesis_batch,
label_batch) in enumerate(dev_iterator):
answer = model(premise_batch, hypothesis_batch)
n_dev_correct += (torch.max(answer,
1)[1].view(label_batch.size()) == label_batch).sum()
dev_loss = criterion(answer, label_batch)
dev_acc = 100. * n_dev_correct / len(dev)
print(
dev_log_template.format(time.time() - start, epoch, iterations, 1 + batch_idx,
len(train_sampler),
100. * (1 + batch_idx) / len(train_sampler), loss.item(),
dev_loss.item(), train_acc, dev_acc))
# update best validation set accuracy
if dev_acc > best_dev_acc:
# found a model with better validation set accuracy
best_dev_acc = dev_acc
snapshot_prefix = os.path.join(args.save_path, 'best_snapshot')
snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(
dev_acc, dev_loss.item(), iterations)
# save model, delete previous 'best_snapshot' files
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
os.remove(f)
elif iterations % args.log_every == 0:
# print progress message
print(
log_template.format(time.time() - start, epoch, iterations, 1 + batch_idx,
len(train_sampler), 100. * (1 + batch_idx) / len(train_sampler),
loss.item(), ' ' * 8, n_correct / n_total * 100, ' ' * 12))