Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Greedy Soup selects only best individual model #10

Open
ivovdongen opened this issue Apr 24, 2023 · 11 comments
Open

Greedy Soup selects only best individual model #10

ivovdongen opened this issue Apr 24, 2023 · 11 comments

Comments

@ivovdongen
Copy link

ivovdongen commented Apr 24, 2023

Dear M. Wortsman,

I am experimenting with Model Soups for four-class brain tumor classification. I use ViT-B32 with AdamW and CategoricalCrossentropy (with label_smoothing). I randomly created 12 model configurations from the hyperparameter grid below. From my 12 models, the best and worst models have a validation accuracy of 91.964% and 84.226%, respectively. The Uniform Soup has a validation accuracy of 88.393%. My Greedy Soup, however, only includes the best individual model (i.e. no combination of weights yields accuracy > 91.964%). What can I do to have my Greedy Soup outperform the best individual model, besides creating a bigger model pool?

Many thanks in advance.

learning_rate = [3e-5, 1e-5, 5e-4]
weight_decay = [1e-6, 1e-7, 1e-8]
epochs = [12, 16, 20]
img_aug = [img_aug_low, img_aug_medium, img_aug_high]
label_smoothing = [0.1, 0.2, 0.3]

Where the different data augmentation intensities are defined as:

def img_aug_low(image, label):
    image = tf.image.random_flip_left_right(image)
    return image, label 

def img_aug_medium(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, 0.1)
    return image, label 

def img_aug_high(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, 0.1)
    image = tf.image.random_saturation(image, 0.7, 1.3)
    return image, label  
@Djoels
Copy link

Djoels commented Apr 25, 2023

I too have a different experience trying to achieve greedy soup that has higher performance.
On the original datasets, this doesn't seem to work properly:

My model soup for this is the 3 best models:

Testing model 3 of 72
[0% 0/102]	Acc: 88.67	Data (t) 4.460	Batch (t) 6.201
[20% 20/102]	Acc: 84.80	Data (t) 0.025	Batch (t) 1.575
[39% 40/102]	Acc: 84.48	Data (t) 0.025	Batch (t) 1.557
[59% 60/102]	Acc: 87.37	Data (t) 0.025	Batch (t) 1.579
[78% 80/102]	Acc: 88.98	Data (t) 0.025	Batch (t) 1.550
[98% 100/102]	Acc: 89.93	Data (t) 0.025	Batch (t) 1.585
Potential greedy soup val acc 0.9017692307692308, best so far 0.8996538461538461.
Adding to soup. New soup is ['model_18', 'model_30', 'model_55']

from then onwards, no more models get added.

This is what my final plot looks like:
newfig

@sorobedio
Copy link

I have read this paper. what i can tell is that the greedy soup tests in a sequential way starting from the top best model if a model added to the top best increase the performance then that model is added to the soup with the top best model and repeat the process until the end. in other words if there is no set of combination with the top best model which is better than the top best model alone then the soup is the top best model alone otherwise the greedy soup is the set of model combined with the top best model whose performance is the best as possible than the top best model and all other sets including the top best found following the greedy algorithm procedure.

@ivovdongen
Copy link
Author

Dear @Djoels

Since you are using a different codebase, a different dataset, and experience a different issue, it might be best to create an issue in this repository yourself.

@mitchellnw
Copy link
Contributor

@Djoels hmm.. your getting that plot following the steps in the repository? Or doing something custom? Because when we follow the steps we get the figure you see here https://github.com/mlfoundations/model-soups

@mitchellnw
Copy link
Contributor

@ivovdongen would love to help if possible! can you let me know some more detail about what is your task, what is your network, and how you are fine-tuning?

one thing it could be is introducing new params when fine-tuning. we always start from a shared zero-shot model, or model with a shared learned linear-probe head. we believe this helps.

concretely, if your model does not already have a classification head, good to first learn a linear probe then initialize with that when doing fine-tuning (as in the LP-FT paper). when fine-tuning multiple times for model soups, best to start out with the same learned linear probe head.

@ivovdongen
Copy link
Author

Dear @mitchellnw,

The task is to classify four different brain tumor types on MRI scans. The dataset contains 7023 samples in total. However, my experiment described above was executed on only 35% of the dataset and had a model pool of size 12. My network is created as follows:

def create_model():
    """
    Returns pre-trained ViT-B32 model  
    Args: None
    """

    ## Load ViT-B32 model
    feature_extractor = vit.vit_b32(
        image_size = img_size,
        activation = 'softmax',
        pretrained = True,
        include_top = False,
        pretrained_top = False,
        classes = 4)
    
    for layer in feature_extractor.layers:
      layer.trainable = False
    
    initializer = tf.keras.initializers.GlorotNormal(seed=2)
    
    vit_b32 = Sequential([
        layers.Input(shape=(224,224,3), name='input_image'),
        feature_extractor,
        layers.Dropout(0.2),
        layers.Dense(128, activation='gelu', kernel_initializer=initializer),
        layers.Dense(4, activation='softmax', kernel_initializer=initializer)
    ], name='vit_b32')
    
    return vit_b32

I create random hyperparameter configurations on the grid in my initial post and train these using the following function. For callbacks, I use EarlyStopping and ReduceLROnPlateau from Keras.

def train_model(train_dataset,
                valid_dataset,
                learning_rate, 
                weight_decay,
                epochs,
                img_aug,
                label_smoothing,
                save_dir = "models/"):
    """
    Returns saved trained model's path and validation evaluation scores
    Args:
    train_ds : Obj, Training set.
    test_ds : Obj, Validation set.
    learning_rate : Float, Learning rate.
    weight_decay : Float, AdamW weight decay.
    epochs : Int, Number of epochs.
    img_aug: Function, Mapping chosen data augmentation intensity.
    label_smoothing: Float, label smoothing for loss function.
    save_dir : Str, Model save directory
    """
    # Create directory for saving models
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    
    # Apply chosen data augmentation intensity
    train_dataset = train_dataset.unbatch().map(img_aug).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    # Build model
    tf.keras.backend.clear_session()
    vit_b32 = create_model()
    
    # Compile model
    vit_b32.compile(
            optimizer = tfa.optimizers.AdamW(learning_rate = learning_rate, weight_decay = weight_decay),
            loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing = label_smoothing),
            metrics = ["accuracy"]
        )
    
    
    # Train model
    vit_b32.fit(
        train_dataset,
        validation_data = valid_dataset,
        epochs = epochs,
        callbacks = callbacks_list
    )
    
    # Evaluate model
    val_loss, val_acc = vit_b32.evaluate(valid_dataset)
    
    # Save model
    model_save_path = save_dir + "model-lr" + str(learning_rate) + "_wd" + str(weight_decay) + "_ep" + str(epochs) + "_" + str(img_aug) + "_ls" + str(label_smoothing) + ".h5"
    if not os.path.isdir(save_dir):
        vit_b32.save_weights(model_save_path)
    else:
        # If model with same parameters already exists
        model_save_path = save_dir + "model-lr" + str(learning_rate) + "_wd" + str(weight_decay) + "_ep" + str(epochs) + "_" + str(img_aug) + "_ls" + str(label_smoothing) + "_" + str(random.choice(np.arange(0,10))) + ".h5"
        vit_b32.save_weights(model_save_path)
        
    # Clear GPU memory
    del vit_b32
    gc.collect()
    return model_save_path, val_acc
model_paths = []
valid_scores = []

for config in tqdm(parameters):
    model_save_path, valid_acc = train_model(train_dataset,
                                              valid_dataset,
                                              config["learning_rate"],
                                              config["weight_decay"],
                                              config["epochs"],
                                              config["img_aug"],
                                              config["label_smoothing"],
                                              save_dir = "models/")
    
    model_paths.append(model_save_path)
    valid_scores.append(valid_acc)

I am not familiar with zero-shot or linear-probe models, but I can try to look into this. More of my code for the experiment above can be found in this Kaggle repository. I am currently running an experiment that fine-tunes 24 models instead of 12 and uses 100% of the dataset (75% for training) instead of 35%. In addition, I included the dropout_probability [0.1, 0.2, 0.3] of the dropout layer in my network as an additional param for my grid.

Thanks in advance.

@mitchellnw
Copy link
Contributor

Thanks! Ok, as suspected I think the issue in terms of model soups performance is that you're introducing new parameters.

Looking at your model definiton

    vit_b32 = Sequential([
        layers.Input(shape=(224,224,3), name='input_image'),
        feature_extractor,
        layers.Dropout(0.2),
        layers.Dense(128, activation='gelu', kernel_initializer=initializer),
        layers.Dense(4, activation='softmax', kernel_initializer=initializer)
    ], name='vit_b32')

It seems that you have two new layers, layers.Dense(128, ...) and layers.Dense(4, ...).

If you run the following experiment my guess is things would work again:

  1. First, fine-tune the model once to get a starting point, lets call this model A.
  2. Now, use model A as your initialization if your future fine-tunings. Now averaging things should work.

@ivovdongen
Copy link
Author

@mitchellnw Thank you for the fast reply. I am afraid that I don't understand what you mean.

"1. First, fine-tune the model once to get a starting point, let's call this model A."

Do you mean that I can use the model architecture that I already had and fine-tune it once with an arbitrary hyperparameter setting? For example:

def create_model(w_init=none, b_init=none):
    """
    Returns pre-trained ViT-B32 model  
    Args:
    w_init: Matrix of weights to be initialized in the first dense layer.
    b_init: Matrix of bias' to be initialized in in the first dense layer.
    """

    # Load ViT-B32 model
    feature_extractor = vit.vit_b32(
        image_size = img_size,
        activation = 'softmax',
        pretrained = True,
        include_top = False,
        pretrained_top = False,
        classes = 4)
    
    for layer in feature_extractor.layers:
      layer.trainable = False
    
    vit_b32 = Sequential([
        layers.Input(shape=(224,224,3), name='input_image'),
        feature_extractor,
        layers.Dropout(0.2),
        layers.Dense(128, activation='gelu', kernel_initializer=w_init, bias_initializer=b_init),
        layers.Dense(4, activation='softmax')
    ], name='vit_b32')
    
    return vit_b32

First-time fine-tuning, with arbitrary params...

model = create_model()

model.compile( ... )

model_history = model.fit( ... )
val_loss, val_acc = model.evaluate( ... )

Now should I save the weights and biases of the first dense layer after this first-time fine-tuning (Model A)?

shared_w = model.layers[2].get_weights()[0]    #get w's of layers.Dense(128, ...)
shared_b = model.layers[2].get_weights()[1]    #get b's of layers.Dense(128, ...)

"2. Now, use model A as your initialization of your future fine-tunings. Now averaging things should work."

model = create_model(shared_w, shared_b)

This should load the weights and biases of the first dense layer, "layers.Dense(128, ...)", of Model A in the network at that same layer using the function above. Or should I perhaps get the w's and b's of both dense layers in my network and initialize them in their corresponding layer?

I would love to hear if I understand you correctly. Unfortunately, I have no experience with PyTorch, which is why I am trying to implement this in Tensorflow/Keras. Thanks in advance.

@mitchellnw
Copy link
Contributor

I'm not familiar with TF/Keras but at a high level looks good! Just trying to make your experimental setting more similar to what we consider.

Concretely, in the paper we use LP-FT (https://arxiv.org/abs/2202.10054), so we first train a linear probe, then use [featruizer, linear_probe] as the common initialization and fine-tune end-to-end (so even changing the weights of the featurizer).

Basically, the "soup starter" = the initialization you use when you start fine-tuning with different seeds, etc., should already be a good model. We accomplish this through LP-FT.

@ivovdongen
Copy link
Author

@mitchellnw Thanks, I will try to implement this. By the way, is the initialization model (or 'soup starter') also part of the model pool on which the soups are constructed, or is this excluded?

@mitchellnw
Copy link
Contributor

great! should be fine to include, this is typically what we do.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants