diff --git a/anylearn-run-dev677.py b/anylearn-run-dev677.py index da61ecc..f514bbc 100644 --- a/anylearn-run-dev677.py +++ b/anylearn-run-dev677.py @@ -9,12 +9,6 @@ algorithm_force_update=True, entrypoint="python train.py", output="results", - dataset_id="DSETf6c6f7584cfb9b01d69f25bb9bc7", - dataset_hyperparam_name='datadir', - model_id="MODE885079a5405185639e8de08f5ffb", - model_hyperparam_name='modeldir', - pretrain_task_id="TRAId8be326641f78134edaa51385cae", - pretrain_hyperparam_name="checkpointdir", hyperparams={ 'epochs': 12, 'model': 'BiT-M-R50x1', diff --git a/train.py b/train.py index 602d759..ee5febc 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ import argparse import os +import anylearn import numpy as np import torch import torchvision as tv @@ -22,21 +23,21 @@ print(f"Args: {args}") print("Environ: {}".format({k: v for k, v in os.environ.items()})) -datadir = args.datadir +datadir = anylearn.get_dataset("yhuang/CIFAR-100").download() print(f"Loading dataset from {datadir}") train_set = tv.datasets.CIFAR100(datadir, train=True, download=False) valid_set = tv.datasets.CIFAR100(datadir, train=False, download=False) print(f"Using a training set with {len(train_set)} images.") print(f"Using a validation set with {len(valid_set)} images.") -modeldir = args.modeldir +modeldir = anylearn.get_model("yhuang/BiT-pretrained").download() model_path = os.path.join(modeldir, f"{args.model}.npz") print(f"Loading model from {model_path}") model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes), zero_head=True) model.load_from(np.load(model_path)) print(f"Loaded model: {model._get_name()}") -checkpointdir = args.checkpointdir +checkpointdir = anylearn.get_task_output('TRAIe5554d5c49c9bf326ff04ff9e1af').download() weigth_path = os.path.join(checkpointdir, "weights_resnet.pkl") print(f"Loading weights from {weigth_path}") with open(weigth_path, 'rb') as f: