diff --git a/experitur/configurators/pruning.py b/experitur/configurators/pruning.py index 8bf662f..5c355da 100644 --- a/experitur/configurators/pruning.py +++ b/experitur/configurators/pruning.py @@ -42,10 +42,6 @@ def __init__( "min_count": min_count, } - @property - def parameter_values(self): - return {} - class _Sampler(ConfigurationSampler): configurator: "Prune" @@ -58,3 +54,7 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: ) yield merge_dicts(parent_configuration, pruning_config=pruning_config) + + @property + def parameter_values(self): + return self.parent.parameter_values diff --git a/experitur/configurators/random.py b/experitur/configurators/random.py index 4665447..425e110 100644 --- a/experitur/configurators/random.py +++ b/experitur/configurators/random.py @@ -15,7 +15,11 @@ from scipy.stats import distributions from experitur import get_current_context -from experitur.core.configurators import ConfigurationSampler, Configurator +from experitur.core.configurators import ( + ConfigurationSampler, + Configurator, + merge_parameter_values, +) from experitur.helpers.merge_dicts import merge_dicts @@ -73,18 +77,10 @@ def __init__(self, distributions: Dict[str, Union[List, Any]], n_iter: int): self.distributions = distributions self.n_iter = n_iter - @property - def parameter_values(self) -> Mapping[str, Container]: - return { - k: tuple(v) if isinstance(v, Iterable) else _DistWrapper(v) - for k, v in self.distributions.items() - } - class _Sampler(ConfigurationSampler): configurator: "Random" def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: - distributions, exclude = self.prepare_values_exclude( self.configurator.distributions, exclude ) @@ -137,3 +133,13 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: break yield merge_dicts(parent_configuration, parameters=params) + + @property + def parameter_values(self) -> Mapping[str, Container]: + return merge_parameter_values( + self.parent.parameter_values, + { + k: tuple(v) if isinstance(v, Iterable) else _DistWrapper(v) + for k, v in self.configurator.distributions.items() + }, + ) diff --git a/experitur/core/configurators.py b/experitur/core/configurators.py index 9d7d939..47be77e 100644 --- a/experitur/core/configurators.py +++ b/experitur/core/configurators.py @@ -35,10 +35,15 @@ class BaseConfigurationSampler(metaclass=ABCMeta): def __iter__(self): return self.sample() + @abstractproperty + def parameter_values(self) -> Mapping[str, Container]: # pragma: no cover + """Information about the values that every parameter can assume.""" + ... + @abstractmethod def sample( self, exclude: Optional[Set] = None - ) -> Iterator[Mapping]: # pragma: no cover + ) -> Iterator[Dict]: # pragma: no cover """Sample trial configurations.""" while False: yield {} @@ -75,6 +80,10 @@ def __init__(self) -> None: def sample(self, exclude=None) -> Iterator[Mapping]: yield {"parameters": {}} + @property + def parameter_values(self) -> Mapping[str, Container]: + return {} + def contains_subset_of( self, configuration: Mapping, exclude: Optional[Set] = None ) -> bool: @@ -120,10 +129,10 @@ def build_sampler( ) -> BaseConfigurationSampler: # pragma: no cover pass - @abstractproperty - def parameter_values(self) -> Mapping[str, Container]: # pragma: no cover - """Information about the values that every parameter configured here can assume.""" - return {} + @property + @final + def parameter_values(self): + raise NotImplementedError() def __add__(self, other) -> "AdditiveConfiguratorChain": if not isinstance(other, BaseConfigurator): # pragma: no cover @@ -179,7 +188,7 @@ def contains_subset_of( # If the values contain , the respective parameter of a trial can assume any value. parameter_values = { k: v - for k, v in self.configurator.parameter_values.items() + for k, v in self.parameter_values.items() if (k not in exclude) and (unset not in v) } @@ -197,7 +206,7 @@ def contains_subset_of( } # Let parents check the rest of the configuration - exclude = exclude.union(self.configurator.parameter_values.keys()) + exclude = exclude.union(self.parameter_values.keys()) return self.parent.contains_subset_of( dict(configuration, parameters=parent_params), exclude=exclude ) @@ -211,7 +220,7 @@ def contains_superset_of(self, configuration: Mapping) -> bool: - `configuration` does not match if values for existing keys are different. """ - values = self.configurator.parameter_values + values = self.parameter_values own_params, parent_params = split_dict( configuration.get("parameters", {}), values @@ -262,6 +271,10 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: parent_configuration, parameters={"foo": self.configurator.foo} ) + + @property + def parameter_values(self) -> Mapping[str, Container]: + return self.parent.parameter_values + ... """ _Sampler: Type[ConfigurationSampler] @@ -328,13 +341,6 @@ def build_sampler( return sampler - @property - def parameter_values(self) -> Mapping[str, Container]: - parameter_values = ParameterSpace() - for c in self.configurators: - parameter_values.update_multiplicative(c.parameter_values) - return parameter_values - def __mul__(self, other) -> "MultiplicativeConfiguratorChain": if not isinstance(other, BaseConfigurator): # pragma: no cover return NotImplemented @@ -423,48 +429,26 @@ def is_invariant(self): return len(self.values) < 2 and not self.children -class ParameterSpace(Dict[str, GenerativeContainer]): - def update(self, *args, **kwargs): - raise NotImplementedError("Use update_additive or update_multiplicative") +def merge_parameter_values( + *parameter_spaces: Mapping[str, Container], **kwargs: Container +) -> Mapping[str, Container]: + parameter_spaces = parameter_spaces + (kwargs,) - def update_additive( - self, *parameter_spaces: Mapping[str, Container], **kwargs: Container - ): - parameter_spaces = parameter_spaces + (kwargs,) + result: Dict[str, GenerativeContainer] = {} - for p_space in parameter_spaces: - if isinstance(p_space, _UnsetParameterSpace): - continue - - for k, v in p_space.items(): - v = GenerativeContainer(v) - if k not in self: - self[k] = v + for p_space in parameter_spaces: + for k, v in p_space.items(): + v = GenerativeContainer(v) + if k not in result: + result[k] = v + else: + if unset not in v: + result[k] = v else: - self[k].update(v) + v.discard(unset) + result[k].update(v) - def update_multiplicative( - self, *parameter_spaces: Mapping[str, Container], **kwargs: Container - ): - parameter_spaces = parameter_spaces + (kwargs,) - - for p_space in parameter_spaces: - if isinstance(p_space, _UnsetParameterSpace): - for k in self.keys(): - if k in p_space: - self[k] = GenerativeContainer((unset,)) - continue - - for k, v in p_space.items(): - v = GenerativeContainer(v) - if k not in self: - self[k] = v - else: - if unset not in v: - self[k] = v - else: - v.discard(unset) - self[k].update(v) + return result class AdditiveConfiguratorChain(BaseConfigurator): @@ -492,23 +476,6 @@ def build_sampler( shuffle=self.shuffle, ) - @property - def parameter_values(self) -> Mapping[str, Container]: - parameter_values = ParameterSpace() - - # Update values - for c in self.configurators: - parameter_values.update_additive(c.parameter_values) - - # Add unset for all parameters that are missing in one of the child configurators - for c in self.configurators: - for missing_key in set(parameter_values.keys()) - ( - c.parameter_values.keys() - ): - parameter_values[missing_key].add(unset) - - return parameter_values - def __add__(self, other) -> "BaseConfigurator": if not isinstance(other, BaseConfigurator): # pragma: no cover return NotImplemented @@ -520,14 +487,14 @@ def __str__(self): class _Sampler(BaseConfigurationSampler): def __init__( - self, samplers: Iterable[BaseConfigurationSampler], shuffle: bool + self, children: Iterable[BaseConfigurationSampler], shuffle: bool ) -> None: - self.samplers = samplers + self.children = children self.shuffle = shuffle def sample(self, exclude=None) -> Iterator[Mapping]: if self.shuffle: - generators = [s.sample(exclude) for s in self.samplers] + generators = [child.sample(exclude) for child in self.children] while generators: g = random.choice(generators) try: @@ -535,8 +502,28 @@ def sample(self, exclude=None) -> Iterator[Mapping]: except StopIteration: generators.remove(g) - for s in self.samplers: - yield from s.sample(exclude) # pylint: disable=protected-access + for child in self.children: + yield from child.sample(exclude) # pylint: disable=protected-access + + @property + def parameter_values(self) -> Mapping[str, Container]: + result: Dict[str, GenerativeContainer] = {} + + # Merge the parameter_values of each child + for child in self.children: + for k, v in child.parameter_values.items(): + v = GenerativeContainer(v) + if k not in result: + result[k] = v + else: + result[k].update(v) + + # Add unset for all parameters that are missing in one of the child configurators + for child in self.children: + for missing_key in set(result.keys()) - (child.parameter_values.keys()): + result[missing_key].add(unset) + + return result def contains_subset_of( self, configuration: Mapping, exclude: Optional[Set] = None @@ -544,13 +531,13 @@ def contains_subset_of( """Return True if there exists a sample that is a subset of `configuration`, i.e. if there is one in a child.""" return any( - s.contains_subset_of(configuration, exclude) for s in self.samplers + s.contains_subset_of(configuration, exclude) for s in self.children ) def contains_superset_of(self, configuration: Mapping) -> bool: """Return True if there exists a sample that is a superset of `configuration`, i.e. if there is one in a child.""" - return any(s.contains_superset_of(configuration) for s in self.samplers) + return any(s.contains_superset_of(configuration) for s in self.children) def split_dict(mapping: Mapping, indicator): @@ -593,10 +580,6 @@ def __init__(self, values: Optional[Mapping[str, Any]] = None, **kwargs): else: self.values = {**values, **kwargs} - @property - def parameter_values(self) -> Mapping[str, Container]: - return {k: (v,) for k, v in self.values.items()} - class _Sampler(ConfigurationSampler): configurator: "Const" @@ -609,6 +592,13 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: # print(self.configurator, ":", values) yield merge_dicts(parent_configuration, parameters=values) + @property + def parameter_values(self) -> Mapping[str, Container]: + return merge_parameter_values( + self.parent.parameter_values, + {k: (v,) for k, v in self.configurator.values.items()}, + ) + class ZeroConfigurator(Configurator): """ @@ -619,10 +609,6 @@ class ZeroConfigurator(Configurator): __str_attrs__ = tuple() - @property - def parameter_values(self) -> Mapping[str, Container]: - return {} - class _Sampler(ConfigurationSampler): configurator: "ZeroConfigurator" @@ -630,6 +616,10 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: while False: yield {} + @property + def parameter_values(self) -> Mapping[str, Container]: + return self.parent.parameter_values + def parameter_product(p: Mapping[str, Iterable]): """Iterate over the points in the grid.""" @@ -693,10 +683,6 @@ def _validate_grid(self, grid: Mapping): if not isinstance(v, collections.abc.Iterable): raise ValueError(f"Value {v!r} for parameter {k} is not Iterable") - @property - def parameter_values(self) -> Mapping[str, Container]: - return {k: tuple(v) for k, v in self.grid.items()} - class _Sampler(ConfigurationSampler): configurator: "Grid" @@ -712,6 +698,13 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: for values in grid_product: yield merge_dicts(parent_configuration, parameters=values) + @property + def parameter_values(self) -> Mapping[str, Container]: + return merge_parameter_values( + self.parent.parameter_values, + {k: tuple(v) for k, v in self.configurator.grid.items()}, + ) + class RandomGrid(Grid): """ @@ -771,10 +764,6 @@ class FilterConfig(Configurator): def __init__(self, filter_func: Callable): self.filter_func = filter_func - @property - def parameter_values(self): - return {} - class _Sampler(ConfigurationSampler): configurator: "FilterConfig" @@ -787,6 +776,10 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: yield parent_configuration + @property + def parameter_values(self): + return self.parent.parameter_values + class _UnsetParameterSpace(Mapping[str, Container]): def __init__(self, patterns: Sequence[str]) -> None: @@ -823,10 +816,6 @@ class Clear(Configurator): def __init__(self, *patterns: str): self.patterns = patterns - @property - def parameter_values(self): - return _UnsetParameterSpace(self.patterns) - class _Sampler(ConfigurationSampler): configurator: "Clear" @@ -837,16 +826,25 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]: # Remove parameters that match any of the provided names parent_configuration["parameters"] = { - k: v - if not any( - fnmatch.fnmatchcase(k, n) for n in self.configurator.patterns - ) - else unset + k: v if not self._matches(k) else unset for k, v in parent_configuration.get("parameters", {}).items() } yield parent_configuration + @property + def parameter_values(self): + return { + k: v + for k, v in self.parent.parameter_values.items() + if not self._matches(k) + } + + def _matches(self, name): + return any( + fnmatch.fnmatchcase(name, pat) for pat in self.configurator.patterns + ) + def contains_subset_of( self, configuration: Mapping, exclude: Optional[Set] = None ) -> bool: @@ -861,20 +859,13 @@ def contains_subset_of( conf_parameters = configuration.get("parameters", {}) - if any( - fnmatch.fnmatchcase(k, n) - for n in self.configurator.patterns - for k, v in conf_parameters.items() - if v != unset - ): + if any(self._matches(k) for k, v in conf_parameters.items() if v != unset): return False - for k, v in conf_parameters.items(): - if any(fnmatch.fnmatchcase(k, n) for n in self.configurator.patterns): + for k in conf_parameters.keys(): + if self._matches(k): exclude.add(k) - print(f"Parents...") - # Let parents check the rest of the configuration return self.parent.contains_subset_of(configuration, exclude=exclude) @@ -882,22 +873,10 @@ def contains_superset_of(self, configuration: Mapping) -> bool: """ Return True if there exists a sample that is a superset of `configuration`. - - `configuration` does not match if it contains additional keys not produced by the sampler (or its parents). + - `configuration` does not match if it contains additional keys not produced by the parents. - `configuration` matches if it lacks keys produced by the sampler. - `configuration` does not match if values for existing keys are different. """ - values = self.configurator.parameter_values - - own_params, parent_params = split_dict( - configuration.get("parameters", {}), values - ) - - # Check parameters configured here - if any(v not in values[k] for k, v in own_params.items()): - return False - # Let parents check the rest of the configuration - return self.parent.contains_superset_of( - dict(configuration, parameters=parent_params) - ) + return self.parent.contains_superset_of(configuration) diff --git a/experitur/core/experiment.py b/experitur/core/experiment.py index 04934e2..fb4d91c 100644 --- a/experitur/core/experiment.py +++ b/experitur/core/experiment.py @@ -24,6 +24,7 @@ import click from experitur.core.configurators import ( + BaseConfigurationSampler, BaseConfigurator, Configurable, MultiplicativeConfiguratorChain, @@ -195,7 +196,7 @@ def __init__( name: Optional[str] = None, parameters=None, configurator=None, - parent: "Experiment" = None, + parent: Optional["Experiment"] = None, meta: Optional[Mapping] = None, active: bool = True, volatile: Optional[bool] = None, @@ -432,30 +433,8 @@ def prepend_configurator(self, configurator: BaseConfigurator) -> None: def _configurators(self) -> List[BaseConfigurator]: return self._base_configurators + self._own_configurators - @property - def configurator(self) -> BaseConfigurator: - return MultiplicativeConfiguratorChain(*self._configurators) - - @property - def parameters(self) -> List[str]: - """Names of the configured parameters.""" - return sorted(self.configurator.parameter_values.keys()) - - @property - def varying_parameters(self) -> List[str]: - """Names of varying parameters, i.e. parameters that can assume more than one value.""" - return sorted( - k - for k, v in self.configurator.parameter_values.items() - if not is_invariant(v) - ) - - @property - def invariant_parameters(self) -> List[str]: - """Names of invariant parameters, i.e. parameters that assume only one single value.""" - return sorted( - k for k, v in self.configurator.parameter_values.items() if is_invariant(v) - ) + def build_sampler(self): + return MultiplicativeConfiguratorChain(*self._configurators).build_sampler() def __str__(self): if self.name is not None: @@ -465,18 +444,18 @@ def __str__(self): def __repr__(self): # pragma: no cover return "Experiment(name={})".format(self.name) - def _trial_generator(self): + def _trial_generator(self, sampler: BaseConfigurationSampler): """Yields readily created trials.""" # Generate trial configurations - sampler = self.configurator.build_sampler() - skip_cache = _SkipCache(self) for trial_configuration in sampler: # Inject experiment data into trial_configuration # TODO: Insert runtime meta elsewhere - trial_configuration = self._setup_trial_configuration(trial_configuration) + trial_configuration = self._setup_trial_configuration( + trial_configuration, sampler + ) # Remove "unset" parameters parameters = trial_configuration["parameters"] = clean_unset( @@ -531,16 +510,18 @@ def run(self): print("Experiment", self) + sampler = self.build_sampler() + # Print varying parameters of this experiment print("Varying parameters:") - for k, v in sorted(self.configurator.parameter_values.items()): + for k, v in sorted(sampler.parameter_values.items()): if is_invariant(v): continue print("{}: {}".format(k, v)) print() - trials = self._trial_generator() + trials = self._trial_generator(sampler) if self.ctx.config["resume_failed"]: hostname = self.meta.get("hostname", object()) @@ -728,9 +709,17 @@ def run_trial(self, trial: Trial): return trial.result - def _setup_trial_configuration(self, trial_configuration): + def _setup_trial_configuration( + self, trial_configuration: Dict, sampler: BaseConfigurationSampler + ): trial_configuration.setdefault("parameters", {}) trial_configuration.setdefault("tags", []) + + independent_parameters = sorted(sampler.parameter_values.keys()) + varying_parameters = sorted( + k for k, v in sampler.parameter_values.items() if not is_invariant(v) + ) + return merge_dicts( trial_configuration, experiment={ @@ -739,8 +728,8 @@ def _setup_trial_configuration(self, trial_configuration): "func": callable_to_name(self.func), "meta": self.meta, # Names of parameters that where actually configured. - "independent_parameters": self.parameters, - "varying_parameters": self.varying_parameters, + "independent_parameters": independent_parameters, + "varying_parameters": varying_parameters, "minimize": self.minimize, "maximize": self.maximize, }, @@ -887,16 +876,10 @@ def get_matching_trials(self, exclude=None): This includes all trials with matching configuration, regardless whether they belong to this or to another experiment. """ - parameters = self.parameters - - if exclude is not None: - if isinstance(exclude, str): - exclude = [exclude] - - parameters = [k for k in parameters if k not in exclude] - - sampler = self.configurator.build_sampler() + sampler = self.build_sampler() return self.ctx.trials.match(func=self.func).filter( - lambda trial: sampler.contains_subset_of({"parameters": dict(trial)}) + lambda trial: sampler.contains_subset_of( + {"parameters": dict(trial)}, exclude + ) ) diff --git a/experitur/testing/configurators.py b/experitur/testing/configurators.py index 1344c03..c4eec1a 100644 --- a/experitur/testing/configurators.py +++ b/experitur/testing/configurators.py @@ -4,10 +4,9 @@ from experitur.core.configurators import ( BaseConfigurationSampler, GenerativeContainer, - ParameterSpace, ) -from experitur.util import unset +from experitur.util import clean_unset, unset def all_parameter_subsets(configuration: Mapping): @@ -59,6 +58,8 @@ def assert_sampler_contains_superset_of_all_samples( """ for conf in sampler: + conf = dict(conf, parameters=clean_unset(conf.get("parameters", {}))) + if with_subsets: for subset in all_parameter_subsets(conf): assert sampler.contains_superset_of( @@ -71,7 +72,7 @@ def assert_sampler_contains_superset_of_all_samples( def sampler_parameter_values(sampler: BaseConfigurationSampler): - parameter_values = ParameterSpace() + parameter_values = {} for conf in sampler: for k, v in conf.get("parameters", {}).items(): diff --git a/tests/configurators/test_conditions.py b/tests/configurators/test_conditions.py index 75e7d2c..8759b4f 100644 --- a/tests/configurators/test_conditions.py +++ b/tests/configurators/test_conditions.py @@ -1,7 +1,7 @@ from experitur import Experiment -from experitur import configurators from experitur.core.context import Context from experitur.configurators import Const, Grid, Conditions +from experitur.util import unset def test_Conditions(tmp_path): @@ -14,15 +14,16 @@ def exp(trial): configurator = Conditions( "x", {1: Const(y=1), 2: [Const(y=2), Grid({"z": [1, 2, 3]})]} ) - assert sorted(configurator.parameter_values.items()) == [ + + sampler = configurator.build_sampler() + + assert sorted(sampler.parameter_values.items()) == [ ("x", (1, 2)), ("y", (1, 2)), - ("z", (1, 2, 3)), + ("z", (1, 2, 3, unset)), ] - samples = sorted( - tuple(sorted(d["parameters"].items())) for d in configurator.build_sampler() - ) + samples = sorted(tuple(sorted(d["parameters"].items())) for d in sampler) # Assert exististence of all specified values assert samples == [ @@ -33,14 +34,15 @@ def exp(trial): ] configurator = Conditions("x", {1: Const(y=1), 2: Grid({"y": [2, 3]})}) - assert sorted(configurator.parameter_values.items()) == [ + + sampler = configurator.build_sampler() + + assert sorted(sampler.parameter_values.items()) == [ ("x", (1, 2)), ("y", (1, 2, 3)), ] - samples = sorted( - tuple(sorted(d["parameters"].items())) for d in configurator.build_sampler() - ) + samples = sorted(tuple(sorted(d["parameters"].items())) for d in sampler) # Assert exististence of all specified values assert samples == [ @@ -51,28 +53,30 @@ def exp(trial): # Condition name overwrites sub-config name configurator = Conditions("x", {1: Const(x=2)}) - assert sorted(configurator.parameter_values.items()) == [ + + sampler = configurator.build_sampler() + + assert sorted(sampler.parameter_values.items()) == [ ("x", (1,)), ] # Assert exististence of all specified values - samples = sorted( - tuple(sorted(d["parameters"].items())) for d in configurator.build_sampler() - ) + samples = sorted(tuple(sorted(d["parameters"].items())) for d in sampler) assert samples == [ (("x", 1),), ] # Test passing sub-configurations as simple dict (should get converted) configurator = Conditions("x", {1: {"y": [1]}}) - assert sorted(configurator.parameter_values.items()) == [ + + sampler = configurator.build_sampler() + + assert sorted(sampler.parameter_values.items()) == [ ("x", (1,)), ("y", (1,)), ] - samples = sorted( - tuple(sorted(d["parameters"].items())) for d in configurator.build_sampler() - ) + samples = sorted(tuple(sorted(d["parameters"].items())) for d in sampler) # Assert exististence of all specified values assert samples == [ @@ -82,15 +86,16 @@ def exp(trial): # Test passing list of sub-configurations (only invariant) configurator = Conditions("x", {1: [Const(y=1), Const(z=1)]}) print(str(configurator.conditions)) - assert sorted(configurator.parameter_values.items()) == [ + + sampler = configurator.build_sampler() + + assert sorted(sampler.parameter_values.items()) == [ ("x", (1,)), ("y", (1,)), ("z", (1,)), ] - samples = sorted( - tuple(sorted(d["parameters"].items())) for d in configurator.build_sampler() - ) + samples = sorted(tuple(sorted(d["parameters"].items())) for d in sampler) # Assert exististence of all specified values assert samples == [ @@ -100,15 +105,16 @@ def exp(trial): # Test passing list of sub-configurations (variant) configurator = Conditions("x", {1: [Const(y=1), {"z": [1, 2]}]}) print(str(configurator.conditions)) - assert sorted(configurator.parameter_values.items()) == [ + + sampler = configurator.build_sampler() + + assert sorted(sampler.parameter_values.items()) == [ ("x", (1,)), ("y", (1,)), ("z", (1, 2)), ] - samples = sorted( - tuple(sorted(d["parameters"].items())) for d in configurator.build_sampler() - ) + samples = sorted(tuple(sorted(d["parameters"].items())) for d in sampler) # Assert exististence of all specified values assert samples == [ diff --git a/tests/configurators/test_random.py b/tests/configurators/test_random.py index f5fbfda..c617261 100644 --- a/tests/configurators/test_random.py +++ b/tests/configurators/test_random.py @@ -56,7 +56,7 @@ def exp(trial): } # Assert correct behavior of independent_parameters - assert configurator.parameter_values == { + assert sampler.parameter_values == { "a": (1, 2), "b": (3, 4), "c": (0,), diff --git a/tests/configurators/test_skopt.py b/tests/configurators/test_skopt.py index 5abae3b..5ddc653 100644 --- a/tests/configurators/test_skopt.py +++ b/tests/configurators/test_skopt.py @@ -4,6 +4,8 @@ from experitur.core.context import Context from experitur.configurators import SKOpt +pytestmark = pytest.mark.skip("deprecated") + try: from experitur.configurators.skopt import SKOpt except ImportError as exc: diff --git a/tests/core/test_configurators.py b/tests/core/test_configurators.py index eab6d6d..4b7fd62 100644 --- a/tests/core/test_configurators.py +++ b/tests/core/test_configurators.py @@ -26,16 +26,16 @@ def test_Const(): # Test __str__ str(configurator) + sampler = configurator.build_sampler() + # Assert correct behavior of "parameter_values" - assert configurator.parameter_values == { + assert sampler.parameter_values == { "a": (1,), "b": (2,), "c": (3,), "d": (unset,), } - sampler = configurator.build_sampler() - assert_sampler_contains_subset_of_all_samples(sampler, include_parameters={"d": 4}) assert_sampler_contains_superset_of_all_samples(sampler) @@ -58,16 +58,16 @@ def test_Grid(cls): # Test __str__ str(configurator) + sampler = configurator.build_sampler() + # Assert correct behavior of "parameter_values" - assert configurator.parameter_values == { + assert sampler.parameter_values == { "a": (1, 2), "b": (3, 4), "c": (0,), "d": (0, unset), } - sampler = configurator.build_sampler() - # Test contains_subset_of and contains_superset_of assert_sampler_contains_subset_of_all_samples(sampler, include_parameters={"d": 4}) assert_sampler_contains_superset_of_all_samples(sampler) @@ -97,11 +97,11 @@ def test_MultiplicativeConfiguratorChain(): # Test __str__ str(configurator) - # Assert correct behavior of "parameter_values" - assert configurator.parameter_values == {"a": (4, 5), "b": (3, 4), "c": (0,)} - sampler = configurator.build_sampler() + # Assert correct behavior of "parameter_values" + assert sampler.parameter_values == {"a": (4, 5), "b": (3, 4), "c": (0,)} + # Test contains_subset_of and contains_superset_of assert_sampler_contains_subset_of_all_samples(sampler, include_parameters={"d": 4}) assert_sampler_contains_superset_of_all_samples(sampler) @@ -129,15 +129,15 @@ def test_AdditiveConfiguratorChain(): # Test __str__ str(configurator) + sampler = configurator.build_sampler() + # Assert correct behavior of "parameter_values" - assert configurator.parameter_values == { + assert sampler.parameter_values == { "a": (1, 2, 4, 5), "b": (3, 4, unset), "c": (10, 11, 0), } - sampler = configurator.build_sampler() - # Test contains_subset_of and contains_superset_of assert_sampler_contains_subset_of_all_samples(sampler, include_parameters={"d": 4}) assert_sampler_contains_superset_of_all_samples(sampler) @@ -168,8 +168,10 @@ def test_AdditiveConfiguratorChain(): def test_AdditiveConst(): configurator = Const(a=1) * (Const() + Const(a=2, b=1) + Const(a=3, b=1)) + sampler = configurator.build_sampler() + # Assert correct behavior of "parameter_values" - assert configurator.parameter_values == { + assert sampler.parameter_values == { "a": {1, 2, 3}, "b": {unset, 1}, } @@ -178,7 +180,7 @@ def test_AdditiveConst(): def test_Unset(): # FIXME: Unset is currently inherently broken. # The problem is that Unset("b") does not have access to the outer parameter b - # during Configurator.parameter_values + # during sampler.parameter_values # Solution: Shift parameter_values to sampler, so it has access to the parent's parameter_values. # This way, all the parameter_values building logic is moved from a global to the specific sampler. @@ -188,20 +190,18 @@ def test_Unset(): # Test __str__ str(configurator) - # # Assert correct behavior of "parameter_values" - # assert configurator.parameter_values == { - # "a": (1, 2), - # "b": (2, unset), - # "c": (3,), - # } - sampler = configurator.build_sampler() - parameter_values_expected = sampler_parameter_values(sampler) + parameter_values_expected = { + "a": (1, 2), + "b": (2, unset), + "c": (3,), + } - print("parameter_values_expected", parameter_values_expected) + # Assert correct behavior of "parameter_values" + assert sampler.parameter_values == parameter_values_expected - assert configurator.parameter_values == parameter_values_expected + assert sampler_parameter_values(sampler) == parameter_values_expected # Test contains_subset_of and contains_superset_of assert_sampler_contains_subset_of_all_samples(sampler, include_parameters={"d": 4})