Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Fix] eval hook resume best acc #33

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmfewshot/classification/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def train_model(model: Union[MMDataParallel, MMDistributedDataParallel],
all_data_loader,
num_test_tasks=meta_test_cfg['num_episodes'],
**eval_cfg),
priority='LOW')
# make eval hook (45) run before checkpoint saver hook (50)
linyq17 marked this conversation as resolved.
Show resolved Hide resolved
priority=45)

# user-defined hooks
if cfg.get('custom_hooks', None):
Expand Down
4 changes: 4 additions & 0 deletions mmfewshot/classification/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def before_run(self, runner: Runner) -> None:
warnings.warn('runner.meta is None. Creating an empty one.')
runner.meta = dict()
runner.meta.setdefault('hook_msgs', dict())
if runner.meta['hook_msgs'].get('best_score', False):
self.best_score = runner.meta['hook_msgs']['best_score']
runner.logger.info(
f'Previous best score is: {self.best_score}.')
self.best_ckpt_path = runner.meta['hook_msgs'].get(
'best_ckpt', None)

Expand Down
45 changes: 45 additions & 0 deletions tests/test_classification_runtime/test_meta_test_eval_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,48 @@ def test_epoch_eval_hook():
max_epochs=1)
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 1)


def test_resume_eval_hook():
test_set_loader = DataLoader(
toy_meta_test_dataset(),
batch_size=1,
sampler=None,
num_workers=0,
shuffle=False)
query_loader = DataLoader(
toy_meta_test_dataset().query(),
batch_size=1,
sampler=None,
num_workers=0,
shuffle=False)
support_loader = DataLoader(
toy_meta_test_dataset().support(),
batch_size=1,
sampler=None,
num_workers=0,
shuffle=False)
model = ExampleModel()
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=model.parameters()))
test_dataset = ExampleDataset()
loader = DataLoader(test_dataset, batch_size=1)
# test EvalHook
with tempfile.TemporaryDirectory() as tmpdir:
eval_hook = MetaTestEvalHook(
support_loader,
query_loader,
test_set_loader,
num_test_tasks=10,
meta_test_cfg=dict(support={}, query={}))
runner = mmcv.runner.EpochBasedRunner(
model=model,
optimizer=optimizer,
work_dir=tmpdir,
logger=logging.getLogger(),
max_epochs=1)
runner.register_hook(eval_hook)
runner.meta = {'hook_msgs': {'best_score': 99.0}}
runner.run([loader], [('train', 1)], 1)
assert eval_hook.best_score == 99.0