Skip to content

Commit

Permalink
Feat: adapt to new Anylearn SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
phamour committed Mar 17, 2023
1 parent 2b40b64 commit 35b3488
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
6 changes: 0 additions & 6 deletions anylearn-run-dev677.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@
algorithm_force_update=True,
entrypoint="python train.py",
output="results",
dataset_id="DSETf6c6f7584cfb9b01d69f25bb9bc7",
dataset_hyperparam_name='datadir',
model_id="MODE885079a5405185639e8de08f5ffb",
model_hyperparam_name='modeldir',
pretrain_task_id="TRAId8be326641f78134edaa51385cae",
pretrain_hyperparam_name="checkpointdir",
hyperparams={
'epochs': 12,
'model': 'BiT-M-R50x1',
Expand Down
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import os

import anylearn
import numpy as np
import torch
import torchvision as tv
Expand All @@ -22,21 +23,21 @@
print(f"Args: {args}")
print("Environ: {}".format({k: v for k, v in os.environ.items()}))

datadir = args.datadir
datadir = anylearn.get_dataset("yhuang/CIFAR-100").download()
print(f"Loading dataset from {datadir}")
train_set = tv.datasets.CIFAR100(datadir, train=True, download=False)
valid_set = tv.datasets.CIFAR100(datadir, train=False, download=False)
print(f"Using a training set with {len(train_set)} images.")
print(f"Using a validation set with {len(valid_set)} images.")

modeldir = args.modeldir
modeldir = anylearn.get_model("yhuang/BiT-pretrained").download()
model_path = os.path.join(modeldir, f"{args.model}.npz")
print(f"Loading model from {model_path}")
model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes), zero_head=True)
model.load_from(np.load(model_path))
print(f"Loaded model: {model._get_name()}")

checkpointdir = args.checkpointdir
checkpointdir = anylearn.get_task_output('TRAIe5554d5c49c9bf326ff04ff9e1af').download()
weigth_path = os.path.join(checkpointdir, "weights_resnet.pkl")
print(f"Loading weights from {weigth_path}")
with open(weigth_path, 'rb') as f:
Expand Down

0 comments on commit 35b3488

Please # to comment.