-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfull_train.py
77 lines (63 loc) · 2.55 KB
/
full_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Starting point for training the models on the whole dataset
"""
import os
import argparse
from datetime import datetime
from configparser import ConfigParser
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from utils.data import SVHNDataset
from trainer import Trainer
def argparser():
"""
Configure the command-line arguments parser
:return: the arguments parsed
"""
"""Command line argument parser"""
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str)
return parser.parse_args()
if __name__ == '__main__':
args = argparser()
conf = ConfigParser()
conf.read(args.config)
conf.set('model', 'device', 'cuda' if torch.cuda.is_available() else 'cpu')
input_resize = conf.getint("preprocessing", "resize")
train_transforms = transforms.Compose([
transforms.Resize((64, 64)),
transforms.RandomCrop(54),
transforms.Resize((input_resize, input_resize)),
transforms.ToTensor(),
transforms.Normalize([0.39954964, 0.3988817, 0.41280591],
[0.23269807, 0.2355513, 0.23580605])
])
test_transforms = transforms.Compose([
transforms.Resize((input_resize, input_resize)),
transforms.ToTensor(),
transforms.Normalize([0.39954964, 0.3988817, 0.41280591],
[0.23269807, 0.2355513, 0.23580605])
])
train_data = SVHNDataset(
metadata_path=conf.get("paths", "train_metadata"),
data_dir=conf.get("paths", "data_dir"),
crop_percent=conf.getfloat("preprocessing", "crop_percent"),
transform=train_transforms)
train_loader = DataLoader(train_data,
batch_size=conf.getint("model", "batch_size"),
shuffle=True,
num_workers=4,
pin_memory=True)
datetime_str = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
conf.set("paths", "results", os.path.join(conf.get("paths", "results"),
conf.get("model", "name"),
datetime_str))
os.makedirs(conf.get("paths", 'results'), exist_ok=True)
conf.set("paths", "checkpoints",
os.path.join(conf.get("paths", "checkpoints"),
conf.get("model", "name"),
datetime_str))
os.makedirs(conf.get("paths", "checkpoints"), exist_ok=True)
trainer = Trainer(conf)
trainer.full_train_model(train_loader)