diff --git a/cnn_class2/fashion.py b/cnn_class2/fashion.py index db845f8a..96141013 100644 --- a/cnn_class2/fashion.py +++ b/cnn_class2/fashion.py @@ -52,17 +52,17 @@ def y2indicator(Y): model.add(Conv2D(input_shape=(28, 28, 1), filters=32, kernel_size=(3, 3))) model.add(BatchNormalization()) model.add(Activation('relu')) -model.add(MaxPooling2D()) +model.add(MaxPool2D(pool_size=(2,2))) model.add(Conv2D(filters=64, kernel_size=(3, 3))) model.add(BatchNormalization()) model.add(Activation('relu')) -model.add(MaxPooling2D()) +model.add(MaxPool2D(pool_size=(2,2))) model.add(Conv2D(filters=128, kernel_size=(3, 3))) model.add(BatchNormalization()) model.add(Activation('relu')) -model.add(MaxPooling2D()) +model.add(MaxPool2D(pool_size=(2,2))) model.add(Flatten()) model.add(Dense(units=300))