-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
89 lines (80 loc) · 3.02 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import sys
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torch.utils.data.distributed import DistributedSampler
def get_network(args):
""" return given network
"""
if args.net == 'resnet18':
from models.MS_ResNet import resnet18
net = resnet18()
elif args.net == 'resnet34':
from models.MS_ResNet import resnet34
net = resnet34()
elif args.net == 'resnet104':
from models.MS_ResNet import resnet104
net = resnet104()
else:
print('the network name you have entered is not supported yet')
sys.exit()
return net
def get_training_dataloader(traindir,
sampler=None,
batch_size=16,
num_workers=2,
shuffle=True):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
ImageNet_training = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.AutoAugment(),
transforms.ToTensor(),
normalize,
]))
if sampler is not None:
ImageNet_training_loader = DataLoader(
ImageNet_training,
shuffle=shuffle,
num_workers=num_workers,
batch_size=batch_size,
pin_memory=True,
sampler=DistributedSampler(ImageNet_training))
else:
ImageNet_training_loader = DataLoader(ImageNet_training,
shuffle=shuffle,
num_workers=num_workers,
batch_size=batch_size,
pin_memory=True)
return ImageNet_training_loader
def get_test_dataloader(valdir,
sampler=None,
batch_size=16,
num_workers=2,
shuffle=False):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
ImageNet_test = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256), # 320
transforms.CenterCrop(224), # 288
transforms.ToTensor(),
normalize,
]))
if sampler is not None:
ImageNet_test_loader = DataLoader(
ImageNet_test,
shuffle=shuffle,
num_workers=num_workers,
batch_size=batch_size,
pin_memory=True,
sampler=DistributedSampler(ImageNet_test))
else:
ImageNet_test_loader = DataLoader(ImageNet_test,
shuffle=shuffle,
num_workers=num_workers,
batch_size=batch_size)
return ImageNet_test_loader