-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
59 lines (51 loc) · 1.36 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import sys
import random
import torch
import utils.multiprocessing as mpu
from config import get_cfg
from train import train
from test import test
def main():
"""
Main function to spawn the train and test process.
"""
cfg = get_cfg()
if cfg.TRAIN.ENABLE:
if cfg.NUM_GPUS > 1:
torch.multiprocessing.spawn(
mpu.run,
nprocs=cfg.NUM_GPUS,
args=(
cfg.NUM_GPUS,
train,
cfg.DIST_INIT_METHOD,
cfg.SHARD_ID,
cfg.NUM_SHARDS,
cfg.DIST_BACKEND,
cfg,
),
daemon=False,
)
else:
train(cfg=cfg)
if cfg.TEST.ENABLE:
if cfg.NUM_GPUS > 1:
torch.multiprocessing.spawn(
mpu.run,
nprocs=cfg.NUM_GPUS,
args=(
cfg.NUM_GPUS,
test,
cfg.DIST_INIT_METHOD,
cfg.SHARD_ID,
cfg.NUM_SHARDS,
cfg.DIST_BACKEND,
cfg,
),
daemon=False,
)
else:
test(cfg=cfg)
if __name__ == "__main__":
torch.multiprocessing.set_start_method("forkserver")
main()