Skip to content

Commit fdca28b

Browse files
committed
Added logs with Tensorboard
1 parent 3a5d37a commit fdca28b

File tree

3 files changed

+79
-25
lines changed

3 files changed

+79
-25
lines changed

experiments/OneShotBuilder.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch.backends.cudnn as cudnn
23
import tqdm
34
from models.MatchingNetwork import MatchingNetwork
45
from torch.autograd import Variable
@@ -38,7 +39,11 @@ def build_experiment(self, batch_size, classes_per_set, samples_per_class, chann
3839
self.current_lr = 1e-03
3940
self.lr_decay = 1e-6
4041
self.wd = 1e-4
41-
self.matchingNet.cuda()
42+
self.isCudaAvailable = torch.cuda.is_available()
43+
if self.isCudaAvailable:
44+
cudnn.benchmark = True
45+
torch.cuda.manual_seed_all(0)
46+
self.matchingNet.cuda()
4247

4348
def run_training_epoch(self, total_train_batches):
4449
"""
@@ -75,7 +80,12 @@ def run_training_epoch(self, total_train_batches):
7580
x_support_set = x_support_set.view(size[0],size[1],size[4],size[2],size[3])
7681
size = x_target.size()
7782
x_target = x_target.view(size[0], size[3], size[1], size[2])
78-
acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(), x_target.cuda(), y_target.cuda())
83+
if self.isCudaAvailable:
84+
acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
85+
x_target.cuda(), y_target.cuda())
86+
else:
87+
acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
88+
x_target, y_target)
7989

8090
# Before the backward pass, use the optimizer object to zero all of the
8191
# gradients for the variables it will update (which are the learnable weights
@@ -122,10 +132,10 @@ def run_validation_epoch(self, total_val_batches):
122132
x_support_set, y_support_set, x_target, y_target = \
123133
self.data.get_batch(str_type='val', rotate_flag=False)
124134

125-
x_support_set = Variable(torch.from_numpy(x_support_set), requires_grad=False).float()
126-
y_support_set = Variable(torch.from_numpy(y_support_set), requires_grad=False).long()
127-
x_target = Variable(torch.from_numpy(x_target), requires_grad=False).float()
128-
y_target = Variable(torch.from_numpy(y_target), requires_grad=False).long()
135+
x_support_set = Variable(torch.from_numpy(x_support_set), volatile=True).float()
136+
y_support_set = Variable(torch.from_numpy(y_support_set), volatile=True).long()
137+
x_target = Variable(torch.from_numpy(x_target), volatile=True).float()
138+
y_target = Variable(torch.from_numpy(y_target), volatile=True).long()
129139

130140
# y_support_set: Add extra dimension for the one_hot
131141
y_support_set = torch.unsqueeze(y_support_set, 2)
@@ -141,8 +151,12 @@ def run_validation_epoch(self, total_val_batches):
141151
x_support_set = x_support_set.view(size[0], size[1], size[4], size[2], size[3])
142152
size = x_target.size()
143153
x_target = x_target.view(size[0], size[3], size[1], size[2])
144-
acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
145-
x_target.cuda(), y_target.cuda())
154+
if self.isCudaAvailable:
155+
acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
156+
x_target.cuda(), y_target.cuda())
157+
else:
158+
acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
159+
x_target, y_target)
146160

147161
iter_out = "val_loss: {}, val_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
148162
pbar.set_description(iter_out)
@@ -170,10 +184,10 @@ def run_testing_epoch(self, total_test_batches):
170184
x_support_set, y_support_set, x_target, y_target = \
171185
self.data.get_batch(str_type='test', rotate_flag=False)
172186

173-
x_support_set = Variable(torch.from_numpy(x_support_set), requires_grad=False).float()
174-
y_support_set = Variable(torch.from_numpy(y_support_set), requires_grad=False).long()
175-
x_target = Variable(torch.from_numpy(x_target), requires_grad=False).float()
176-
y_target = Variable(torch.from_numpy(y_target), requires_grad=False).long()
187+
x_support_set = Variable(torch.from_numpy(x_support_set), volatile=True).float()
188+
y_support_set = Variable(torch.from_numpy(y_support_set), volatile=True).long()
189+
x_target = Variable(torch.from_numpy(x_target), volatile=True).float()
190+
y_target = Variable(torch.from_numpy(y_target), volatile=True).long()
177191

178192
# y_support_set: Add extra dimension for the one_hot
179193
y_support_set = torch.unsqueeze(y_support_set, 2)
@@ -189,8 +203,12 @@ def run_testing_epoch(self, total_test_batches):
189203
x_support_set = x_support_set.view(size[0], size[1], size[4], size[2], size[3])
190204
size = x_target.size()
191205
x_target = x_target.view(size[0], size[3], size[1], size[2])
192-
acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
193-
x_target.cuda(), y_target.cuda())
206+
if self.isCudaAvailable:
207+
acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
208+
x_target.cuda(), y_target.cuda())
209+
else:
210+
acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
211+
x_target, y_target)
194212

195213
iter_out = "test_loss: {}, test_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
196214
pbar.set_description(iter_out)

logger.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
from tensorboard_logger import configure, log_value
3+
4+
class Logger(object):
5+
def __init__(self, log_dir):
6+
# clean previous logged data under the same directory name
7+
self._remove(log_dir)
8+
9+
# configure the project
10+
configure(log_dir)
11+
12+
self.global_step = 0
13+
14+
def log_value(self, name, value):
15+
log_value(name, value, self.global_step)
16+
return self
17+
18+
def step(self):
19+
self.global_step += 1
20+
21+
@staticmethod
22+
def _remove(path):
23+
""" param <path> could either be relative or absolute. """
24+
if os.path.isfile(path):
25+
os.remove(path) # remove the file
26+
elif os.path.isdir(path):
27+
import shutil
28+
shutil.rmtree(path) # remove dir and all contains

main.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,33 @@
77
## This source code is licensed under the MIT-style license found in the
88
## LICENSE file in the root directory of this source tree
99
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10-
1110
from datasets import omniglotNShot
1211
from option import Options
13-
14-
#Dummy test
15-
import torch
16-
import torch.nn as nn
17-
from torch.autograd import Variable
1812
from experiments.OneShotBuilder import OneShotBuilder
1913
import tqdm
14+
from logger import Logger
2015

2116
# Experiment Setup
2217
batch_size = 32
2318
fce = True
2419
classes_per_set = 20
2520
samples_per_class = 1
2621
channels = 1
27-
epochs = 200
28-
logs_path = "one_shot_outputs/"
29-
experiment_name = "one_shot_learning_embedding_{}_{}".format(samples_per_class, classes_per_set)
30-
22+
# Training setup
3123
total_epochs = 300
3224
total_train_batches = 1000
3325
total_val_batches = 100
3426
total_test_batches = 250
35-
27+
# Parse other options
3628
args = Options().parse()
29+
30+
LOG_DIR = args.log_dir + '/run-batchSize_{}-fce_{}-classes_per_set{}-samples_per_class{}-channels{}' \
31+
.format(batch_size,fce,classes_per_set,samples_per_class,channels)
32+
33+
# create logger
34+
logger = Logger(LOG_DIR)
35+
36+
3737
data = omniglotNShot.OmniglotNShotDataset(dataroot=args.dataroot, batch_size = batch_size,
3838
classes_per_set=classes_per_set,
3939
samples_per_class=samples_per_class)
@@ -50,13 +50,21 @@
5050
total_val_batches=total_val_batches)
5151
print("Epoch {}: val_loss: {}, val_accuracy: {}".format(e, total_val_c_loss, total_val_accuracy))
5252

53+
logger.log_value('train_loss', total_c_loss)
54+
logger.log_value('train_acc', total_accuracy)
55+
logger.log_value('val_loss', total_val_c_loss)
56+
logger.log_value('val_acc', total_val_accuracy)
57+
5358
if total_val_accuracy >= best_val: # if new best val accuracy -> produce test statistics
5459
best_val = total_val_accuracy
5560
total_test_c_loss, total_test_accuracy = obj_oneShotBuilder.run_testing_epoch(
5661
total_test_batches=total_test_batches)
5762
print("Epoch {}: test_loss: {}, test_accuracy: {}".format(e, total_test_c_loss, total_test_accuracy))
63+
logger.log_value('test_loss', total_test_c_loss)
64+
logger.log_value('test_acc', total_test_accuracy)
5865
else:
5966
total_test_c_loss = -1
6067
total_test_accuracy = -1
6168

6269
pbar_e.update(1)
70+
logger.step()

0 commit comments

Comments
 (0)