-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunexpwb.py
52 lines (41 loc) · 1.22 KB
/
runexpwb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
'''
Run experiment with wandb logging.
Usage:
python runexpwb.py --setting bag
Note: wandb isn't compatible with running scripts in subdirs:
e.g., python -m exps.chess.chessgfn
So we call wandb init here.
'''
import torch
import wandb
import options
from attrdict import AttrDict
from exps.bag import bag
from exps.tfbind8 import tfbind8_oracle
from exps.tfbind10 import tfbind10
from exps.qm9str import qm9str
from exps.sehstr import sehstr
setting_calls = {
'bag': lambda args: bag.main(args),
'tfbind8': lambda args: tfbind8_oracle.main(args),
'tfbind10': lambda args: tfbind10.main(args),
'qm9str': lambda args: qm9str.main(args),
'sehstr': lambda args: sehstr.main(args),
}
def main(args):
print(f'Using {args.setting=} ...')
exp_f = setting_calls[args.setting]
exp_f(args)
return
if __name__ == '__main__':
args = options.parse_args()
wandb.init(project=args.wandb_project,
entity=args.wandb_entity,
config=args,
mode=args.wandb_mode)
args = AttrDict(wandb.config)
args.run_name = wandb.run.name if wandb.run.name else 'None'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'{device=}')
args.device = device
main(args)