Skip to content

Implementation of Swin Transformers in TensorFlow along with converted pre-trained models, code for off-the-shelf classification and fine-tuning.

License

Notifications You must be signed in to change notification settings

sayakpaul/swin-transformers-tf

Repository files navigation

Swin for win!

TensorFlow 2.8 Models on TF-Hub

This repository provides TensorFlow / Keras implementations of different Swin Transformer [1, 2] variants by Liu et al. and Chen et al. It also provides the TensorFlow / Keras models that have been populated with the original Swin pre-trained params available from [3, 4]. These models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model objects and one can call all the utility functions on them (example: .summary()).

Refer to the "Using the models" section to get started.

I find Swin Transformers interesting because they induce a sense of hierarchies by using shifted windows. Multi-scale representations like that are crucial to get good performance in tasks like object detection and segmentation. teaser Source

"Swin for win!" however doesn't portray my architecture bias -- I found it cool and hence kept it.

Table of contents

Conversion

TensorFlow / Keras implementations are available in swins/models.py. All model configurations are in swins/model_configs.py. Conversion utilities are in convert.py. To run the conversion utilities, first install all the dependencies listed in requirements.txt. Additionally, nnstall timm from source:

pip install -q git+https://github.com/rwightman/pytorch-image-models

Models

Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/swin/1. You can fully inspect the architecture of the TF-Hub models like so:

import tensorflow as tf

model_gcs_path = "gs://tfhub-modules/sayakpaul/swin_tiny_patch4_window7_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = model(dummy_inputs)
print(model.summary(expand_nested=True))

Results

The table below provides a performance summary (ImageNet-1k validation set):

model_name top1_acc(%) top5_acc(%) orig_top1_acc(%)
swin_base_patch4_window7_224 85.134 97.48 85.2
swin_large_patch4_window7_224 86.252 97.878 86.3
swin_s3_base_224 83.958 96.532 84
swin_s3_small_224 83.648 96.358 83.7
swin_s3_tiny_224 82.034 95.864 82.1
swin_small_patch4_window7_224 83.178 96.24 83.2
swin_tiny_patch4_window7_224 81.184 95.512 81.2
swin_base_patch4_window12_384 86.428 98.042 86.4
swin_large_patch4_window12_384 87.272 98.242 87.3

The base and large models were first pre-trained on the ImageNet-22k dataset and then fine-tuned on the ImageNet-1k dataset.

in1k-eval directory provides details on how these numbers were generated. Original scores for all the models except for the s3 ones were gathered from here. Scores for the s3 model were gathered from here.

Using the models

Pre-trained models:

When doing transfer learning try using the models that were pre-trained on the ImageNet-22k dataset. All the base and large models listed here were pre-trained on the ImageNet-22k dataset. Refer to the model collection page on TF-Hub to know more.

These models also output attention weights from each of the Transformer blocks. Refer to this notebook for more details. Additionally, the notebook shows how to obtain the attention maps for a given image.

Randomly initialized models:

import tensorflow as tf

from swins import SwinTransformer

cfg = dict(
    patch_size=4,
    window_size=7,
    embed_dim=128,
    depths=(2, 2, 18, 2),
    num_heads=(4, 8, 16, 32),
)
 
swin_base_patch4_window7_224 = SwinTransformer(
    name="swin_base_patch4_window7_224", **cfg
)
print("Model instantiated, attempting predictions...")
random_tensor = tf.random.normal((2, 224, 224, 3))
outputs = swin_base_patch4_window7_224(random_tensor, training=False)

print(outputs.shape)

print(swin_base_patch4_window7_224.count_params() / 1e6)

To initialize a network with say, 5 classes do:

cfg = dict(
    patch_size=4,
    window_size=7,
    embed_dim=128,
    depths=(2, 2, 18, 2),
    num_heads=(4, 8, 16, 32),
    num_classes=5,
)

swin_base_patch4_window7_224 = SwinTransformer(
    name="swin_base_patch4_window7_224", **cfg
)

To view different model configurations, refer to swins/model_configs.py.

References

[1] Swin Transformer: Hierarchical Vision Transformer using Shifted Windows Liu et al.

[2] Searching the Search Space of Vision Transformer by Chen et al.

[3] Swin Transformers GitHub

[4] AutoFormerV2 GitHub

Acknowledgements

  • timm library source code for the awesome codebase. I've copy-pasted and modified a huge chunk of code from there. I've also mentioned it from the respective scripts.
  • Willi Gierke for helping with a non-trivial model serialization hack.
  • ML-GDE program for providing GCP credits that supported my experiments.

About

Implementation of Swin Transformers in TensorFlow along with converted pre-trained models, code for off-the-shelf classification and fine-tuning.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •