Skip to content

Commit

Permalink
Test and fix Trial
Browse files Browse the repository at this point in the history
  • Loading branch information
moi90 committed Jun 25, 2023
1 parent 9a909e4 commit 1e85064
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
6 changes: 2 additions & 4 deletions experitur/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,7 @@ def _trial_generator(self):
).as_dict(),
)

trial = self.ctx.trials.create(
trial_configuration, record_used_parameters=True
)
trial = self.ctx.trials.create(trial_configuration)
os.makedirs(trial.wdir, exist_ok=True)

yield trial
Expand Down Expand Up @@ -653,7 +651,7 @@ def run_trial(self, trial: Trial):
self._handle_event("on_pre_run", trial)

try:
with self.ctx.set_current_trial(trial):
with self.ctx.set_current_trial(trial), trial.record_used_parameters():
result = self.func(trial, *args, **kwargs)

# Merge returned result into existing result
Expand Down
28 changes: 24 additions & 4 deletions experitur/core/trial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections.abc
import contextlib
import copy
import datetime
import glob
Expand Down Expand Up @@ -85,7 +86,11 @@ def __contains__(self, x: object) -> bool:
return x in self.data()

def __iter__(self):
return _filter_prefixed(self.trial._prefix, self.data())
return iter(set(_filter_prefixed(self.trial._prefix, self.data())))

def __repr__(self):
items = ", ".join(repr(o) for o in self)
return f"<_PrefixedTrialDataSetView {{{items}}}>"

def __len__(self):
return sum(1 for _ in self)
Expand Down Expand Up @@ -219,9 +224,12 @@ def __getitem__(self, name):

return self.resolved_parameters[name]

def __eq__(self, o: object) -> bool:
return self is o

@property
def unused_parameters(self):
return sorted(set(self.resolved_parameters) - set(self.used_parameters))
def unused_parameters(self) -> Set:
return set(self.resolved_parameters) - set(self.used_parameters)

def __setitem__(self, name, value):
"""Set the value of a parameter."""
Expand All @@ -234,6 +242,17 @@ def __delitem__(self, name):
def __iter__(self):
return iter(self.resolved_parameters)

def items(self):
with self.record_used_parameters(False):
return super().items()

@contextlib.contextmanager
def record_used_parameters(self, record=True):
old_value = self._record_used_parameters
self._record_used_parameters = record
yield
self._record_used_parameters = old_value

def todict(self, with_prefix=False):
"""
Convert trial data to dictionary.
Expand Down Expand Up @@ -324,7 +343,8 @@ def get(self, key, default=None, setdefault=True):
return super().get(key, default)

def save(self):
# Compact used parameters
# Store and compact used and unused parameters
self._data["used_parameters"] = sorted(self.used_parameters)
self._data["unused_parameters"] = sorted(self.unused_parameters)

# Write to the store
Expand Down
18 changes: 11 additions & 7 deletions tests/core/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,18 @@ def test_trial_parameters(tmp_path):

@Experiment(configurator={"a": [1], "b": [2], "c": ["{a}"]})
def experiment(trial: Trial):
assert ctx.current_trial == trial
assert ctx.current_trial is trial

assert trial.used_parameters == set()
assert trial.unused_parameters == {"a", "b", "c"}

assert trial["a"] == 1
assert trial["b"] == 2
assert trial["c"] == trial["a"]
assert len(trial) == 3

print(trial["a"], trial["c"])
assert trial.used_parameters == {"a", "b", "c"}
assert trial.unused_parameters == set()

for k, v in trial.items():
pass
Expand All @@ -85,7 +89,7 @@ def experiment(trial: Trial):
trial["prefix__" + k] = v
trial["prefix1__" + k] = v

assert trial.prefixed("prefix__") == seed
assert trial.prefixed("prefix__").todict() == seed

# test call
def identity(a, b, c=4, d=5):
Expand All @@ -104,7 +108,7 @@ def identity(a, b, c=4, d=5):

# test record_defaults
trial.prefixed("prefix2__").record_defaults(identity)
assert trial.prefixed("prefix2__") == {"c": 4, "d": 5}
assert trial.prefixed("prefix2__").todict() == {"c": 4, "d": 5}

# test call: functools.partial

Expand All @@ -125,7 +129,7 @@ def identity(a, b, c=4, d=5):
# Keyword arguments will *not* be recorded and can *not* be overwritten
identity_a8_kwd = functools.partial(identity, a=8)
trial.prefixed("prefix4_").record_defaults(identity_a8_kwd)
assert trial.prefixed("prefix4_") == {"c": 4, "d": 5}
assert trial.prefixed("prefix4_").todict() == {"c": 4, "d": 5}

with pytest.raises(TypeError):
trial.prefixed("prefix5_").record_defaults(identity_a8_kwd, a=9)
Expand All @@ -146,7 +150,7 @@ def identity(a, b, c=4, d=5):
4,
9,
)
assert trial.prefixed("prefix7__") == {"c": 4}
assert trial.prefixed("prefix7__").todict() == {"c": 4}

with pytest.raises(TypeError):
trial.prefixed("prefix7__").call(identity_d9_kwd, 1, 2, 5, 10)
Expand All @@ -165,7 +169,7 @@ def identity(a, b, c=4, d=5):
# setdefaults
assert trial.prefixed("prefix8__").setdefaults(
dict(a=1, b=2, c=3, d=4), e=10
) == dict(a=1, b=2, c=3, d=4, e=10)
).todict() == dict(a=1, b=2, c=3, d=4, e=10)
assert dict(trial.prefixed("prefix8__")) == dict(a=1, b=2, c=3, d=4, e=10)

# Make sure that the default value is recorded when using .get
Expand Down

0 comments on commit 1e85064

Please # to comment.