forked from codertimo/BERT-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhubconf.py
118 lines (97 loc) · 4.99 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import random
import torch
import numpy as np
from bert_pytorch import parse_args
from bert_pytorch.trainer import BERTTrainer
from bert_pytorch.dataset import BERTDataset, WordVocab
from bert_pytorch.model import BERT
from torch.utils.data import DataLoader
torch.manual_seed(1337)
random.seed(1337)
np.random.seed(1337)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def skipIfNotImplemented(func):
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except NotImplementedError:
print('skipped since {} is not implemented'.format(func.__name__))
return wrapper
class Model:
def __init__(self, device=None, jit=False):
self.device = device
self.jit = jit
args = parse_args(args=[
'--train_dataset', 'data/corpus.small',
'--test_dataset', 'data/corpus.small',
'--vocab_path', 'data/vocab.small',
'--output_path', 'bert.model',
]) # Avoid reading sys.argv here
args.with_cuda = self.device == 'cuda'
args.script = self.jit
print("Loading Vocab", args.vocab_path)
vocab = WordVocab.load_vocab(args.vocab_path)
print("Vocab Size: ", len(vocab))
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len,
corpus_lines=args.corpus_lines, on_memory=args.on_memory)
test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \
if args.test_dataset is not None else None
print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
if test_dataset is not None else None
print("Building BERT model")
bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
if args.script:
print("Scripting BERT model")
bert = torch.jit.script(bert)
self.trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq, debug=args.debug)
example_batch = next(iter(train_data_loader))
self.example_inputs = example_batch['bert_input'].to(self.device), example_batch['segment_label'].to(self.device)
def get_module(self):
return self.trainer.model, self.example_inputs
@skipIfNotImplemented
def eval(self, niter=1):
trainer = self.trainer
_, data = next(enumerate(trainer.test_data))
for _ in range(niter):
data = {key: value.to(trainer.device) for key, value in data.items()}
# 1. forward the next_sentence_prediction and masked_lm model
next_sent_output, mask_lm_output = trainer.model.forward(data["bert_input"], data["segment_label"])
# 2-1. NLL(negative log likelihood) loss of is_next classification result
# 2-2. NLLLoss of predicting masked token word
# 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
next_loss = trainer.criterion(next_sent_output, data["is_next"])
mask_loss = trainer.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])
loss = next_loss + mask_loss
@skipIfNotImplemented
def train(self, niter=1):
trainer = self.trainer
_, data = next(enumerate(trainer.train_data))
for _ in range(niter):
data = {key: value.to(trainer.device) for key, value in data.items()}
# 1. forward the next_sentence_prediction and masked_lm model
next_sent_output, mask_lm_output = trainer.model.forward(data["bert_input"], data["segment_label"])
# 2-1. NLL(negative log likelihood) loss of is_next classification result
# 2-2. NLLLoss of predicting masked token word
# 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
next_loss = trainer.criterion(next_sent_output, data["is_next"])
mask_loss = trainer.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])
loss = next_loss + mask_loss
# 3. backward and optimization only in train
trainer.optim_schedule.zero_grad()
loss.backward()
trainer.optim_schedule.step_and_update_lr()
if __name__ == '__main__':
for device in ['cpu', 'cuda']:
for jit in [True, False]:
print("Testing device {}, JIT {}".format(device, jit))
m = Model(device=device, jit=jit)
bert, example_inputs = m.get_module()
bert(*example_inputs)
m.train()
m.eval()