-
Notifications
You must be signed in to change notification settings - Fork 113
/
Copy pathmodel.py
18 lines (15 loc) · 850 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, GlobalAveragePooling2D
# generic model design
def model_fn(actions):
# unpack the actions from the list
kernel_1, filters_1, kernel_2, filters_2, kernel_3, filters_3, kernel_4, filters_4 = actions
ip = Input(shape=(32, 32, 3))
x = Conv2D(filters_1, (kernel_1, kernel_1), strides=(2, 2), padding='same', activation='relu')(ip)
x = Conv2D(filters_2, (kernel_2, kernel_2), strides=(1, 1), padding='same', activation='relu')(x)
x = Conv2D(filters_3, (kernel_3, kernel_3), strides=(2, 2), padding='same', activation='relu')(x)
x = Conv2D(filters_4, (kernel_4, kernel_4), strides=(1, 1), padding='same', activation='relu')(x)
x = GlobalAveragePooling2D()(x)
x = Dense(10, activation='softmax')(x)
model = Model(ip, x)
return model