forked from chengtan9907/Co-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
103 lines (91 loc) · 3.89 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os.path as osp
from addict import Dict
import tempfile
import re
import sys
import ast
from importlib import import_module
'''
Thanks the code from https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py wrote by Open-MMLab.
The `Config` class here uses some parts of this reference.
'''
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
if not osp.isfile(filename):
raise FileNotFoundError(msg_tmpl.format(filename))
class Config:
def __init__(self, cfg_dict=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError('cfg_dict must be a dict, but '
f'got {type(cfg_dict)}')
if filename is not None:
cfg_dict = self._file2dict(filename, True)
filename = filename
super(Config, self).__setattr__('_cfg_dict', cfg_dict)
super(Config, self).__setattr__('_filename', filename)
@staticmethod
def _validate_py_syntax(filename):
with open(filename, 'r') as f:
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError('There are syntax errors in config '
f'file {filename}: {e}')
@staticmethod
def _substitute_predefined_vars(filename, temp_config_name):
file_dirname = osp.dirname(filename)
file_basename = osp.basename(filename)
file_basename_no_extension = osp.splitext(file_basename)[0]
file_extname = osp.splitext(filename)[1]
support_templates = dict(
fileDirname=file_dirname,
fileBasename=file_basename,
fileBasenameNoExtension=file_basename_no_extension,
fileExtname=file_extname)
with open(filename, 'r') as f:
config_file = f.read()
for key, value in support_templates.items():
regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
value = value.replace('\\', '/')
config_file = re.sub(regexp, value, config_file)
with open(temp_config_name, 'w') as tmp_config_file:
tmp_config_file.write(config_file)
@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
fileExtname = osp.splitext(filename)[1]
if fileExtname not in ['.py']:
raise IOError('Only py type are supported now!')
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix=fileExtname)
temp_config_name = osp.basename(temp_config_file.name)
# Substitute predefined variables
if use_predefined_variables:
Config._substitute_predefined_vars(filename,
temp_config_file.name)
else:
shutil.copyfile(filename, temp_config_file.name)
if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
Config._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
# delete imported module
del sys.modules[temp_module_name]
# close temp file
temp_config_file.close()
return cfg_dict
@staticmethod
def fromfile(filename, use_predefined_variables=True):
cfg_dict = Config._file2dict(filename, use_predefined_variables)
return Config(cfg_dict, filename=filename)