|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved.
|
| 2 | +import copy |
2 | 3 | import os
|
3 | 4 | from tempfile import NamedTemporaryFile, TemporaryDirectory
|
| 5 | +from typing import Any |
4 | 6 |
|
5 | 7 | import mmcv
|
6 | 8 | import numpy as np
|
|
37 | 39 | img = np.random.rand(*img_shape, 3)
|
38 | 40 |
|
39 | 41 |
|
40 |
| -def test_init_pytorch_model(): |
| 42 | +@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0]) |
| 43 | +def test_init_pytorch_model(from_mmrazor: Any): |
41 | 44 | from mmcls.models.classifiers.base import BaseClassifier
|
42 |
| - model = task_processor.init_pytorch_model(None) |
| 45 | + if from_mmrazor is False: |
| 46 | + _task_processor = task_processor |
| 47 | + else: |
| 48 | + _model_cfg_path = 'tests/test_codebase/test_mmcls/data/' \ |
| 49 | + 'mmrazor_model.py' |
| 50 | + _model_cfg = load_config(_model_cfg_path)[0] |
| 51 | + _model_cfg.algorithm.architecture.model.type = 'mmcls.ImageClassifier' |
| 52 | + _model_cfg.algorithm.architecture.model.backbone = dict( |
| 53 | + type='SearchableShuffleNetV2', widen_factor=1.0) |
| 54 | + _deploy_cfg = copy.deepcopy(deploy_cfg) |
| 55 | + _deploy_cfg.codebase_config['from_mmrazor'] = from_mmrazor |
| 56 | + _task_processor = build_task_processor(_model_cfg, _deploy_cfg, 'cpu') |
| 57 | + |
| 58 | + if not isinstance(from_mmrazor, bool): |
| 59 | + with pytest.raises( |
| 60 | + TypeError, |
| 61 | + match='`from_mmrazor` attribute must be ' |
| 62 | + 'boolean type! ' |
| 63 | + f'but got: {from_mmrazor}'): |
| 64 | + _ = _task_processor.from_mmrazor |
| 65 | + return |
| 66 | + assert from_mmrazor == _task_processor.from_mmrazor |
| 67 | + |
| 68 | + model = _task_processor.init_pytorch_model(None) |
43 | 69 | assert isinstance(model, BaseClassifier)
|
44 | 70 |
|
45 | 71 |
|
|
0 commit comments