|
8 | 8 | from mmrazor.models.mutables.mutable_module import MutableModule
|
9 | 9 | from .base_mutator import MUTABLE_TYPE
|
10 | 10 |
|
| 11 | +if sys.version_info < (3, 8): |
| 12 | + from typing_extensions import Protocol |
| 13 | +else: |
| 14 | + from typing import Protocol |
| 15 | + |
11 | 16 |
|
12 | 17 | class GroupMixin():
|
13 | 18 | """A mixin for :class:`BaseMutator`, which can group mutables by
|
@@ -259,3 +264,66 @@ def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]],
|
259 | 264 | f'When a mutable is set alias attribute :{alias_key},'
|
260 | 265 | f'the corresponding module name {mutable_name} should '
|
261 | 266 | f'not be used in `custom_group` {custom_group}.')
|
| 267 | + |
| 268 | + |
| 269 | +class MutatorProtocol(Protocol): # pragma: no cover |
| 270 | + |
| 271 | + @property |
| 272 | + def mutable_class_type(self) -> Type[BaseMutable]: |
| 273 | + ... |
| 274 | + |
| 275 | + @property |
| 276 | + def search_groups(self) -> Dict: |
| 277 | + ... |
| 278 | + |
| 279 | + |
| 280 | +class OneShotSampleMixin: |
| 281 | + """Sample mixin for one-shot mutators.""" |
| 282 | + |
| 283 | + def sample_choices(self: MutatorProtocol) -> Dict: |
| 284 | + """Sample choices for each group in search_groups.""" |
| 285 | + random_choices = dict() |
| 286 | + for group_id, modules in self.search_groups.items(): |
| 287 | + random_choices[group_id] = modules[0].sample_choice() |
| 288 | + |
| 289 | + return random_choices |
| 290 | + |
| 291 | + def set_choices(self: MutatorProtocol, choices: Dict) -> None: |
| 292 | + """Set choices for each group in search_groups.""" |
| 293 | + for group_id, modules in self.search_groups.items(): |
| 294 | + choice = choices[group_id] |
| 295 | + for module in modules: |
| 296 | + module.current_choice = choice |
| 297 | + |
| 298 | + |
| 299 | +class DynamicSampleMixin(OneShotSampleMixin): |
| 300 | + |
| 301 | + def sample_choices(self: MutatorProtocol, kind: str = 'random') -> Dict: |
| 302 | + """Sample choices for each group in search_groups.""" |
| 303 | + random_choices = dict() |
| 304 | + for group_id, modules in self.search_groups.items(): |
| 305 | + if kind == 'max': |
| 306 | + random_choices[group_id] = modules[0].max_choice |
| 307 | + elif kind == 'min': |
| 308 | + random_choices[group_id] = modules[0].min_choice |
| 309 | + else: |
| 310 | + random_choices[group_id] = modules[0].sample_choice() |
| 311 | + return random_choices |
| 312 | + |
| 313 | + @property |
| 314 | + def max_choice(self: MutatorProtocol) -> Dict: |
| 315 | + """Get max choices for each group in search_groups.""" |
| 316 | + max_choice = dict() |
| 317 | + for group_id, modules in self.search_groups.items(): |
| 318 | + max_choice[group_id] = modules[0].max_choice |
| 319 | + |
| 320 | + return max_choice |
| 321 | + |
| 322 | + @property |
| 323 | + def min_choice(self: MutatorProtocol) -> Dict: |
| 324 | + """Get min choices for each group in search_groups.""" |
| 325 | + min_choice = dict() |
| 326 | + for group_id, modules in self.search_groups.items(): |
| 327 | + min_choice[group_id] = modules[0].min_choice |
| 328 | + |
| 329 | + return min_choice |
0 commit comments