From f9d3bf3a04e109f9f231064ad72dcff86c517925 Mon Sep 17 00:00:00 2001 From: Jintao Lin <528557675@qq.com> Date: Thu, 15 Oct 2020 22:29:24 +0800 Subject: [PATCH] import_modules_from_strings when loading cfg from file (#606) * import_modules_from_strings when loading cfg from file * add unittest to tell whether the feature is enabled as expected * minor * set an environment variable instead of writing a file * use 'shutil' instead of 'os.system' --- mmcv/utils/config.py | 7 ++++++- tests/data/config/q.py | 3 +++ tests/data/config/r.py | 3 +++ tests/test_utils/test_config.py | 15 +++++++++++++++ 4 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 tests/data/config/q.py create mode 100644 tests/data/config/r.py diff --git a/mmcv/utils/config.py b/mmcv/utils/config.py index 1d95d46dd9..76739c361a 100644 --- a/mmcv/utils/config.py +++ b/mmcv/utils/config.py @@ -12,6 +12,7 @@ from addict import Dict from yapf.yapflib.yapf_api import FormatCode +from .misc import import_modules_from_strings from .path import check_file_exist if platform.system() == 'Windows': @@ -208,9 +209,13 @@ def _merge_a_into_b(a, b): return b @staticmethod - def fromfile(filename, use_predefined_variables=True): + def fromfile(filename, + use_predefined_variables=True, + import_custom_modules=True): cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables) + if import_custom_modules and cfg_dict.get('custom_imports', None): + import_modules_from_strings(**cfg_dict['custom_imports']) return Config(cfg_dict, cfg_text=cfg_text, filename=filename) @staticmethod diff --git a/tests/data/config/q.py b/tests/data/config/q.py new file mode 100644 index 0000000000..cd9d1f606c --- /dev/null +++ b/tests/data/config/q.py @@ -0,0 +1,3 @@ +custom_imports = dict( + imports=['r'], + allow_failed_imports=False) diff --git a/tests/data/config/r.py b/tests/data/config/r.py new file mode 100644 index 0000000000..9360128d57 --- /dev/null +++ b/tests/data/config/r.py @@ -0,0 +1,3 @@ +import os + +os.environ["TEST_VALUE"] = 'test' diff --git a/tests/test_utils/test_config.py b/tests/test_utils/test_config.py index 0f669e8bb9..e13daff122 100644 --- a/tests/test_utils/test_config.py +++ b/tests/test_utils/test_config.py @@ -1,7 +1,9 @@ # Copyright (c) Open-MMLab. All rights reserved. import argparse import json +import os import os.path as osp +import shutil import tempfile import pytest @@ -140,6 +142,19 @@ def test_fromfile(): assert cfg.text == osp.abspath(osp.expanduser(cfg_file)) + '\n' + \ open(cfg_file, 'r').read() + # test custom_imports for Config.fromfile + cfg_file = osp.join(data_path, 'config', 'q.py') + imported_file = osp.join(data_path, 'config', 'r.py') + target_pkg = osp.join(osp.dirname(__file__), 'r.py') + + # Since the imported config will be regarded as a tmp file + # it should be copied to the directory at the same level + shutil.copy(imported_file, target_pkg) + Config.fromfile(cfg_file, import_custom_modules=True) + + assert os.environ.pop('TEST_VALUE') == 'test' + os.remove(target_pkg) + with pytest.raises(FileNotFoundError): Config.fromfile('no_such_file.py') with pytest.raises(IOError):