-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
102 lines (80 loc) · 2.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import sys
import argparse
import logging
import random
import torch
import gorilla
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, 'provider'))
sys.path.append(os.path.join(BASE_DIR, 'model'))
sys.path.append(os.path.join(BASE_DIR, 'model', 'pointnet2'))
sys.path.append(os.path.join(BASE_DIR, 'utils'))
from create_dataloaders import create_dataloaders
from solver import Solver, get_logger
from Net import Net, Loss
def get_parser():
parser = argparse.ArgumentParser(
description="Pose Estimation")
# pretrain
parser.add_argument("--gpus",
type=str,
default="0",
help="gpu num")
parser.add_argument("--config",
type=str,
help="path to config file")
args_cfg = parser.parse_args()
return args_cfg
def init():
args = get_parser()
exp_name = args.config.split("/")[-1].split(".")[0]
log_dir = os.path.join("log", exp_name)
if not os.path.isdir("log"):
os.makedirs("log")
if not os.path.isdir(log_dir):
os.makedirs(log_dir)
cfg = gorilla.Config.fromfile(args.config)
cfg.exp_name = exp_name
cfg.log_dir = log_dir
cfg.ckpt_dir = os.path.join(log_dir, 'ckpt')
if not os.path.isdir(cfg.ckpt_dir):
os.makedirs(cfg.ckpt_dir)
cfg.gpus = args.gpus
logger = get_logger(
level_print=logging.INFO, level_save=logging.WARNING, path_file=log_dir+"/training_logger.log")
gorilla.utils.set_cuda_visible_devices(gpu_ids=cfg.gpus)
return logger, cfg
if __name__ == "__main__":
logger, cfg = init()
logger.warning(
"************************ Start Logging ************************")
logger.info(cfg)
logger.info("using gpu: {}".format(cfg.gpus))
random.seed(cfg.rd_seed)
torch.manual_seed(cfg.rd_seed)
torch.cuda.manual_seed(cfg.rd_seed)
torch.cuda.manual_seed_all(cfg.rd_seed)
# model
logger.info("=> creating model ...")
model = Net(cfg.pose_net)
start_epoch = 1
start_iter = 0
model = model.cuda()
count_parameters = sum(gorilla.parameter_count(model).values())
logger.warning("#Total parameters : {}".format(count_parameters))
loss = Loss(cfg.loss).cuda()
# dataloader
dataloaders = create_dataloaders(cfg.train_dataset)
for k in dataloaders.keys():
dataloaders[k].dataset.reset()
# solver
Trainer = Solver(model=model,
loss=loss,
dataloaders=dataloaders,
logger=logger,
cfg=cfg,
start_epoch=start_epoch,
start_iter=start_iter)
Trainer.solve()
logger.info('\nFinish!\n')