Skip to content

Commit

Permalink
Configurators: Fix Clear, improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
moi90 committed Jan 15, 2024
1 parent c2a7b7e commit 2535e34
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 114 deletions.
237 changes: 132 additions & 105 deletions experitur/core/configurators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import fnmatch
import itertools
import random
import textwrap
import warnings
from abc import ABCMeta, abstractmethod, abstractproperty
from functools import cached_property
from typing import (
Any,
Callable,
Expand All @@ -19,12 +21,12 @@
Tuple,
Type,
TypeVar,
Union,
)

from typing_extensions import final

from experitur.helpers.merge_dicts import merge_dicts
from experitur.util import unset
from typing_extensions import final


class BaseConfigurationSampler(metaclass=ABCMeta):
Expand All @@ -49,26 +51,77 @@ def sample(
yield {}

def contains_subset_of(
self, configuration: Mapping, exclude: Optional[Set] = None
) -> bool: # pragma: no cover
self,
configuration: Mapping,
exclude: Optional[Set] = None,
strict=False,
_explain=False,
) -> bool:
"""
Return True if there exists a sample that is a subset of `configuration`.
- `configuration` matches if it contains additional keys not produced by the sampler.
- `configuration` does not match if it lacks keys produced by the sampler.
- `configuration` does not match if values for existing keys are different.
"""
raise NotImplementedError(f"{type(self).__qualname__}.contains_subset_of")

def contains_superset_of(self, configuration: Mapping) -> bool: # pragma: no cover
if exclude is None:
exclude = set()

# The parameter values that should contain the supplied configuration, without `exclude`ed
# If the values contain <unset>, the respective parameter of a trial can assume any value.
parameter_values = {
k: v
for k, v in self.parameter_values.items()
if (k not in exclude) and (unset not in v)
}

conf_parameters = configuration.get("parameters", {})

# Check if all configured parameters are contained
if any(
(k not in conf_parameters) or (conf_parameters[k] not in v)
for k, v in parameter_values.items()
):
if _explain:
msg = []
for k, v in parameter_values.items():
if k not in conf_parameters:
msg.append(f"{k!r} of sampler not in conf_parameters")
elif conf_parameters[k] not in v:
msg.append(
f"conf_parameters[{k}]={conf_parameters[k]!r} not in parameter_values[{k}]={v}"
)

msg.append(f"parameter_values: {parameter_values}")
print(
f"{self.__class__.__qualname__}:\n"
+ textwrap.indent("\n".join(msg), " ")
)
return False

return True

def contains_superset_of(self, configuration: Mapping) -> bool:
"""
Return True if there exists a sample that is a superset of `configuration`.
- `configuration` does not match if it contains additional keys not produced by the sampler.
- `configuration` matches if it lacks keys produced by the sampler.
- `configuration` does not match if it contains additional keys not produced by the sampler (or its parents).
- `configuration` matches if it lacks keys produced by the sampler.
- `configuration` does not match if values for existing keys are different.
"""
raise NotImplementedError(f"{type(self).__qualname__}.contains_superset_of")

values = self.parameter_values

own_params, remaining_params = split_dict(
configuration.get("parameters", {}), values
)

# Check parameters configured here
if any(v not in values[k] for k, v in own_params.items()):
return False

return not remaining_params


class _RootSampler(BaseConfigurationSampler):
Expand All @@ -85,11 +138,16 @@ def parameter_values(self) -> Mapping[str, Container]:
return {}

def contains_subset_of(
self, configuration: Mapping, exclude: Optional[Set] = None
self,
configuration: Mapping,
exclude: Optional[Set] = None,
strict=False,
_explain=False,
) -> bool:
"""Return True if there exists a sample that is a subset of `configuration`, i.e. always."""

# print(f"{type(self).__qualname__}.contains_subset_of", configuration, exclude)
if _explain:
print(f"{type(self).__qualname__}.contains_subset_of: Always True")

return True

Expand Down Expand Up @@ -170,69 +228,13 @@ def __init__(
self.configurator = configurator
self.parent = parent

def contains_subset_of(
self, configuration: Mapping, exclude: Optional[Set] = None
) -> bool:
"""
Return True if there exists a sample that is a subset of `configuration`.
- `configuration` matches if it contains additional keys not produced by the sampler.
- `configuration` does not match if it lacks keys produced by the sampler.
- `configuration` does not match if values for existing keys are different.
"""

if exclude is None:
exclude = set()

# The parameter values that should contain the supplied configuration, without `exclude`ed
# If the values contain <unset>, the respective parameter of a trial can assume any value.
parameter_values = {
k: v
for k, v in self.parameter_values.items()
if (k not in exclude) and (unset not in v)
}

conf_parameters = configuration.get("parameters", {})

# Check if all configured parameters are contained
if any(
(k not in conf_parameters) or (conf_parameters[k] not in v)
for k, v in parameter_values.items()
):
return False

parent_params = {
k: v for k, v in conf_parameters.items() if k not in parameter_values
}

# Let parents check the rest of the configuration
exclude = exclude.union(self.parameter_values.keys())
return self.parent.contains_subset_of(
dict(configuration, parameters=parent_params), exclude=exclude
)

def contains_superset_of(self, configuration: Mapping) -> bool:
"""
Return True if there exists a sample that is a superset of `configuration`.
- `configuration` does not match if it contains additional keys not produced by the sampler (or its parents).
- `configuration` matches if it lacks keys produced by the sampler.
- `configuration` does not match if values for existing keys are different.
"""

values = self.parameter_values

own_params, parent_params = split_dict(
configuration.get("parameters", {}), values
)

# Check parameters configured here
if any(v not in values[k] for k, v in own_params.items()):
return False
def get_own_parameter_values(self) -> Mapping[str, Container]:
return {}

# Let parents check the rest of the configuration
return self.parent.contains_superset_of(
dict(configuration, parameters=parent_params)
@cached_property
def parameter_values(self):
return merge_parameter_values(
self.parent.parameter_values, self.get_own_parameter_values()
)

@staticmethod
Expand Down Expand Up @@ -505,34 +507,54 @@ def sample(self, exclude=None) -> Iterator[Mapping]:
for child in self.children:
yield from child.sample(exclude) # pylint: disable=protected-access

@property
@cached_property
def parameter_values(self) -> Mapping[str, Container]:
result: Dict[str, GenerativeContainer] = {}

# Merge the parameter_values of each child
for child in self.children:
for k, v in child.parameter_values.items():
children_parameter_values = [
child.parameter_values for child in self.children
]
for child_parameter_values in children_parameter_values:
for k, v in child_parameter_values.items():
v = GenerativeContainer(v)
if k not in result:
result[k] = v
else:
result[k].update(v)

# Add unset for all parameters that are missing in one of the child configurators
for child in self.children:
for missing_key in set(result.keys()) - (child.parameter_values.keys()):
for child_parameter_values in children_parameter_values:
for missing_key in set(result.keys()) - (child_parameter_values.keys()):
result[missing_key].add(unset)

return result

def contains_subset_of(
self, configuration: Mapping, exclude: Optional[Set] = None
self,
configuration: Mapping,
exclude: Optional[Set] = None,
strict=False,
_explain=False,
) -> bool:
"""Return True if there exists a sample that is a subset of `configuration`, i.e. if there is one in a child."""
"""
Return True if there exists a sample that is a subset of `configuration`, i.e. if there is one in a child.
return any(
s.contains_subset_of(configuration, exclude) for s in self.children
)
Args:
configuration (Mapping): The configuration to test.
exclude (set, optional): Exclude these parameter names. Default: none.
strict (bool, optional): Be strict, i.e. there has to be a child that provides exactly this configuration.
Otherwise, a configuration matches if it is contained in `self.parameter_values`. Default: False.
"""

if strict:
return any(
s.contains_subset_of(configuration, exclude, strict, _explain)
for s in self.children
)

# not strict:
return super().contains_subset_of(configuration, exclude, strict, _explain)

def contains_superset_of(self, configuration: Mapping) -> bool:
"""Return True if there exists a sample that is a superset of `configuration`, i.e. if there is one in a child."""
Expand Down Expand Up @@ -592,7 +614,7 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]:
# print(self.configurator, ":", values)
yield merge_dicts(parent_configuration, parameters=values)

@property
@cached_property
def parameter_values(self) -> Mapping[str, Container]:
return merge_parameter_values(
self.parent.parameter_values,
Expand Down Expand Up @@ -698,7 +720,7 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]:
for values in grid_product:
yield merge_dicts(parent_configuration, parameters=values)

@property
@cached_property
def parameter_values(self) -> Mapping[str, Container]:
return merge_parameter_values(
self.parent.parameter_values,
Expand Down Expand Up @@ -811,10 +833,24 @@ class Clear(Configurator):
*names (str): Parameter names to remove from the configuration.
"""

__str_attrs__ = ("patterns",)
__str_attrs__ = ("include", "exclude")

def __init__(self, *patterns: str):
self.patterns = patterns
def __init__(
self,
include: Union[str, Sequence[str]],
*,
exclude: Union[str, Sequence[str], None] = None,
):
if isinstance(include, str):
include = [include]

if exclude is None:
exclude = []
elif isinstance(exclude, str):
exclude = [exclude]

self.include = include
self.exclude = exclude

class _Sampler(ConfigurationSampler):
configurator: "Clear"
Expand All @@ -832,7 +868,7 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]:

yield parent_configuration

@property
@cached_property
def parameter_values(self):
return {
k: v
Expand All @@ -842,11 +878,15 @@ def parameter_values(self):

def _matches(self, name):
return any(
fnmatch.fnmatchcase(name, pat) for pat in self.configurator.patterns
fnmatch.fnmatchcase(name, pat) for pat in self.configurator.include
)

def contains_subset_of(
self, configuration: Mapping, exclude: Optional[Set] = None
self,
configuration: Mapping,
exclude: Optional[Set] = None,
strict=False,
_explain=False,
) -> bool:
"""
Return True if there exists a sample that is a subset of `configuration`.
Expand All @@ -866,17 +906,4 @@ def contains_subset_of(
if self._matches(k):
exclude.add(k)

# Let parents check the rest of the configuration
return self.parent.contains_subset_of(configuration, exclude=exclude)

def contains_superset_of(self, configuration: Mapping) -> bool:
"""
Return True if there exists a sample that is a superset of `configuration`.
- `configuration` does not match if it contains additional keys not produced by the parents.
- `configuration` matches if it lacks keys produced by the sampler.
- `configuration` does not match if values for existing keys are different.
"""

# Let parents check the rest of the configuration
return self.parent.contains_superset_of(configuration)
return super().contains_subset_of(configuration, exclude, strict, _explain)
13 changes: 4 additions & 9 deletions tests/core/test_configurators.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def test_AdditiveConfiguratorChain():
assert_sampler_contains_subset_of_all_samples(sampler, include_parameters={"d": 4})
assert_sampler_contains_superset_of_all_samples(sampler)

# TODO: Test that every combination of parameter_values is subset in non-strict mode

samples = sorted(
tuple(sorted(configuration["parameters"].items())) for configuration in sampler
)
Expand Down Expand Up @@ -177,15 +179,8 @@ def test_AdditiveConst():
}


def test_Unset():
# FIXME: Unset is currently inherently broken.
# The problem is that Unset("b") does not have access to the outer parameter b
# during sampler.parameter_values
# Solution: Shift parameter_values to sampler, so it has access to the parent's parameter_values.
# This way, all the parameter_values building logic is moved from a global to the specific sampler.

configurator = Const(a=1, b=2, c=3) * (Const() + (Const(a=2) * Clear("b")))
# configurator = Const(a=1, b=2) * Unset("b")
def test_Clear():
configurator = Const(a=1, b=2, c=3) * (Const() + (Const(a=2) * Clear("b*")))

# Test __str__
str(configurator)
Expand Down

0 comments on commit 2535e34

Please # to comment.