forked from enoonIT/DANN
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathnew_main.py
89 lines (76 loc) · 3.11 KB
/
new_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import random
import time
import torch.backends.cudnn as cudnn
import torch.utils.data
from dataset.data_loader import get_dataloader
from logger import Logger
from models.model import get_net
from test import test
from train.optim import get_optimizer_and_scheduler
from train.utils import get_name, get_folder_name, ensure_dir, train_epoch, get_args, do_pretraining, simple_tuned
args = get_args()
print(args)
manual_seed = random.randint(1, 1000)
run_name = get_name(args, manual_seed)
print("Working on " + run_name)
log_folder = "logs/"
if args.tmp_log:
log_folder = "/tmp/"
folder_name = get_folder_name(args.source, args.target)
logger = Logger("{}/{}/{}".format(log_folder, folder_name, run_name))
model_root = 'models'
cuda = True
cudnn.benchmark = True
lr = args.lr
batch_size = args.batch_size
image_size = args.image_size
test_batch_size = 1000
if image_size > 100:
test_batch_size = 256
n_epoch = args.epochs
dann_weight = args.DANN_weight
entropy_weight = args.entropy_loss_weight
source_dataset_names = args.source
target_dataset_name = args.target
random.seed(manual_seed)
torch.manual_seed(manual_seed)
args.domain_classes = 1 + len(args.source)
dataloader_source = get_dataloader(args.source, batch_size, image_size, args.data_aug_mode, args.source_limit)
dataloader_target = get_dataloader(args.target, batch_size, image_size, args.data_aug_mode, args.target_limit)
print("Len source %d, len target %d" % (len(dataloader_source), len(dataloader_target)))
# load model
my_net = get_net(args)
# setup optimizer
optimizer, scheduler = get_optimizer_and_scheduler(args.optimizer, my_net, args.epochs, args.lr,
args.keep_pretrained_fixed)
if cuda:
my_net = my_net.cuda()
if args.deco_pretrain > 0:
do_pretraining(args.deco_pretrain, dataloader_source, dataloader_target, my_net, logger)
start = time.time()
# training
if args.data_aug_mode == simple_tuned:
tune_stats = True
else:
tune_stats = False
for epoch in range(n_epoch):
scheduler.step()
logger.scalar_summary("aux/lr", scheduler.get_lr()[0], epoch)
train_epoch(epoch, dataloader_source, dataloader_target, optimizer, my_net, logger, n_epoch, cuda, dann_weight,
entropy_weight, scheduler, args.generalization)
my_net.set_deco_mode("source")
for d, source in enumerate(source_dataset_names):
s_acc = test(source, epoch, my_net, image_size, d, test_batch_size, limit=args.source_limit, tune_stats=tune_stats)
if len(source_dataset_names) == 1:
source_name = "acc/source"
else:
source_name = "acc/source_%s" % source
logger.scalar_summary(source_name, s_acc, epoch)
my_net.set_deco_mode("target")
t_acc = test(target_dataset_name, epoch, my_net, image_size, len(args.source), test_batch_size, limit=args.target_limit, tune_stats=tune_stats)
logger.scalar_summary("acc/target", t_acc, epoch)
save_path = '{}/{}/{}_{}.pth'.format(model_root, folder_name, run_name, epoch)
print("Network saved to {}".format(save_path))
ensure_dir(save_path)
torch.save(my_net, save_path)
print('done, it took %g' % (time.time() - start))