-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
102 lines (75 loc) · 3.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils.utils import *
from utils.dataloader import *
N_STEPS = 10
def load_data(args, path):
train_data = MovingMNIST(args, is_train=True, root=path, n_frames_input=N_STEPS, n_frames_output=N_STEPS, num_objects=[2])
val_data = MovingMNIST(args, is_train=False, root=path, n_frames_input=N_STEPS, n_frames_output=N_STEPS, num_objects=[2])
return train_data, val_data
def main(args):
start_epoch = 1
path = "/data2/jjlee_datasets/MovingMNIST/"
args.gpu_num = torch.cuda.device_count()
best_loss = 10000.
lr = args.lr
model = get_model(args)
ckpt_path = f'./model_ckpt/{args.model}_layer{args.num_layers}_model.pth'
ckpt_best_path = f'./model_ckpt/{args.model}_layer{args.num_layers}_best_model.pth'
if args.reload:
start_epoch, lr, optimizer_state_dict = load_checkpoint(model, args, ckpt_path)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.cuda()
train_data, val_data = load_data(args, path)
train_loader = DataLoader(train_data, shuffle=True, batch_size=args.batch_size)
val_loader = DataLoader(val_data, shuffle=False, batch_size=args.batch_size)
loss_fn = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
if args.reload:
optimizer.load_state_dict(optimizer_state_dict)
for epoch in tqdm(range(start_epoch, args.epochs+1), position=0):
model.train()
tq_train = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", total=len(train_loader), leave=False, position=1)
for idx, (x, y) in enumerate(tq_train):
x, y = x.cuda(), y.cuda()
optimizer.zero_grad()
logits = model(x)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
tq_train.set_postfix({'loss': '{:.03f}'.format(loss.item())})
if epoch % 10 == 0 or epoch == 1:
test_loss_avg = Averager()
model.eval()
tq_val = tqdm(val_loader, desc=f"Validation", total=len(val_loader), leave=False)
for idx, (x, y) in enumerate(tq_val):
x, y = x.cuda(), y.cuda()
logits = model(x)
loss = loss_fn(logits, y)
tq_val.set_postfix(val_loss=f'{loss.item():.03f}')
test_loss_avg.add(loss.item())
if best_loss > test_loss_avg.item():
best_loss = test_loss_avg.item()
print(f"Epoch: {epoch}, Best loss: {best_loss:.4f}")
save_checkpoint(model, optimizer, epoch, ckpt_best_path)
save_checkpoint(model, optimizer, epoch, ckpt_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--batch_size', default=1, type=int, help='batch size')
parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train')
parser.add_argument('--hidden_dim', type=int, default=64, help='number of hidden dim for ConvLSTM layers')
parser.add_argument('--input_dim', type=int, default=1, help='input channels')
parser.add_argument('--model', type=str, default='convlstm', help='name of the model')
parser.add_argument('--num_layers', type=int, default=4, help='number of layers')
parser.add_argument('--frame_num', type=int, default=10, help='number of frames')
parser.add_argument('--img_size', type=int, default=64, help='image size')
parser.add_argument('--reload', action='store_true', help='reload model')
args = parser.parse_args()
main(args)