diff --git a/experitur/core/experiment.py b/experitur/core/experiment.py index 526fc69..f0391dc 100644 --- a/experitur/core/experiment.py +++ b/experitur/core/experiment.py @@ -506,9 +506,7 @@ def _trial_generator(self): ).as_dict(), ) - trial = self.ctx.trials.create( - trial_configuration, record_used_parameters=True - ) + trial = self.ctx.trials.create(trial_configuration) os.makedirs(trial.wdir, exist_ok=True) yield trial @@ -653,7 +651,7 @@ def run_trial(self, trial: Trial): self._handle_event("on_pre_run", trial) try: - with self.ctx.set_current_trial(trial): + with self.ctx.set_current_trial(trial), trial.record_used_parameters(): result = self.func(trial, *args, **kwargs) # Merge returned result into existing result diff --git a/experitur/core/trial.py b/experitur/core/trial.py index ee848b4..fb75b2a 100644 --- a/experitur/core/trial.py +++ b/experitur/core/trial.py @@ -1,4 +1,5 @@ import collections.abc +import contextlib import copy import datetime import glob @@ -85,7 +86,11 @@ def __contains__(self, x: object) -> bool: return x in self.data() def __iter__(self): - return _filter_prefixed(self.trial._prefix, self.data()) + return iter(set(_filter_prefixed(self.trial._prefix, self.data()))) + + def __repr__(self): + items = ", ".join(repr(o) for o in self) + return f"<_PrefixedTrialDataSetView {{{items}}}>" def __len__(self): return sum(1 for _ in self) @@ -219,9 +224,12 @@ def __getitem__(self, name): return self.resolved_parameters[name] + def __eq__(self, o: object) -> bool: + return self is o + @property - def unused_parameters(self): - return sorted(set(self.resolved_parameters) - set(self.used_parameters)) + def unused_parameters(self) -> Set: + return set(self.resolved_parameters) - set(self.used_parameters) def __setitem__(self, name, value): """Set the value of a parameter.""" @@ -234,6 +242,17 @@ def __delitem__(self, name): def __iter__(self): return iter(self.resolved_parameters) + def items(self): + with self.record_used_parameters(False): + return super().items() + + @contextlib.contextmanager + def record_used_parameters(self, record=True): + old_value = self._record_used_parameters + self._record_used_parameters = record + yield + self._record_used_parameters = old_value + def todict(self, with_prefix=False): """ Convert trial data to dictionary. @@ -324,7 +343,8 @@ def get(self, key, default=None, setdefault=True): return super().get(key, default) def save(self): - # Compact used parameters + # Store and compact used and unused parameters + self._data["used_parameters"] = sorted(self.used_parameters) self._data["unused_parameters"] = sorted(self.unused_parameters) # Write to the store diff --git a/tests/core/test_trial.py b/tests/core/test_trial.py index 95771b1..6ec4d48 100644 --- a/tests/core/test_trial.py +++ b/tests/core/test_trial.py @@ -52,14 +52,18 @@ def test_trial_parameters(tmp_path): @Experiment(configurator={"a": [1], "b": [2], "c": ["{a}"]}) def experiment(trial: Trial): - assert ctx.current_trial == trial + assert ctx.current_trial is trial + + assert trial.used_parameters == set() + assert trial.unused_parameters == {"a", "b", "c"} assert trial["a"] == 1 assert trial["b"] == 2 assert trial["c"] == trial["a"] assert len(trial) == 3 - print(trial["a"], trial["c"]) + assert trial.used_parameters == {"a", "b", "c"} + assert trial.unused_parameters == set() for k, v in trial.items(): pass @@ -85,7 +89,7 @@ def experiment(trial: Trial): trial["prefix__" + k] = v trial["prefix1__" + k] = v - assert trial.prefixed("prefix__") == seed + assert trial.prefixed("prefix__").todict() == seed # test call def identity(a, b, c=4, d=5): @@ -104,7 +108,7 @@ def identity(a, b, c=4, d=5): # test record_defaults trial.prefixed("prefix2__").record_defaults(identity) - assert trial.prefixed("prefix2__") == {"c": 4, "d": 5} + assert trial.prefixed("prefix2__").todict() == {"c": 4, "d": 5} # test call: functools.partial @@ -125,7 +129,7 @@ def identity(a, b, c=4, d=5): # Keyword arguments will *not* be recorded and can *not* be overwritten identity_a8_kwd = functools.partial(identity, a=8) trial.prefixed("prefix4_").record_defaults(identity_a8_kwd) - assert trial.prefixed("prefix4_") == {"c": 4, "d": 5} + assert trial.prefixed("prefix4_").todict() == {"c": 4, "d": 5} with pytest.raises(TypeError): trial.prefixed("prefix5_").record_defaults(identity_a8_kwd, a=9) @@ -146,7 +150,7 @@ def identity(a, b, c=4, d=5): 4, 9, ) - assert trial.prefixed("prefix7__") == {"c": 4} + assert trial.prefixed("prefix7__").todict() == {"c": 4} with pytest.raises(TypeError): trial.prefixed("prefix7__").call(identity_d9_kwd, 1, 2, 5, 10) @@ -165,7 +169,7 @@ def identity(a, b, c=4, d=5): # setdefaults assert trial.prefixed("prefix8__").setdefaults( dict(a=1, b=2, c=3, d=4), e=10 - ) == dict(a=1, b=2, c=3, d=4, e=10) + ).todict() == dict(a=1, b=2, c=3, d=4, e=10) assert dict(trial.prefixed("prefix8__")) == dict(a=1, b=2, c=3, d=4, e=10) # Make sure that the default value is recorded when using .get