-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrain.py
73 lines (50 loc) · 2.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import pdb
from utils import calc_cls_measures, move_to
def train(model, optimizer, train_loader, criterion, entropy_loss_func, opts):
""" Train for a single epoch """
y_probs = np.zeros((0, len(train_loader.dataset.CLASSES)), np.float)
y_trues = np.zeros((0), np.int)
losses = []
# Put model in training mode
model.train()
for i, (x_low, x_high, label) in enumerate(tqdm(train_loader)):
x_low, x_high, label = move_to([x_low, x_high, label], opts.device)
optimizer.zero_grad()
y, attention_map, patches, x_low = model(x_low, x_high)
entropy_loss = entropy_loss_func(attention_map)
loss = criterion(y, label) - entropy_loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), opts.clipnorm)
optimizer.step()
loss_value = loss.item()
losses.append(loss_value)
y_prob = F.softmax(y, dim=1)
y_probs = np.concatenate([y_probs, y_prob.detach().cpu().numpy()])
y_trues = np.concatenate([y_trues, label.cpu().numpy()])
train_loss_epoch = np.round(np.mean(losses), 4)
metrics = calc_cls_measures(y_probs, y_trues)
return train_loss_epoch, metrics
def evaluate(model, test_loader, criterion, entropy_loss_func, opts):
""" Evaluate a single epoch """
y_probs = np.zeros((0, len(test_loader.dataset.CLASSES)), np.float)
y_trues = np.zeros((0), np.int)
losses = []
# Put model in eval mode
model.eval()
for i, (x_low, x_high, label) in enumerate(tqdm(test_loader)):
x_low, x_high, label = move_to([x_low, x_high, label], opts.device)
y, attention_map, patches, x_low = model(x_low, x_high)
entropy_loss = entropy_loss_func(attention_map)
loss = criterion(y, label) - entropy_loss
loss_value = loss.item()
losses.append(loss_value)
y_prob = F.softmax(y, dim=1)
y_probs = np.concatenate([y_probs, y_prob.detach().cpu().numpy()])
y_trues = np.concatenate([y_trues, label.cpu().numpy()])
test_loss_epoch = np.round(np.mean(losses), 4)
metrics = calc_cls_measures(y_probs, y_trues)
return test_loss_epoch, metrics