diff --git a/experitur/configurators/__init__.py b/experitur/configurators/__init__.py index 15688ea..6f4b879 100644 --- a/experitur/configurators/__init__.py +++ b/experitur/configurators/__init__.py @@ -9,6 +9,7 @@ Grid, MultiplicativeConfiguratorChain, RandomGrid, + Clear, ) from .conditions import Conditions @@ -23,6 +24,7 @@ "MultiplicativeConfiguratorChain", "Prune", "RandomGrid", + "Clear", ] try: diff --git a/experitur/core/configurators.py b/experitur/core/configurators.py index ae9af04..093f14f 100644 --- a/experitur/core/configurators.py +++ b/experitur/core/configurators.py @@ -1,4 +1,5 @@ import collections.abc +import fnmatch import itertools import random import warnings @@ -13,6 +14,7 @@ List, Mapping, Optional, + Sequence, Set, Tuple, Type, @@ -123,13 +125,13 @@ def parameter_values(self) -> Mapping[str, Container]: # pragma: no cover """Information about the values that every parameter configured here can assume.""" return {} - def __add__(self, other) -> "BaseConfigurator": + def __add__(self, other) -> "AdditiveConfiguratorChain": if not isinstance(other, BaseConfigurator): # pragma: no cover return NotImplemented return AdditiveConfiguratorChain(self, other) - def __mul__(self, other) -> "BaseConfigurator": + def __mul__(self, other) -> "MultiplicativeConfiguratorChain": if not isinstance(other, BaseConfigurator): # pragma: no cover return NotImplemented @@ -174,7 +176,7 @@ def contains_subset_of( exclude = set() # The parameter values that should contain the supplied configuration, without `exclude`ed - # If the values contain , a trial can assume any value. + # If the values contain , the respective parameter of a trial can assume any value. parameter_values = { k: v for k, v in self.configurator.parameter_values.items() @@ -185,7 +187,7 @@ def contains_subset_of( # Check if all configured parameters are contained if any( - k not in conf_parameters or conf_parameters[k] not in v + (k not in conf_parameters) or (conf_parameters[k] not in v) for k, v in parameter_values.items() ): return False @@ -294,19 +296,6 @@ def combine_parameter_values(v_left, v_right, drop_unset=False): ) # pragma: no cover -def replace_parameter_values( - left: Dict[str, Container], right: Mapping[str, Container] -) -> None: - """Replace values in `left` with matching in `right`.""" - for k, v_right in right.items(): - if unset in v_right: - left[k] = combine_parameter_values( - left.get(k, (unset,)), v_right, drop_unset=True - ) - else: - left[k] = v_right - - def extend_parameter_values( left: Dict[str, Container], right: Mapping[str, Container] ) -> None: @@ -320,7 +309,7 @@ def extend_parameter_values( left[k] = combine_parameter_values(left[k], v_right) -class MultiplicativeConfiguratorChain(Configurator): +class MultiplicativeConfiguratorChain(BaseConfigurator): """ Multiplicative configurator chain. The result is the cross-product of all contained configurators. @@ -332,22 +321,21 @@ def __init__(self, *configurators: BaseConfigurator) -> None: def build_sampler( self, parent: Optional["BaseConfigurationSampler"] = None ) -> BaseConfigurationSampler: - if parent is None: - parent = _RootSampler() + sampler = _RootSampler() if parent is None else parent for c in self.configurators: - parent = c.build_sampler(parent) + sampler = c.build_sampler(sampler) - return parent + return sampler @property def parameter_values(self) -> Mapping[str, Container]: - parameter_values = {} + parameter_values = ParameterSpace() for c in self.configurators: - replace_parameter_values(parameter_values, c.parameter_values) + parameter_values.update_multiplicative(c.parameter_values) return parameter_values - def __mul__(self, other) -> "BaseConfigurator": + def __mul__(self, other) -> "MultiplicativeConfiguratorChain": if not isinstance(other, BaseConfigurator): # pragma: no cover return NotImplemented @@ -357,7 +345,129 @@ def __str__(self): return "(" + (" * ".join(str(c) for c in self.configurators)) + ")" -class AdditiveConfiguratorChain(Configurator): +class GenerativeContainer(collections.abc.Container): + def __init__(self, values=None) -> None: + super().__init__() + + self.values = [] + self.children = [] + + if values is not None: + self.update(values) + + def update(self, other): + if isinstance(other, (list, tuple, set)): + for elm in other: + self.add(elm) + elif isinstance(other, GenerativeContainer): + for elm in other.values: + self.add(elm) + self.children.extend(other.children) + else: + self.children.append(other) + + def add(self, value): + if value in self.values: + return + self.values.append(value) + + def discard(self, value): + self.values = [v for v in self.values if v != value] + + def __contains__(self, x) -> bool: + return x in self.values or any(x in c for c in self.children) + + def __str__(self) -> str: + parts = [f"{v}" for v in self.values] + for c in self.children: + parts.append(f"*{c}") + return "{" + (", ".join(parts)) + "}" + + def __repr__(self) -> str: + return f"{type(self)}({self})" + + def __eq__(self, other: object) -> bool: + if isinstance(other, (list, tuple, set)): + if self.children: + return False + + if len(self.values) != len(other): + return False + + for elm in self.values: + if elm not in other: + return False + + return True + + if isinstance(other, GenerativeContainer): + if len(self.values) != len(other.values): + return False + + for elm in self.values: + if elm not in other.values: + return False + + if len(self.children) != len(other.children): + return False + + for child in self.children: + if child not in other.children: + return False + + return True + + return NotImplemented + + def is_invariant(self): + return len(self.values) < 2 and not self.children + + +class ParameterSpace(Dict[str, GenerativeContainer]): + def update(self, *args, **kwargs): + raise NotImplementedError("Use update_additive or update_multiplicative") + + def update_additive( + self, *parameter_spaces: Mapping[str, Container], **kwargs: Container + ): + parameter_spaces = parameter_spaces + (kwargs,) + + for p_space in parameter_spaces: + if isinstance(p_space, _UnsetParameterSpace): + continue + + for k, v in p_space.items(): + v = GenerativeContainer(v) + if k not in self: + self[k] = v + else: + self[k].update(v) + + def update_multiplicative( + self, *parameter_spaces: Mapping[str, Container], **kwargs: Container + ): + parameter_spaces = parameter_spaces + (kwargs,) + + for p_space in parameter_spaces: + if isinstance(p_space, _UnsetParameterSpace): + for k in self.keys(): + if k in p_space: + self[k] = GenerativeContainer((unset,)) + continue + + for k, v in p_space.items(): + v = GenerativeContainer(v) + if k not in self: + self[k] = v + else: + if unset not in v: + self[k] = v + else: + v.discard(unset) + self[k].update(v) + + +class AdditiveConfiguratorChain(BaseConfigurator): """ Additive configurator chain. The result is the concatenation of all contained configurators. @@ -384,13 +494,19 @@ def build_sampler( @property def parameter_values(self) -> Mapping[str, Container]: - parameter_values = {} + parameter_values = ParameterSpace() + + # Update values for c in self.configurators: - extend_parameter_values(parameter_values, c.parameter_values) + parameter_values.update_additive(c.parameter_values) + + # Add unset for all parameters that are missing in one of the child configurators for c in self.configurators: - for k in parameter_values.keys(): - if k not in c.parameter_values and unset not in parameter_values[k]: - parameter_values[k] = parameter_values[k] + (unset,) + for missing_key in set(parameter_values.keys()) - ( + c.parameter_values.keys() + ): + parameter_values[missing_key].add(unset) + return parameter_values def __add__(self, other) -> "BaseConfigurator": @@ -619,7 +735,9 @@ def validate_configurators(configurators) -> List[BaseConfigurator]: return [] if isinstance(configurators, Mapping): return [Grid(configurators)] - if isinstance(configurators, Iterable): + if isinstance(configurators, Iterable) and not isinstance( + configurators, (str, bytes) + ): return sum((validate_configurators(c) for c in configurators), []) if isinstance(configurators, BaseConfigurator): return [configurators] @@ -630,6 +748,13 @@ def validate_configurators(configurators) -> List[BaseConfigurator]: def is_invariant(configured_values: Any): """Return True if not more than one single value is configured.""" + try: + is_invariant = configured_values.is_invariant + except AttributeError: + pass + else: + return is_invariant() + if hasattr(configured_values, "__len__"): return len(configured_values) < 2 @@ -655,9 +780,124 @@ class _Sampler(ConfigurationSampler): def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: for parent_configuration in self.parent.sample(exclude=exclude): - parameters = parent_configuration.get("parameters") + parameters = parent_configuration.get("parameters", {}) if not self.configurator.filter_func(parameters): continue yield parent_configuration + + +class _UnsetParameterSpace(Mapping[str, Container]): + def __init__(self, patterns: Sequence[str]) -> None: + self.patterns = patterns + + def __contains__(self, key): + return any(fnmatch.fnmatchcase(key, n) for n in self.patterns) + + def __getitem__(self, key: str) -> Container: + if key in self: + return {unset} + raise KeyError(key) + + def __iter__(self) -> Iterator[str]: + raise NotImplementedError(repr(self) + ".__iter__") + + def __len__(self) -> int: + raise NotImplementedError(repr(self) + ".__len__") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.patterns})" + + +class Clear(Configurator): + """ + Clear parameters matching the given patterns. + + Args: + *names (str): Parameter names to remove from the configuration. + """ + + __str_attrs__ = ("names",) + + def __init__(self, *patterns: str): + self.patterns = patterns + + @property + def parameter_values(self): + return _UnsetParameterSpace(self.patterns) + + class _Sampler(ConfigurationSampler): + configurator: "Clear" + + def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: + for parent_configuration in self.parent.sample(exclude=exclude): + # Make a copy of the parent configuration + parent_configuration = dict(parent_configuration) + + # Remove parameters that match any of the provided names + parent_configuration["parameters"] = { + k: v + if not any( + fnmatch.fnmatchcase(k, n) for n in self.configurator.patterns + ) + else unset + for k, v in parent_configuration.get("parameters", {}).items() + } + + yield parent_configuration + + def contains_subset_of( + self, configuration: Mapping, exclude: Optional[Set] = None + ) -> bool: + """ + Return True if there exists a sample that is a subset of `configuration`. + + A configuration matches if it does not contain any keys matched by the provided patterns + """ + + if exclude is None: + exclude = set() + + conf_parameters = configuration.get("parameters", {}) + + if any( + fnmatch.fnmatchcase(k, n) + for n in self.configurator.patterns + for k, v in conf_parameters.items() + if v != unset + ): + return False + + for k, v in conf_parameters.items(): + if any(fnmatch.fnmatchcase(k, n) for n in self.configurator.patterns): + exclude.add(k) + + print(f"Parents...") + + # Let parents check the rest of the configuration + return self.parent.contains_subset_of(configuration, exclude=exclude) + + def contains_superset_of(self, configuration: Mapping) -> bool: + """ + Return True if there exists a sample that is a superset of `configuration`. + + - `configuration` does not match if it contains additional keys not produced by the sampler (or its parents). + - `configuration` matches if it lacks keys produced by the sampler. + - `configuration` does not match if values for existing keys are different. + """ + + values = self.configurator.parameter_values + + own_params, parent_params = split_dict( + configuration.get("parameters", {}), values + ) + + # Check parameters configured here + if any(v not in values[k] for k, v in own_params.items()): + return False + + # Let parents check the rest of the configuration + return self.parent.contains_superset_of( + dict(configuration, parameters=parent_params) + ) diff --git a/experitur/testing/configurators.py b/experitur/testing/configurators.py index 682435c..1344c03 100644 --- a/experitur/testing/configurators.py +++ b/experitur/testing/configurators.py @@ -1,7 +1,13 @@ from experitur.helpers.merge_dicts import merge_dicts from typing import Mapping, Optional import itertools -from experitur.core.configurators import BaseConfigurationSampler +from experitur.core.configurators import ( + BaseConfigurationSampler, + GenerativeContainer, + ParameterSpace, +) + +from experitur.util import unset def all_parameter_subsets(configuration: Mapping): @@ -62,3 +68,18 @@ def assert_sampler_contains_superset_of_all_samples( assert sampler.contains_superset_of( conf ), f"No superset of {conf} in sampler" + + +def sampler_parameter_values(sampler: BaseConfigurationSampler): + parameter_values = ParameterSpace() + + for conf in sampler: + for k, v in conf.get("parameters", {}).items(): + parameter_values.setdefault(k, GenerativeContainer()).add(v) + + for conf in sampler: + for k in parameter_values.keys(): + if k not in conf.get("parameters", {}): + parameter_values[k].add(unset) + + return parameter_values diff --git a/tests/core/test_configurators.py b/tests/core/test_configurators.py index dd05fb3..eab6d6d 100644 --- a/tests/core/test_configurators.py +++ b/tests/core/test_configurators.py @@ -1,16 +1,17 @@ from typing import Union import pytest - from experitur.core.configurators import ( Const, Grid, RandomGrid, + Clear, parameter_product, ) from experitur.testing.configurators import ( assert_sampler_contains_subset_of_all_samples, assert_sampler_contains_superset_of_all_samples, + sampler_parameter_values, ) from experitur.util import unset @@ -169,6 +170,39 @@ def test_AdditiveConst(): # Assert correct behavior of "parameter_values" assert configurator.parameter_values == { - "a": (1, 2, 3), - "b": (unset, 1), + "a": {1, 2, 3}, + "b": {unset, 1}, } + + +def test_Unset(): + # FIXME: Unset is currently inherently broken. + # The problem is that Unset("b") does not have access to the outer parameter b + # during Configurator.parameter_values + # Solution: Shift parameter_values to sampler, so it has access to the parent's parameter_values. + # This way, all the parameter_values building logic is moved from a global to the specific sampler. + + configurator = Const(a=1, b=2, c=3) * (Const() + (Const(a=2) * Clear("b"))) + # configurator = Const(a=1, b=2) * Unset("b") + + # Test __str__ + str(configurator) + + # # Assert correct behavior of "parameter_values" + # assert configurator.parameter_values == { + # "a": (1, 2), + # "b": (2, unset), + # "c": (3,), + # } + + sampler = configurator.build_sampler() + + parameter_values_expected = sampler_parameter_values(sampler) + + print("parameter_values_expected", parameter_values_expected) + + assert configurator.parameter_values == parameter_values_expected + + # Test contains_subset_of and contains_superset_of + assert_sampler_contains_subset_of_all_samples(sampler, include_parameters={"d": 4}) + assert_sampler_contains_superset_of_all_samples(sampler)