Skip to content

[Feature] Resume from the latest checkpoint automatically. #6727

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 5 commits into from
Dec 22, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mmdet/apis/train.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmdet.utils import get_root_logger
from mmdet.utils import find_latest_checkpoint, get_root_logger


def init_random_seed(seed=None, device='cuda'):
@@ -196,6 +196,12 @@ def train_detector(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')

resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
7 changes: 6 additions & 1 deletion mmdet/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .logger import get_root_logger
from .misc import find_latest_checkpoint

__all__ = ['get_root_logger', 'collect_env']
__all__ = [
'get_root_logger',
'collect_env',
'find_latest_checkpoint',
]
38 changes: 38 additions & 0 deletions mmdet/utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import warnings


def find_latest_checkpoint(path, suffix='pth'):
"""Find the latest checkpoint from the working directory.

Args:
path(str): The path to find checkpoints.
suffix(str): File extension.
Defaults to pth.

Returns:
latest_path(str | None): File path of the latest checkpoint.
References:
.. [1] https://github.com/microsoft/SoftTeacher
/blob/main/ssod/utils/patch.py
"""
if not osp.exists(path):
warnings.warn('The path of checkpoints does not exist.')
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')

checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('There are no checkpoints in the path.')
return None
latest = -1
latest_path = None
for checkpoint in checkpoints:
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
41 changes: 41 additions & 0 deletions tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile

import numpy as np
import pytest
import torch
@@ -7,6 +9,7 @@
from mmdet.core.mask.structures import BitmapMasks, PolygonMasks
from mmdet.core.utils import (center_of_mass, filter_scores_and_topk,
flip_tensor, mask2ndarray, select_single_mlvl)
from mmdet.utils import find_latest_checkpoint


def dummy_raw_polygon_masks(size):
@@ -160,3 +163,41 @@ def test_filter_scores_and_topk():
assert keep_idxs.allclose(torch.tensor([1, 2, 1, 3]))
assert results['bbox_pred'].allclose(
torch.tensor([[0.4, 0.7], [0.1, 0.1], [0.4, 0.7], [0.5, 0.1]]))


def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir
latest = find_latest_checkpoint(path)
# There are no checkpoints in the path.
assert latest is None

path = tmpdir + '/none'
latest = find_latest_checkpoint(path)
# The path does not exist.
assert latest is None

with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/latest.pth', 'w') as f:
f.write('latest')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == tmpdir + '/latest.pth'

with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/iter_4000.pth', 'w') as f:
f.write('iter_4000')
with open(tmpdir + '/iter_8000.pth', 'w') as f:
f.write('iter_8000')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == tmpdir + '/iter_8000.pth'

with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/epoch_1.pth', 'w') as f:
f.write('epoch_1')
with open(tmpdir + '/epoch_2.pth', 'w') as f:
f.write('epoch_2')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == tmpdir + '/epoch_2.pth'
5 changes: 5 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,10 @@ def parse_args():
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
@@ -104,6 +108,7 @@ def main():
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else: