diff --git a/mmfewshot/classification/core/evaluation/eval_hooks.py b/mmfewshot/classification/core/evaluation/eval_hooks.py index 877e36e..e088be5 100644 --- a/mmfewshot/classification/core/evaluation/eval_hooks.py +++ b/mmfewshot/classification/core/evaluation/eval_hooks.py @@ -86,6 +86,10 @@ def before_run(self, runner: Runner) -> None: warnings.warn('runner.meta is None. Creating an empty one.') runner.meta = dict() runner.meta.setdefault('hook_msgs', dict()) + if runner.meta['hook_msgs'].get('best_score', False): + self.best_score = runner.meta['hook_msgs']['best_score'] + runner.logger.info( + f'Previous best score is: {self.best_score}.') self.best_ckpt_path = runner.meta['hook_msgs'].get( 'best_ckpt', None) diff --git a/tests/test_classification_runtime/test_meta_test_eval_hook.py b/tests/test_classification_runtime/test_meta_test_eval_hook.py index c5bb8d3..c2c7039 100644 --- a/tests/test_classification_runtime/test_meta_test_eval_hook.py +++ b/tests/test_classification_runtime/test_meta_test_eval_hook.py @@ -171,3 +171,48 @@ def test_epoch_eval_hook(): max_epochs=1) runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 1) + + +def test_resume_eval_hook(): + test_set_loader = DataLoader( + toy_meta_test_dataset(), + batch_size=1, + sampler=None, + num_workers=0, + shuffle=False) + query_loader = DataLoader( + toy_meta_test_dataset().query(), + batch_size=1, + sampler=None, + num_workers=0, + shuffle=False) + support_loader = DataLoader( + toy_meta_test_dataset().support(), + batch_size=1, + sampler=None, + num_workers=0, + shuffle=False) + model = ExampleModel() + optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) + optimizer = obj_from_dict(optim_cfg, torch.optim, + dict(params=model.parameters())) + test_dataset = ExampleDataset() + loader = DataLoader(test_dataset, batch_size=1) + # test EvalHook + with tempfile.TemporaryDirectory() as tmpdir: + eval_hook = MetaTestEvalHook( + support_loader, + query_loader, + test_set_loader, + num_test_tasks=10, + meta_test_cfg=dict(support={}, query={})) + runner = mmcv.runner.EpochBasedRunner( + model=model, + optimizer=optimizer, + work_dir=tmpdir, + logger=logging.getLogger(), + max_epochs=1) + runner.register_hook(eval_hook) + runner.meta = {'hook_msgs': {'best_score': 99.0}} + runner.run([loader], [('train', 1)], 1) + assert eval_hook.best_score == 99.0