Skip to content

Commit 202b59f

Browse files
committed
Removed hardcoded label columns for CelebA.
1 parent 5be9c36 commit 202b59f

File tree

1 file changed

+6
-50
lines changed

1 file changed

+6
-50
lines changed

datasets/celeba/dataloader.py

+6-50
Original file line numberDiff line numberDiff line change
@@ -7,67 +7,25 @@
77

88
from utils.image_utils import load_image
99

10-
CELEBA_LABEL_COLUMNS = [
11-
"5_o_Clock_Shadow",
12-
"Arched_Eyebrows",
13-
"Attractive",
14-
"Bags_Under_Eyes",
15-
"Bald",
16-
"Bangs",
17-
"Big_Lips",
18-
"Big_Nose",
19-
"Black_Hair",
20-
"Blond_Hair",
21-
"Blurry",
22-
"Brown_Hair",
23-
"Bushy_Eyebrows",
24-
"Chubby",
25-
"Double_Chin",
26-
"Eyeglasses",
27-
"Goatee",
28-
"Gray_Hair",
29-
"Heavy_Makeup",
30-
"High_Cheekbones",
31-
"Male",
32-
"Mouth_Slightly_Open",
33-
"Mustache",
34-
"Narrow_Eyes",
35-
"No_Beard",
36-
"Oval_Face",
37-
"Pale_Skin",
38-
"Pointy_Nose",
39-
"Receding_Hairline",
40-
"Rosy_Cheeks",
41-
"Sideburns",
42-
"Smiling",
43-
"Straight_Hair",
44-
"Wavy_Hair",
45-
"Wearing_Earrings",
46-
"Wearing_Hat",
47-
"Wearing_Lipstick",
48-
"Wearing_Necklace",
49-
"Wearing_Necktie",
50-
"Young"
51-
]
52-
5310
class DataSequence(Sequence):
5411
"""
5512
Keras Sequence object to train a model on larger-than-memory data.
5613
"""
57-
def __init__(self, df, data_root, batch_size, label_columns=CELEBA_LABEL_COLUMNS, resize_size=(64, 64), flip_augment=True, mode='train'):
14+
def __init__(self, df, data_root, batch_size, resize_size=(64, 64), flip_augment=True, mode='train'):
5815
self.df = df
5916
self.batch_size = batch_size
6017
self.mode = mode
6118
self.resize_size = resize_size
6219
self.crop_pt_1 = (45, 25)
6320
self.crop_pt_2 = (173, 153)
6421
self.flip_augment = flip_augment
65-
self.label_columns = label_columns
22+
# extract columns from df columns
23+
self.label_columns = self.df.columns[1:].tolist()
6624

6725
# Take labels and a list of image locations in memory
6826
self.labels = self.df[self.label_columns].values
6927
self.im_list = self.df['Image_Name'].apply(lambda x: os.path.join(data_root, x)).tolist()
70-
28+
# Trigger a shuffle
7129
self.on_epoch_end()
7230

7331
def __len__(self):
@@ -84,21 +42,19 @@ def get_batch_labels(self, idx):
8442
return self.labels[idx]
8543

8644
def get_batch_features(self, idx):
87-
8845
images = []
8946
for im_idx in idx:
9047
im = self.im_list[im_idx]
9148
loaded_image = load_image(im, self.resize_size, self.crop_pt_1, self.crop_pt_2)
9249
if self.flip_augment and random.random() < 0.5:
9350
loaded_image = np.flip(loaded_image, 1)
9451
images.append(loaded_image)
95-
9652
# Fetch a batch of inputs
9753
return np.array(images)
9854

9955
def __getitem__(self, index):
10056
idx = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
101-
57+
# get the actual data
10258
batch_x = self.get_batch_features(idx)
10359
batch_y = np.clip(self.get_batch_labels(idx).astype(np.float32), 0, 1)
104-
return (batch_x, batch_y), batch_y
60+
return (batch_x, batch_y), batch_y

0 commit comments

Comments
 (0)