-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
106 lines (83 loc) · 4.06 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import warnings
from argparse import ArgumentParser
import pickle
import paddle
from utils.utils import get_config
from network import get_networks
from dataloader import load_data, load_localization_data
from loss_functions import MseDirectionLoss, DirectionOnlyLoss
from test_functions import detection_test, localization_test
parser = ArgumentParser()
parser.add_argument('--config', type=str, default='configs/config.yaml', help="training configuration")
parser.add_argument('--dataset_root', type=str, default=None)
parser.add_argument('--normal_class', type=str, default='capsule')
parser.add_argument('--save_dir', type=str, default='./output/')
warnings.filterwarnings('ignore')
def train(args, config):
direction_loss_only = config["direction_loss_only"]
normal_class = config["normal_class"]
learning_rate = float(config['learning_rate'])
num_epochs = config["num_epochs"]
lamda = config['lamda']
train_dataloader, test_dataloader = load_data(args, config)
test_loc_dataloader, ground_truth = load_localization_data(args, config)
vgg, model = get_networks(config)
if direction_loss_only:
criterion = DirectionOnlyLoss()
else:
criterion = MseDirectionLoss(lamda)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(),
learning_rate=learning_rate)
losses = []
best_detection_roc_auc = 0
best_loc_roc_auc = 0
for epoch in range(num_epochs + 1):
model.train()
epoch_loss = 0
for data in train_dataloader:
X = data[0]
if X.shape[1] == 1:
X = X.repeat(1, 3, 1, 1)
output_pred = model.forward(X)
output_real = vgg(X)
total_loss = criterion(output_pred, output_real)
# Add loss to the list
epoch_loss += total_loss.item()
losses.append(total_loss.item())
# Compute gradients
total_loss.backward()
# Adjust weights
optimizer.step()
model.clear_gradients()
print('[Train] epoch [{}/{}], loss:{:.4f} class:{}'.format(epoch, num_epochs, epoch_loss, normal_class))
if (epoch % 10 == 0 and epoch != 0) or epoch == num_epochs:
detection_roc_auc = detection_test(model=model,
vgg=vgg,
test_dataloader=test_dataloader,
config=config)
localization_roc_auc = localization_test(model=model,
vgg=vgg,
test_dataloader=test_loc_dataloader,
ground_truth=ground_truth,
config=config)
print(f"[Eval] {normal_class} class RocAUC detection: {detection_roc_auc} "
f"localization: {localization_roc_auc} at epoch {epoch}")
if detection_roc_auc > best_detection_roc_auc:
print(f"[Eval] best detection_roc_auc at epoch {epoch}")
best_detection_roc_auc = detection_roc_auc
if localization_roc_auc > best_loc_roc_auc:
print(f"[Eval] best localization_roc_auc at epoch {epoch}")
best_loc_roc_auc = localization_roc_auc
os.makedirs(f"./output/{normal_class}", exist_ok=True)
paddle.save(model.state_dict(), os.path.join(args.save_dir, f'{normal_class}/model_{epoch}.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(args.save_dir, f'{normal_class}/model_{epoch}.pdopt'))
paddle.save(model.state_dict(), os.path.join(args.save_dir, f'{normal_class}/final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(args.save_dir, f'{normal_class}/final_model.pdopt'))
def main():
args = parser.parse_args()
config = get_config(args.config)
config['normal_class'] = args.normal_class
train(args, config)
if __name__ == '__main__':
main()