A implementation of the NIPS 2017 paper : Prototypical Networks for Few Shot Learning using pytorch. In the model, somethings, such as learning rates or regression, may differ from the original paper.
I posted the details of the code in Korean on my blog, so if you are interested, please visit!
한글로 논문과 코드에 대해 작성한 글이 있으니 관심있으신 분은 확인해보세요!
-
cd prototypical
-
This commend will train the model. You can set the model and dataset option. Model's default is
protonet
and you can also setresnet
. Dataset's default isomniglot
and also you can chooseminiImagenet
python train.py -d omniglot -m protonet
-
If trained models are exists, you can test the model. Below command will test the entire model in
runs/exp_name
python eval.py
-
Train logs, saved model and configuration data were in
run/exp_name
. Logs are made bytensorboard
. So if you want to see more detail about train metrics, write commend on like this.tensorboard --logdir=runs
All parameters are present in arguments.py
. If you want to adjust the parameters, modify them and run the code.
Model | Reference Paper | This Repo |
---|---|---|
Omniglot 5-w 1-s | 98.8% | 98.8 ± 0.4% |
Omniglot 5-w 5-s | 99.7% | 99.5 ± 0.2% |
Omniglot 20-w 1-s | 96.0% | 95.4 ± 0.5% |
Omniglot 20-w 1-s | 98.9% | 98.6 ± 0.2% |
miniImagenet 5-w 1-s | 49.42 ± 0.78% | 43.5 ± 2% |
miniImagenet 5-w 5-s | 68.20 ± 0.66% | 63.7 ± 1.8% |
miniImagenet 5-w 1-s with resnet | - | will be soon |
miniImagenet 5-w 5-s with resnet | - | will be soon |
miniImagenet with resnet just tried it for interest.
Graph