-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_resnet.py
134 lines (109 loc) · 4.59 KB
/
train_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader
from framework_pytorch.core import Trainer
from framework_pytorch.utils import loss_fuctions as LF, process_methods as P, utils as U
from framework_pytorch.utils.data_augmentation import RandomForegroundCrop
from customize.cmr_dataset import CMRDataset
from customize.timm_models import ResNet
from customize.model_cls import ClsModel
from customize import _tools as T
# training parameters
epochs = 10
learning_rate = 0.0001
train_batch_size = 10
eval_batch_size = 10
slice_ratio = None
save_frequency = None
log_train_image = False
log_validation_image = True
loss_function = {LF.CrossEntropy(): 1.0}
# model setting
pretrained = True
freeze = False
drop_rate = 0.5
input_channels = 1
# data setting
crop_size = [224,224]
foreground_ratio = 0.8
test_num_samples = 50
balance = True
mag = True
mag_name = 'org' if not mag else 'mag'
root_output_path = f'./results/resnet_{mag_name}'
# data and results path
data_path = './data_cls'
root_output_path = './results/' + root_output_path
data_list = np.load('./data_cls/5fold_random.npy', allow_pickle=True).item()
for ifold in range(1, 6):
output_path = root_output_path + f'/fold{ifold}/'
valid_1 = data_list[f'fold_{ifold}_c1']
valid_2 = data_list[f'fold_{ifold}_c2']
valid_3 = data_list[f'fold_{ifold}_c3']
train_1 = []
train_2 = []
train_3 = []
for i in range(1, 6):
if not i == ifold:
train_1 += data_list[f'fold_{i}_c1']
train_2 += data_list[f'fold_{i}_c2']
train_3 += data_list[f'fold_{i}_c3']
suffix = '*_pdorg.tif'
train_1 = T.fill_imgs(train_1, slice_ratio, suffix)
train_2 = T.fill_imgs(train_2, slice_ratio, suffix)
train_3 = T.fill_imgs(train_3, slice_ratio, suffix)
valid_1 = T.fill_imgs(valid_1, slice_ratio, suffix)
valid_2 = T.fill_imgs(valid_2, slice_ratio, suffix)
valid_3 = T.fill_imgs(valid_3, slice_ratio, suffix)
if not balance:
train_1 = train_1 + train_2 + train_3
train_2 = None
train_3 = None
valid_1 = valid_1 + valid_2 + valid_3
valid_2 = None
valid_3 = None
img_suffix = 'pdorg.tif'
lab_suffix = 'org_lab'
seg_suffix = 'pdorg_lab.png'
# set pre-process functions for image and label
pre = {lab_suffix: [lambda x: np.expand_dims(np.array([i==x-1 for i in range(3)], np.float32), 0)],
seg_suffix: [lambda x: [0]]
}
if mag:
pre.update({img_suffix: [P.Magnitude(), P.Transpose([2, 0, 1]), P.ExpandDim(0)]})
else:
pre.update({img_suffix: [P.ExpandDim(-1), P.Transpose([2, 0, 1]), P.ExpandDim(0)]})
aug = RandomForegroundCrop(img_suffix, seg_suffix, tar_size=crop_size, foreground_ratio=foreground_ratio)
# set device to gpu if gpu is available, otherwise use cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def collate_fn(batch):
batch_dict = None
for b in batch:
batch_dict = U.dict_concat(batch_dict, b)
for key in batch_dict:
batch_dict[key] = torch.tensor(batch_dict[key])
return batch_dict
# build pytorch dataset, see core/basic_dataset
train_set = CMRDataset(train_1, train_2, train_3, [img_suffix, lab_suffix], pre, aug, seg_suffix=seg_suffix, shuffle=True)
valid_set = CMRDataset(valid_1, valid_2, valid_3, [img_suffix, lab_suffix], pre, aug, seg_suffix=seg_suffix)
# build pytorch data loader, shuffle train set while training
trainloader = DataLoader(train_set, batch_size=train_batch_size, collate_fn=collate_fn)
validloader = DataLoader(valid_set, batch_size=eval_batch_size, collate_fn=collate_fn)
# get a random image for graph draw
# random_img = torch.tensor(train_set[0][img_suffix]).to(device)
random_img = None
# build model
net = ResNet(num_classes=3, pretrained=pretrained, input_channels=input_channels, freeze=freeze)
# init optimizer, adam is used here
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
# init the model
model = ClsModel(net, optimizer, device, img_suffix, lab_suffix, loss_functions=loss_function)
# init train and start train
trainer = Trainer(model)
trainer.train(trainloader, validloader,
epochs=epochs,
output_path=output_path,
log_train_image=log_train_image,
log_validation_image=log_validation_image,
log_graph_input=random_img,
save_frequency=save_frequency)