-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
65 lines (43 loc) · 1.33 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# -*- coding: utf-8 -*-
import os, tarfile, random
import sys
from six.moves import urllib
from torchvision.transforms import Compose, CenterCrop, RandomCrop, ToTensor, Scale, RandomHorizontalFlip
from dataset import DatasetFromFolder
from PIL import Image
CROP_SIZE = 256
class RandomRotate(object):
def __call__(self, img):
return img.rotate(90 * random.randint(0, 4))
class RandomScale(object):
def __call__(self, img):
shape = img.size
ratio = min(shape)/CROP_SIZE
scale = random.uniform(ratio, 1)
return img.resize((int(shape[0]*scale), int(shape[1]*scale)), Image.BICUBIC)
def LR_transform(crop_size):
return Compose([
Scale(crop_size//8),
ToTensor(),
])
def HR_2_transform(crop_size):
return Compose([
Scale(crop_size//4),
ToTensor(),
])
def HR_4_transform(crop_size):
return Compose([
Scale(crop_size//2),
ToTensor(),
])
def HR_8_transform(crop_size):
return Compose([
RandomCrop(crop_size),
RandomScale(),
#RandomRotate(),
RandomHorizontalFlip(),
])
def get_training_set(train_dir=None):
return DatasetFromFolder(train_dir,
LR_transform=LR_transform(CROP_SIZE),
HR_8_transform=HR_8_transform(CROP_SIZE))