7
7
8
8
from utils .image_utils import load_image
9
9
10
- CELEBA_LABEL_COLUMNS = [
11
- "5_o_Clock_Shadow" ,
12
- "Arched_Eyebrows" ,
13
- "Attractive" ,
14
- "Bags_Under_Eyes" ,
15
- "Bald" ,
16
- "Bangs" ,
17
- "Big_Lips" ,
18
- "Big_Nose" ,
19
- "Black_Hair" ,
20
- "Blond_Hair" ,
21
- "Blurry" ,
22
- "Brown_Hair" ,
23
- "Bushy_Eyebrows" ,
24
- "Chubby" ,
25
- "Double_Chin" ,
26
- "Eyeglasses" ,
27
- "Goatee" ,
28
- "Gray_Hair" ,
29
- "Heavy_Makeup" ,
30
- "High_Cheekbones" ,
31
- "Male" ,
32
- "Mouth_Slightly_Open" ,
33
- "Mustache" ,
34
- "Narrow_Eyes" ,
35
- "No_Beard" ,
36
- "Oval_Face" ,
37
- "Pale_Skin" ,
38
- "Pointy_Nose" ,
39
- "Receding_Hairline" ,
40
- "Rosy_Cheeks" ,
41
- "Sideburns" ,
42
- "Smiling" ,
43
- "Straight_Hair" ,
44
- "Wavy_Hair" ,
45
- "Wearing_Earrings" ,
46
- "Wearing_Hat" ,
47
- "Wearing_Lipstick" ,
48
- "Wearing_Necklace" ,
49
- "Wearing_Necktie" ,
50
- "Young"
51
- ]
52
-
53
10
class DataSequence (Sequence ):
54
11
"""
55
12
Keras Sequence object to train a model on larger-than-memory data.
56
13
"""
57
- def __init__ (self , df , data_root , batch_size , label_columns = CELEBA_LABEL_COLUMNS , resize_size = (64 , 64 ), flip_augment = True , mode = 'train' ):
14
+ def __init__ (self , df , data_root , batch_size , resize_size = (64 , 64 ), flip_augment = True , mode = 'train' ):
58
15
self .df = df
59
16
self .batch_size = batch_size
60
17
self .mode = mode
61
18
self .resize_size = resize_size
62
19
self .crop_pt_1 = (45 , 25 )
63
20
self .crop_pt_2 = (173 , 153 )
64
21
self .flip_augment = flip_augment
65
- self .label_columns = label_columns
22
+ # extract columns from df columns
23
+ self .label_columns = self .df .columns [1 :].tolist ()
66
24
67
25
# Take labels and a list of image locations in memory
68
26
self .labels = self .df [self .label_columns ].values
69
27
self .im_list = self .df ['Image_Name' ].apply (lambda x : os .path .join (data_root , x )).tolist ()
70
-
28
+ # Trigger a shuffle
71
29
self .on_epoch_end ()
72
30
73
31
def __len__ (self ):
@@ -84,21 +42,19 @@ def get_batch_labels(self, idx):
84
42
return self .labels [idx ]
85
43
86
44
def get_batch_features (self , idx ):
87
-
88
45
images = []
89
46
for im_idx in idx :
90
47
im = self .im_list [im_idx ]
91
48
loaded_image = load_image (im , self .resize_size , self .crop_pt_1 , self .crop_pt_2 )
92
49
if self .flip_augment and random .random () < 0.5 :
93
50
loaded_image = np .flip (loaded_image , 1 )
94
51
images .append (loaded_image )
95
-
96
52
# Fetch a batch of inputs
97
53
return np .array (images )
98
54
99
55
def __getitem__ (self , index ):
100
56
idx = self .indexes [index * self .batch_size :(index + 1 )* self .batch_size ]
101
-
57
+ # get the actual data
102
58
batch_x = self .get_batch_features (idx )
103
59
batch_y = np .clip (self .get_batch_labels (idx ).astype (np .float32 ), 0 , 1 )
104
- return (batch_x , batch_y ), batch_y
60
+ return (batch_x , batch_y ), batch_y
0 commit comments