-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathtrain.py
73 lines (60 loc) · 2.16 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.backends.cudnn as cudnn
from torch import optim, nn
import models
from trainer import Trainer
from datasets import load_data
from utils import load_embeddings, load_checkpoint, parse_opt
cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def set_trainer(config):
# load a checkpoint
if config.checkpoint is not None:
# load data
train_loader = load_data(config, 'train', False)
model, optimizer, word_map, start_epoch = load_checkpoint(config.checkpoint, device)
print('\nLoaded checkpoint from epoch %d.\n' % (start_epoch - 1))
# or initialize model
else:
start_epoch = 0
# load data
train_loader, embeddings, emb_size, word_map, n_classes, vocab_size = load_data(config, 'train', True)
model = models.make(
config = config,
n_classes = n_classes,
vocab_size = vocab_size,
embeddings = embeddings,
emb_size = emb_size
)
optimizer = optim.Adam(
params = filter(lambda p: p.requires_grad, model.parameters()),
lr = config.lr
)
# loss functions
loss_function = nn.CrossEntropyLoss()
# move to device
model = model.to(device)
loss_function = loss_function.to(device)
trainer = Trainer(
num_epochs = config.num_epochs,
start_epoch = start_epoch,
train_loader = train_loader,
model = model,
model_name = config.model_name,
loss_function = loss_function,
optimizer = optimizer,
lr_decay = config.lr_decay,
dataset_name = config.dataset,
word_map = word_map,
grad_clip = config.grad_clip,
print_freq = config.print_freq,
checkpoint_path = config.checkpoint_path,
checkpoint_basename = config.checkpoint_basename,
tensorboard = config.tensorboard,
log_dir = config.log_dir
)
return trainer
if __name__ == '__main__':
config = parse_opt()
trainer = set_trainer(config)
trainer.run_train()