-
Notifications
You must be signed in to change notification settings - Fork 61
/
utils_cifar.py
28 lines (20 loc) · 1 KB
/
utils_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torchvision
import torchvision.transforms as transforms
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def _get_transform():
return transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
def get_train_data_loader():
transform = _get_transform()
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
return torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
def get_test_data_loader(download):
transform = _get_transform()
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=download, transform=transform)
return torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)