Skip to content

Commit e9e598d

Browse files
authored
Update train_superpixels_graph_classification.py
1 parent 82ced67 commit e9e598d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

semisupervised_MNIST_CIFAR10/pre-training/train/train_superpixels_graph_classification.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def train_epoch(model, optimizer, device, data_loader, epoch, drop_percent, temp
2121
t0 = time.time()
2222

2323
for iter, (batch_graphs, batch_labels, batch_snorm_n, batch_snorm_e) in enumerate(data_loader):
24-
24+
aug_batch_graphs = dgl.unbatch(batch_graphs)
2525
aug_list1, aug_list2 = aug.aug_double(aug_batch_graphs, aug_type)
2626
batch_graphs, batch_snorm_n, batch_snorm_e= aug.collate_batched_graph(aug_list1)
2727
aug_batch_graphs, aug_batch_snorm_n, aug_batch_snorm_e= aug.collate_batched_graph(aug_list2)
@@ -74,4 +74,4 @@ def train_epoch(model, optimizer, device, data_loader, epoch, drop_percent, temp
7474
epoch_loss /= (iter + 1)
7575
print('Epoch: [{:>2d}] Loss: [{:.4f}]'.format(epoch + 1, epoch_loss))
7676

77-
return epoch_loss, optimizer
77+
return epoch_loss, optimizer

0 commit comments

Comments
 (0)