-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_db_vae.py
63 lines (53 loc) · 1.82 KB
/
run_db_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Optional
import os
import random
import numpy as np
import torch
from DB_VAE.setup import args, DEVICE
from DB_VAE.evaluator import Evaluator
from DB_VAE.trainer import Trainer
from DB_VAE.logger import logger
# Set path to current directory
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
# Setting seeds for reproducability
def set_seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def make_evaluator(args=args, device=DEVICE, trained_model: Optional = None):
"""Creates an Evaluator object which is ready to .eval on, or .eval_on_setups in case of the automated experience. """
return Evaluator(
args=args,
device=device,
model=trained_model
)
if __name__ == "__main__":
set_seed(seed=args.seed)
if args.run_mode == 'train':
logger.info("Running training only")
trainer = Trainer(args, DEVICE)
trainer.train()
elif args.run_mode == 'eval':
logger.info("Running evaluation only")
evaluator = make_evaluator()
evaluator.eval_on_setups('run_mode')
elif args.run_mode == 'perturb':
logger.info("Running perturbation only")
trainer = Trainer(args, DEVICE)
trainer.perturb()
elif args.run_mode == 'interpolate':
logger.info("Running interpolation only")
trainer = Trainer(args, DEVICE)
trainer.interpolate()
else:
logger.info("Running training and evaluation of this model")
trainer = Trainer(args, DEVICE)
trainer.train()
trainer.best_and_worst()
evaluator = make_evaluator(trained_model=trainer.model)
evaluator.eval_on_setups('run_mode')