forked from will-wiki/softmasked-bert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_pytorch.py
100 lines (86 loc) · 4.02 KB
/
train_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import pandas
import torch.nn as nn
from data import Data
from config import Config
from torch.optim import Adam
from torch.utils.data import DataLoader
from pytorch_data import CSC_DataSet, collate_fn
from model.SoftMasked_Bert import SoftMasked_Bert
vocab_file = '../pretrain_model/chinese_wwm_ext_pytorch/vocab.txt'
config = Config()
train_dataset = CSC_DataSet('./data/narts/SIGHAN15_train.csv', config.vocab_file)
test_dataset = CSC_DataSet('./data/narts/SIGHAN15_test.csv', config.vocab_file)
train_generator = DataLoader(train_dataset, batch_size=config.batch_size, collate_fn=collate_fn)
test_generator = DataLoader(test_dataset, batch_size=config.batch_size, collate_fn=collate_fn)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = Config()
model = SoftMasked_Bert(config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
optimizer = Adam(model.parameters(), lr=config.lr)
criterion_n, criterion_b = nn.NLLLoss(), nn.BCELoss()
gama = 0.8
def dev(dataloader, model):
model.eval()
avg_loss, total_element = 0, 0
d_correct, c_correct = 0, 0
for i, batch_data in enumerate(dataloader):
batch_input_ids, batch_input_mask, \
batch_segment_ids, batch_output_ids, batch_labels = batch_data
batch_input_ids = batch_input_ids.to(device)
batch_input_mask = batch_input_mask.to(device)
batch_segment_ids = batch_segment_ids.to(device)
batch_output_ids = batch_output_ids.to(device)
batch_labels = batch_labels.to(device)
output, prob = model(batch_input_ids, batch_input_mask, batch_segment_ids)
# correct = out.argmax(dim=-1).eq(data["output_ids"]).sum().item()
output = output.argmax(dim=-1)
c_correct += sum([output[i].equal(batch_output_ids[i]) for i in range(len(output))])
prob = torch.round(prob).long()
d_correct += sum([prob[i].squeeze().equal(batch_labels[i]) for i in range(len(prob))])
total_element += len(batch_data)
print("d_acc=", d_correct / total_element, "c_acc", c_correct / total_element)
for epoch in range(config.epoch):
model.train()
avg_loss, total_element = 0, 0
d_correct, c_correct = 0, 0
for i, batch_data in enumerate(train_generator):
batch_input_ids, batch_input_mask, \
batch_segment_ids, batch_output_ids, batch_labels = batch_data
batch_input_ids = batch_input_ids.to(device)
batch_input_mask = batch_input_mask.to(device)
batch_segment_ids = batch_segment_ids.to(device)
batch_output_ids = batch_output_ids.to(device)
batch_labels = batch_labels.to(device)
output, prob = model(batch_input_ids, batch_input_mask, batch_segment_ids)
loss_b = criterion_b(prob, batch_labels.float())
loss_n = criterion_n(output.reshape(-1, output.size()[-1]), \
batch_output_ids.reshape(-1))
loss = gama * loss_n + (1 - gama) * loss_b
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
# correct = out.argmax(dim=-1).eq(data["output_ids"]).sum().item()
output = output.argmax(dim=-1)
c_correct += sum([output[i].equal(batch_output_ids[i]) for i in range(len(output))])
prob = torch.round(prob).long()
opp = prob[0].squeeze()
opp1 = batch_labels[0]
d_correct += sum([prob[i].squeeze().equal(batch_labels[i]) for i in range(len(prob))])
avg_loss += loss.item()
# total_correct += c_correct
# # total_element += data["label"].nelement()
total_element += len(batch_data)
post_fix = {
"epoch": epoch,
"iter": i,
"avg_loss": avg_loss / (i + 1),
"d_acc": d_correct / total_element,
"c_acc": c_correct / total_element
}
# if i % self.log_freq == 0:
# data_loader.write(str(post_fix))
print("EP%d_, avg_loss=" % (epoch), avg_loss / len(train_generator), "d_acc=",
d_correct / total_element, "c_acc", c_correct / total_element)
dev(test_generator, model)