Skip to content

Encountering error with number of batches per epoch #164

Open
@sivaramakrishnan-rajaraman

Description

I am using Tensorflow 2.7 and trying to reproduce the results with the code in https://modal-python.readthedocs.io/en/latest/content/examples/Keras_integration.html. Instead of importing KerasClassifier from Keras Wrappers that through a deprecation warning, I have installed SciKeras and imported the KerasClassifier. The following is the code:

from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from scikeras.wrappers import KerasClassifier

# build function for the Keras' scikit-learn API
def create_keras_model():
    """
    This function compiles and returns a Keras model.
    Should be passed to KerasClassifier in the Keras scikit-learn API.
    """

    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10, activation='softmax'))

    model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])

    return model

classifier = KerasClassifier(create_keras_model)

import numpy as np
from tensorflow.keras.datasets import mnist

# read training data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 28, 28, 1).astype('float32') / 255
X_test = X_test.reshape(10000, 28, 28, 1).astype('float32') / 255
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

# assemble initial data
n_initial = 100
initial_idx = np.random.choice(range(len(X_train)), size=n_initial, replace=False)
X_initial = X_train[initial_idx]
y_initial = y_train[initial_idx]

# generate the pool
# remove the initial data from the training dataset
X_pool = np.delete(X_train, initial_idx, axis=0)
y_pool = np.delete(y_train, initial_idx, axis=0)

from modAL.models import ActiveLearner

# initialize ActiveLearner
learner = ActiveLearner(
    estimator=classifier,
    X_training=X_initial, y_training=y_initial,
    verbose=1
)

I have taken 100 initial data samples as aforementioned. But running the code showed me this: That is, the data has been divided into 4 batches whereas in the documentation example, we see the data is divided into "n_initial" batches.

4/4 [==============================] - 11s 32ms/step - loss: 2.2845 - accuracy: 0.1200

The I ran this part of the code where i wanted to generate 10 queries and each time I take 100 instances.

n_queries = 10
for idx in range(n_queries):
    print('Query no. %d' % (idx + 1))
    query_idx, query_instance = learner.query(X_pool, n_instances=100, verbose=0)
    learner.teach(
        X=X_pool[query_idx], y=y_pool[query_idx], only_new=True,
        verbose=1
    )
    # remove queried instance from pool
    X_pool = np.delete(X_pool, query_idx, axis=0)
    y_pool = np.delete(y_pool, query_idx, axis=0)

And I got this: The data has been divided into 4 batches for each query but in the documentation, it was divided into 100 batches (100/100).

Query no. 1
4/4 [==============================] - 0s 6ms/step - loss: 2.3089 - accuracy: 0.0700
Query no. 2
4/4 [==============================] - 1s 7ms/step - loss: 2.2615 - accuracy: 0.2400
Query no. 3
4/4 [==============================] - 0s 6ms/step - loss: 2.3040 - accuracy: 0.0800
Query no. 4
4/4 [==============================] - 0s 7ms/step - loss: 2.2629 - accuracy: 0.1300
Query no. 5
4/4 [==============================] - 0s 6ms/step - loss: 2.3116 - accuracy: 0.1300
Query no. 6
4/4 [==============================] - 0s 6ms/step - loss: 2.3290 - accuracy: 0.0900
Query no. 7
4/4 [==============================] - 0s 6ms/step - loss: 2.3691 - accuracy: 0.0300
Query no. 8
4/4 [==============================] - 0s 6ms/step - loss: 2.2598 - accuracy: 0.2500
Query no. 9
4/4 [==============================] - 0s 7ms/step - loss: 2.2914 - accuracy: 0.1000
Query no. 10
4/4 [==============================] - 0s 6ms/step - loss: 2.2839 - accuracy: 0.1500

Requesting assistance in this regard.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions