-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmain.py
149 lines (130 loc) · 6.77 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from transformers import logging, AutoTokenizer, AutoModel
from config import get_config
from data import load_dataset
from model import Transformer, Gru_Model, BiLstm_Model, Lstm_Model, Rnn_Model, TextCNN_Model, Transformer_CNN_RNN, \
Transformer_Attention, Transformer_CNN_RNN_Attention
class Niubility:
def __init__(self, args, logger):
self.args = args
self.logger = logger
self.logger.info('> creating model {}'.format(args.model_name))
# Create model
if args.model_name == 'bert':
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.input_size = 768
base_model = AutoModel.from_pretrained('bert-base-uncased')
elif args.model_name == 'roberta':
self.tokenizer = AutoTokenizer.from_pretrained('roberta-base', add_prefix_space=True)
self.input_size = 768
base_model = AutoModel.from_pretrained('roberta-base')
else:
raise ValueError('unknown model')
# Operate the method
if args.method_name == 'fnn':
self.Mymodel = Transformer(base_model, args.num_classes, self.input_size)
elif args.method_name == 'gru':
self.Mymodel = Gru_Model(base_model, args.num_classes, self.input_size)
elif args.method_name == 'lstm':
self.Mymodel = Lstm_Model(base_model, args.num_classes, self.input_size)
elif args.method_name == 'bilstm':
self.Mymodel = BiLstm_Model(base_model, args.num_classes, self.input_size)
elif args.method_name == 'rnn':
self.Mymodel = Rnn_Model(base_model, args.num_classes, self.input_size)
elif args.method_name == 'textcnn':
self.Mymodel = TextCNN_Model(base_model, args.num_classes)
elif args.method_name == 'attention':
self.Mymodel = Transformer_Attention(base_model, args.num_classes)
elif args.method_name == 'lstm+textcnn':
self.Mymodel = Transformer_CNN_RNN(base_model, args.num_classes)
elif args.method_name == 'lstm_textcnn_attention':
self.Mymodel = Transformer_CNN_RNN_Attention(base_model, args.num_classes)
else:
raise ValueError('unknown method')
self.Mymodel.to(args.device)
if args.device.type == 'cuda':
self.logger.info('> cuda memory allocated: {}'.format(torch.cuda.memory_allocated(args.device.index)))
self._print_args()
def _print_args(self):
self.logger.info('> training arguments:')
for arg in vars(self.args):
self.logger.info(f">>> {arg}: {getattr(self.args, arg)}")
def _train(self, dataloader, criterion, optimizer):
train_loss, n_correct, n_train = 0, 0, 0
# Turn on the train mode
self.Mymodel.train()
for inputs, targets in tqdm(dataloader, disable=self.args.backend, ascii='>='):
inputs = {k: v.to(self.args.device) for k, v in inputs.items()}
targets = targets.to(self.args.device)
predicts = self.Mymodel(inputs)
loss = criterion(predicts, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * targets.size(0)
n_correct += (torch.argmax(predicts, dim=1) == targets).sum().item()
n_train += targets.size(0)
return train_loss / n_train, n_correct / n_train
def _test(self, dataloader, criterion):
test_loss, n_correct, n_test = 0, 0, 0
# Turn on the eval mode
self.Mymodel.eval()
with torch.no_grad():
for inputs, targets in tqdm(dataloader, disable=self.args.backend, ascii=' >='):
inputs = {k: v.to(self.args.device) for k, v in inputs.items()}
targets = targets.to(self.args.device)
predicts = self.Mymodel(inputs)
loss = criterion(predicts, targets)
test_loss += loss.item() * targets.size(0)
n_correct += (torch.argmax(predicts, dim=1) == targets).sum().item()
n_test += targets.size(0)
return test_loss / n_test, n_correct / n_test
def run(self):
# Print the parameters of model
# for name, layer in self.Mymodel.named_parameters(recurse=True):
# print(name, layer.shape, sep=" ")
train_dataloader, test_dataloader = load_dataset(tokenizer=self.tokenizer,
train_batch_size=self.args.train_batch_size,
test_batch_size=self.args.test_batch_size,
model_name=self.args.model_name,
method_name=self.args.method_name,
workers=self.args.workers)
_params = filter(lambda x: x.requires_grad, self.Mymodel.parameters())
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(_params, lr=self.args.lr, weight_decay=self.args.weight_decay)
l_acc, l_trloss, l_teloss, l_epo = [], [], [], []
# Get the best_loss and the best_acc
best_loss, best_acc = 0, 0
for epoch in range(self.args.num_epoch):
train_loss, train_acc = self._train(train_dataloader, criterion, optimizer)
test_loss, test_acc = self._test(test_dataloader, criterion)
l_epo.append(epoch), l_acc.append(test_acc), l_trloss.append(train_loss), l_teloss.append(test_loss)
if test_acc > best_acc or (test_acc == best_acc and test_loss < best_loss):
best_acc, best_loss = test_acc, test_loss
self.logger.info(
'{}/{} - {:.2f}%'.format(epoch + 1, self.args.num_epoch, 100 * (epoch + 1) / self.args.num_epoch))
self.logger.info('[train] loss: {:.4f}, acc: {:.2f}'.format(train_loss, train_acc * 100))
self.logger.info('[test] loss: {:.4f}, acc: {:.2f}'.format(test_loss, test_acc * 100))
self.logger.info('best loss: {:.4f}, best acc: {:.2f}'.format(best_loss, best_acc * 100))
self.logger.info('log saved: {}'.format(self.args.log_name))
# Draw the training process
plt.plot(l_epo, l_acc)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.savefig('acc.png')
plt.plot(l_epo, l_teloss)
plt.ylabel('test-loss')
plt.xlabel('epoch')
plt.savefig('teloss.png')
plt.plot(l_epo, l_trloss)
plt.ylabel('train-loss')
plt.xlabel('epoch')
plt.savefig('trloss.png')
if __name__ == '__main__':
logging.set_verbosity_error()
args, logger = get_config()
nb = Niubility(args, logger)
nb.run()