-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
90 lines (76 loc) · 3.24 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import numpy as np
from paddle.io import Dataset
from PIL import Image
def has_file_allowed_extension(filename, extensions):
return filename.lower().endswith(extensions)
def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):
instances = []
directory = os.path.expanduser(directory)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x):
return has_file_allowed_extension(x, extensions)
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
instances.append(item)
return instances
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class ImageFolder(Dataset):
def __init__(self, root, transform=None, localization_test=False):
extensions = ('.jpg', '.jpeg', '.png')
self.root = root
self.transform = transform
classes, class_to_idx = self._find_classes(self.root, localization_test)
samples = make_dataset(self.root, class_to_idx, extensions)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)
self.loader = pil_loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
def _find_classes(self, dir, localization_test=False):
if localization_test:
classes = [d.name for d in os.scandir(dir) if d.is_dir() and d.name != "good"]
else:
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
sample = np.array(sample)
sample = np.transpose(sample, [2, 0, 1]).astype('float32')
sample /= 255.0
return sample, target
def __len__(self):
return len(self.samples)