Skip to content

Latest commit

 

History

History
58 lines (46 loc) · 2.7 KB

README.md

File metadata and controls

58 lines (46 loc) · 2.7 KB

home > examples > fixmatch

Semi-Supervised Image Classification with FixMatch

Setup

Required environment variables

export PYTHONPATH=$PYTHONPATH:.
export ML_DATA="path to where you want the datasets saved"
export PROJECT="ObjaxSSL"
export SSL_PATH=examples/fixmatch

Data preparation

# Download datasets
CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_datasets.py
cp $ML_DATA/$PROJECT/svhn-test.tfrecord $ML_DATA/$PROJECT/svhnx-test.tfrecord

# Create unlabeled datasets
CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/cifar10 $ML_DATA/$PROJECT/cifar10-train.tfrecord &
CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/cifar100 $ML_DATA/$PROJECT/cifar100-train.tfrecord &
CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/stl10 $ML_DATA/$PROJECT/stl10-train.tfrecord $ML_DATA/$PROJECT/stl10-unlabeled.tfrecord &
CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/svhn $ML_DATA/$PROJECT/svhn-train.tfrecord &
CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/svhnx $ML_DATA/$PROJECT/svhn-train.tfrecord $ML_DATA/$PROJECT/svhn-extra.tfrecord &
wait

# Create semi-supervised subsets
for seed in 0 1 2 3 4 5; do
    for size in 40 100 250 1000 4000; do
        CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/cifar10 $ML_DATA/$PROJECT/cifar10-train.tfrecord &
        CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/svhn $ML_DATA/$PROJECT/svhn-train.tfrecord &
        CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/svhnx $ML_DATA/$PROJECT/svhn-train.tfrecord $ML_DATA/$PROJECT/svhn-extra.tfrecord &
    done
    for size in 400 1000 2500 10000; do
        CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/cifar100 $ML_DATA/$PROJECT/cifar100-train.tfrecord &
    done
    CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=1000 $ML_DATA/$PROJECT/SSL/stl10 $ML_DATA/$PROJECT/stl10-train.tfrecord $ML_DATA/$PROJECT/stl10-unlabeled.tfrecord &
    wait
done
CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=1 --size=5000 $ML_DATA/$PROJECT/stl10 $ML_DATA/stl10-train.tfrecord $ML_DATA/stl10-unlabeled.tfrecord

Training

# FixMatch
python $SSL_PATH/fixmatch.py --dataset=cifar10.3@250-0 --unlabeled=cifar10 --uratio=5 --augment='CTA(sm,sm,sm)'

Tensorboard

tensorboard --port 6006 --logdir_spec=experiments