-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
82 lines (68 loc) · 2.55 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import torch
import ruamel.yaml as yaml
import logging
from attrdict import AttrDict
from torch.utils.data import DataLoader, RandomSampler
# Custom Modules
from utils import weights_init, visualize_progress
from model import Generator, Discriminator, Discriminator_MD
from dataset import CustomDataset
from trainer import Trainer
if __name__ == "__main__":
# load configs
with open("config/hyperparameters.yaml") as f:
cfg = yaml.safe_load(f)
cfg = AttrDict(cfg)
# set device
cfg.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
# create dataset
dataset = CustomDataset(os.path.abspath(cfg.folder))
####### Define data loader and models ##########
train_sampler = RandomSampler(data_source=dataset,
replacement=True,
num_samples=int(1e100), # make the dataloader "infinite"
)
dataloader = DataLoader(dataset,
batch_size=cfg.batch_size,
sampler=train_sampler,
pin_memory=True,
#shuffle=True,
num_workers=4,
)
G = Generator(z_dim=cfg.z_dim,
n_feat=cfg.n_feat,
)
if cfg.minibatch_discrimination:
logging.info("Using minibatch discrimination")
D = Discriminator_MD(batch_size=cfg.batch_size, n_feat=cfg.n_feat)
else:
#logging.info("Using minibatch discr")
D = Discriminator(n_feat=cfg.n_feat)
G.apply(weights_init)
D.apply(weights_init)
G = G.to(cfg.device)
D = D.to(cfg.device)
params_G = G.parameters()
params_D = D.parameters()
optimizerG = torch.optim.Adam(params=params_G,
lr=cfg.lr_G,
betas=(cfg.beta1_G, cfg.beta2_G),
weight_decay=cfg.weight_decay_G,
)
optimizerD = torch.optim.Adam(params=params_D,
lr=cfg.lr_D,
betas=(cfg.beta1_D, cfg.beta2_D),
weight_decay=cfg.weight_decay_D,
)
trainer = Trainer(cfg,
G,
D,
dataloader,
optimizerG,
optimizerD,
)
# Train models
trainer.train()
# create gif of progress
visualize_progress(cfg.save_folder, cfg.max_epoch)