Skip to content

ViTベースの手法による画像分類のサンプルコード

Notifications You must be signed in to change notification settings

SyunkiTakase/ViT_Classification_Sample

Repository files navigation

ViT_Classification_Sample

ViTベースの手法による画像分類のサンプルコード

動作環境

ライブラリのバージョン
  • cuda 12.1
  • python 3.6.9
  • torch 1.8.1+cu111
  • torchaudio 0.8.1
  • torchinfo 1.5.4
  • torchmetrics 0.8.2
  • torchsummary 1.5.1
  • torchvision 0.9.1+cu111
  • timm 0.5.4
  • tlt 0.1.0
  • numpy 1.19.5
  • Pillow 8.4.0
  • scikit-image 0.17.2
  • scikit-learn 0.24.2
  • tqdm 4.64.0
  • opencv-python 4.5.1.48
  • opencv-python-headless 4.6.0.66
  • scipy 1.5.4
  • matplotlib 3.3.4
  • mmcv 1.7.1

ファイル&フォルダ一覧

学習用コード等
ファイル名 説明
vit_train.py ViTを学習するコード.
mae_train.py ViTを学習するコード(Masked Autoencoder(MAE)で事前学習したTransformer Encoderを使用).
cait_train.py CaiTを学習するコード.
deit_train.py DeiTを学習するコード.
trainer.py 学習ループのコード.
vis_att.py Attention Mapを可視化するコード.
vis_class_att.py Class AttentionのAttention Mapを可視化するコード.
attention_rollout.py Attention RolloutでAttention Mapを可視化するコード.
make_graph.py 学習曲線を可視化するコード.

実行手順

環境設定

先述の環境を整えてください.

学習済みモデルのダウンロード

MAEで学習したTransformer Encoderをファインチューニングする場合は学習済みのモデルをダウンロードしてください.

学習済みのモデル

MAE:https://github.com/facebookresearch/mae

学習

ハイパーパラメータは適宜調整してください.

※ ImageNetなどの大きなデータセットで学習する場合はRandAugment,MixUp,CutMix,Random ErasingなどのData Augmentationの追加やWarmUp Epoch,Label Smoothing,Stochastic Depthなどを導入してください.

ViT,MAE,DeiT,CaiTのファインチューニング(CIFAR-10)

ViTの学習

python3 vit_train.py --epoch 10 --batch_size 128 --amp --dataset cifar10 --warmup_t 0 --warmup_lr_init 0

ViT(MAEで学習済みのTransformer Encoder)の学習

python3 mae_train.py --epoch 10 --batch_size 128 --amp --dataset cifar10 --warmup_t 0 --warmup_lr_init 0

DeiTの学習

python3 deit_train.py --epoch 10 --batch_size 128 --amp --dataset cifar10 --warmup_t 0 --warmup_lr_init 0

CaiTの学習

python3 cait_train.py --epoch 10 --batch_size 128 --amp --dataset cifar10 --warmup_t 0 --warmup_lr_init 0
ViT,MAE,DeiT,CaiTのファインチューニング(CIFAR-100)

ViTの学習

python3 vit_train.py --epoch 10 --batch_size 128 --amp --dataset cifar100 --warmup_t 0 --warmup_lr_init 0

ViT(MAEで学習済みのTransformer Encoder)の学習

python3 mae_train.py --epoch 10 --batch_size 128 --amp --dataset cifar100 --warmup_t 0 --warmup_lr_init 0

DeiTの学習

python3 deit_train.py --epoch 10 --batch_size 128 --amp --dataset cifar100 --warmup_t 0 --warmup_lr_init 0

CaiTの学習

python3 cait_train.py --epoch 10 --batch_size 128 --amp --dataset cifar100 --warmup_t 0 --warmup_lr_init 0

アテンションマップの可視化

ViT,MAE,DeiTの場合
python3 vis_att.py 
Attention Rolloutの場合
python3 attention_rollout.py 
CaiTの場合
python3 vis_class_att.py 

参考文献

About

ViTベースの手法による画像分類のサンプルコード

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages