-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain_I2G.py
235 lines (196 loc) · 8.49 KB
/
train_I2G.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import time
from options.train_options import TrainOptions
from models import create_model
import numpy as np
from utils.visualizer import Visualizer
import logging
import os
from collections import OrderedDict
from IPython import embed
import torch
import torchvision
import torchvision.transforms as transforms
import pdb
from torch.utils.data import DataLoader
from data.I2G_dataset import I2GDataset
from utils import pidfile, util
import utils.logging
from PIL import Image
def train(opt):
torch.manual_seed(opt.seed)
dset = I2GDataset(opt, os.path.join(opt.real_im_path, 'train'))
# halves batch size since each batch returns both real and fake ims
dset.get32frames()
dl = DataLoader(dset, batch_size=opt.batch_size,
num_workers=opt.nThreads, pin_memory=False,
shuffle=True)
# setup class labeling
assert(opt.fake_class_id in [0, 1])
fake_label = opt.fake_class_id
real_label = 1 - fake_label
logging.info("real label = %d" % real_label)
logging.info("fake label = %d" % fake_label)
dataset_size = len(dset)
logging.info('# total images = %d' % dataset_size)
logging.info('# total batches = %d' % len(dl))
# setup model and visualizer
model = create_model(opt)
epoch, best_val_metric, best_val_ep = model.setup(opt)
visualizer_losses = model.loss_names + \
[n + '_val' for n in model.loss_names]
visualizer = Visualizer(opt, visualizer_losses, model.visual_names)
total_batches = epoch * len(dl)
t_data = 0
now = time.strftime("%c")
logging.info(
'================ Training Loss (%s) ================\n' % now)
while True:
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
for i, ims in enumerate(dl):
images = ims['img'].to(opt.gpu_ids[0])
masks = ims['mask'].to(opt.gpu_ids[0])
labels = ims['label'].to(opt.gpu_ids[0])
batch_im = images
batch_mask = masks
batch_label = labels
batch_data = dict(ims=batch_im, masks=batch_mask,
labels=batch_label)
# for i in range(20):
# img_save = transforms.ToPILImage()(ims['img'][i]).convert("RGB")
# img_save.save('test_masks/face_{}.png'.format(i))
iter_start_time = time.time()
if total_batches % opt.print_freq == 0:
# time to load data
t_data = iter_start_time - iter_data_time
total_batches += 1
epoch_iter += 1
model.reset()
model.set_input(batch_data)
model.optimize_parameters()
if epoch_iter % opt.print_freq == 0:
losses = model.get_current_losses()
t = time.time() - iter_start_time
visualizer.print_current_losses(
epoch, float(epoch_iter) / len(dl), total_batches,
losses, t, t_data)
visualizer.plot_current_losses(total_batches, losses)
if epoch_iter % opt.save_latest_freq == 0:
logging.info('saving the latest model (epoch %d, total_batches %d)' %
(epoch, total_batches))
model.save_networks('latest', epoch, best_val_metric,
best_val_ep)
model.reset()
iter_data_time = time.time()
# do validation loop at end of each epoch
model.eval()
val_start_time = time.time()
val_losses = validate(model, opt)
visualizer.plot_current_losses(epoch, val_losses)
logging.info("Printing validation losses:")
visualizer.print_current_losses(
epoch, 0.0, total_batches, val_losses,
time.time()-val_start_time, 0.0)
model.train()
model.reset()
assert(model.net_D.training)
# update best model and determine stopping conditions
if val_losses[model.val_metric + '_val'] > best_val_metric:
logging.info("Updating best val mode at ep %d" % epoch)
logging.info("The previous values: ep %d, val %0.2f" %
(best_val_ep, best_val_metric))
best_val_ep = epoch
best_val_metric = val_losses[model.val_metric + '_val']
logging.info("The updated values: ep %d, val %0.2f" %
(best_val_ep, best_val_metric))
model.save_networks('bestval', epoch, best_val_metric, best_val_ep)
with open(os.path.join(model.save_dir, 'bestval_ep.txt'), 'a') as f:
f.write('ep: %d %s: %f\n' % (epoch, model.val_metric + '_val',
best_val_metric))
elif epoch > (best_val_ep + 5*opt.patience):
logging.info("Current epoch %d, last updated val at ep %d" %
(epoch, best_val_ep))
logging.info("Stopping training...")
break
elif best_val_metric == 1:
logging.info("Reached perfect val accuracy metric")
logging.info("Stopping training...")
break
elif opt.max_epochs and epoch > opt.max_epochs:
logging.info("Reached max epoch count")
logging.info("Stopping training...")
break
logging.info("Best val ep: %d" % best_val_ep)
logging.info("Best val metric: %0.2f" % best_val_metric)
# save final plots at end of each epoch
visualizer.save_final_plots()
if epoch % opt.save_epoch_freq == 0 and epoch > 0:
logging.info('saving the model at the end of epoch %d, total batches %d' % (
epoch, total_batches))
model.save_networks('latest', epoch, best_val_metric,
best_val_ep)
model.save_networks(epoch, epoch, best_val_metric, best_val_ep)
logging.info('End of epoch %d \t Time Taken: %d sec' %
(epoch, time.time() - epoch_start_time))
model.update_learning_rate(
metric=val_losses[model.val_metric + '_val'])
epoch += 1
dset.get32frames()
dl = DataLoader(dset, batch_size=opt.batch_size,
num_workers=opt.nThreads, pin_memory=False,
shuffle=True)
# save model at the end of training
visualizer.save_final_plots()
model.save_networks('latest', epoch, best_val_metric,
best_val_ep)
model.save_networks(epoch, epoch, best_val_metric, best_val_ep)
logging.info("Finished Training")
def validate(model, opt):
# --- start evaluation loop ---
logging.info('Starting evaluation loop ...')
model.reset()
assert(not model.net_D.training)
val_dset = I2GDataset(opt, os.path.join(opt.real_im_path, 'val'), is_val=True)
val_dset.get32frames()
val_dl = DataLoader(val_dset, batch_size=opt.batch_size,
num_workers=opt.nThreads, pin_memory=False,
shuffle=True)
val_losses = OrderedDict([(k + '_val', util.AverageMeter())
for k in model.loss_names])
fake_label = opt.fake_class_id
real_label = 1 - fake_label
val_start_time = time.time()
for i, ims in enumerate(val_dl):
images = ims['img'].to(opt.gpu_ids[0])
masks = ims['mask'].to(opt.gpu_ids[0])
labels = ims['label'].to(opt.gpu_ids[0])
inputs = dict(ims=images,
masks=masks,
labels=labels)
# forward pass
model.reset()
model.set_input(inputs)
model.test(True)
losses = model.get_current_losses()
# update val losses
for k, v in losses.items():
val_losses[k + '_val'].update(v, n=len(inputs['labels']))
# get average val losses
for k, v in val_losses.items():
val_losses[k] = v.avg
return val_losses
if __name__ == '__main__':
options = TrainOptions(print_opt=False)
opt = options.parse()
# lock active experiment directory and write out options
os.makedirs(os.path.join(opt.checkpoints_dir, opt.name), exist_ok=True)
pidfile.exit_if_job_done(os.path.join(opt.checkpoints_dir, opt.name))
options.print_options(opt)
# configure logging file
logging_file = os.path.join(opt.checkpoints_dir, opt.name, 'log.txt')
utils.logging.configure(logging_file, append=False)
# run train loop
train(opt)
# mark done and release lock
pidfile.mark_job_done(os.path.join(opt.checkpoints_dir, opt.name))