-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
89 lines (78 loc) · 2.46 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys
import numpy as np
import torch
import torch.utils.data as Data
import torch.optim as optim
import matplotlib.pyplot as plt
from characterLoader import characterLoader
from nets.NewNet import NewNet
bs = 8
lr = 0.0001
epoch = 50
stepLength = 20
trainpath = "dataset/train_data"
valpath = "dataset/val_data"
traindata = characterLoader(trainpath)
valdata = characterLoader(valpath)
data_loader = Data.DataLoader(
dataset=traindata, # torch TensorDataset format
batch_size=bs, # mini batch size
shuffle=True, # 要不要打乱数据 (打乱比较好)
num_workers=2, # 多线程来读数据
)
val_loader = Data.DataLoader(
dataset=valdata,
batch_size=bs,
num_workers=2
)
net = NewNet()
optimizer = optim.Adam(net.parameters(), lr=lr) # optimize all cnn parameters
criterion = torch.nn.CrossEntropyLoss()
# criterion = torch.nn.NLLLoss()
# criterion = torch.nn.functional.nll_loss()
def train(e, data_loader):
sum_loss = 0
for step, (path, input, label) in enumerate(data_loader):
optimizer.zero_grad()
outputs = net.forward(input)
# print("outputs ", outputs.size())
# print("label ", label.size())
loss = criterion(outputs, label)
loss.backward()
optimizer.step()
sum_loss += loss.item()
print('train epoch %d loss:%.03f'
% (e, sum_loss))
return sum_loss
def validate(val_loader):
sum_loss = 0
for step, (path, input, label) in enumerate(val_loader):
outputs = net.forward(input)
loss = criterion(outputs, label)
sum_loss += loss.item()
print('validation epoch %d loss:%.03f'
% (e, sum_loss))
return sum_loss
if __name__ == "__main__":
val_loss = np.inf
trainloss = []
valloss = []
for e in range(epoch):
train_loss = train(e, data_loader)
trainloss.append(train_loss)
if e % 5 == 0:
loss = validate(val_loader)
valloss.append(loss)
if loss < val_loss:
val_loss = loss
torch.save(net.state_dict(), "model/negative_minLoss_model.pkl")
if e == epoch - 1:
torch.save(net.state_dict(), "model/negative_epoch_{}_model.pkl".format(e))
plt.figure(0)
x = [i for i in range(len(trainloss))]
plt.plot(x, trainloss)
plt.savefig("train_loss.jpg")
plt.close(0)
x = [i for i in range(len(valloss))]
plt.plot(x, valloss)
plt.savefig("val_loss.jpg")