Skip to content

Commit

Permalink
Improve trial.Trial
Browse files Browse the repository at this point in the history
- Prefixed used_parameters and resolved_parameters
- aggregate_log: exclude fields
  • Loading branch information
moi90 committed Jun 25, 2023
1 parent df4d486 commit 9a909e4
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 29 deletions.
24 changes: 19 additions & 5 deletions experitur/core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,35 @@ def to_pandas(self):

return pd.json_normalize(self.read())

def aggregate(self, include) -> Dict:
include = set(include)
def aggregate(self, include=None, exclude=None) -> Dict:
"""
Aggregate (min/max/mean/final) each field in the log.
Args:
include (Collection, optional): If not None, include only these fields.
exclude (Collection, optional): If not None, exclude these fields.
"""

include = set(include) if include is not None else None
exclude = set(exclude) if exclude is not None else None

metrics = defaultdict(list)
for entry in self.read():
for k, v in entry.items():
if k in include:
metrics[k].append(v)
if include is not None and k not in include:
continue

if exclude is not None and k in exclude:
continue

metrics[k].append(v)

result = {}
for k in metrics:
result[f"max_{k}"] = max(metrics[k])
result[f"min_{k}"] = min(metrics[k])
result[f"final_{k}"] = metrics[k][-1]
result[f"mean_{k}"] = sum(metrics[k]) / len(metrics[k])
result[f"final_{k}"] = metrics[k][-1]

return result

Expand Down
129 changes: 105 additions & 24 deletions experitur/core/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,65 @@ def _to_str(obj):
return repr(obj)


def _filter_prefixed(prefix, values: Iterable):
start = len(prefix)
return (k[start:] for k in values if k.startswith(prefix))


class _PrefixedTrialDataSetView(collections.abc.MutableSet):
def __init__(self, trial: "Trial", key) -> None:
self.trial = trial
self.key = key

def data(self):
return self.trial._data.setdefault(self.key, [])

def __contains__(self, x: object) -> bool:
x = f"{self.trial._prefix}{x}"
return x in self.data()

def __iter__(self):
return _filter_prefixed(self.trial._prefix, self.data())

def __len__(self):
return sum(1 for _ in self)

def discard(self, value) -> None:
value = f"{self.trial._prefix}{value}"
self.trial._data[self.key] = [o for o in self.data() if o != value]

def add(self, value):
value = f"{self.trial._prefix}{value}"
self.data().append(value)


class _PrefixedTrialDataDictView(collections.abc.MutableMapping):
def __init__(self, trial: "Trial", key) -> None:
self.trial = trial
self.key = key

def data(self):
return self.trial._data.setdefault(self.key, {})

def __getitem__(self, key):
key = f"{self.trial._prefix}{key}"
return self.data()[key]

def __setitem__(self, key, value):
key = f"{self.trial._prefix}{key}"
self.data()[key] = value

def __delitem__(self, key):
key = f"{self.trial._prefix}{key}"
del self.data()[key]

def __iter__(self):
return _filter_prefixed(self.trial._prefix, self.data())

def __len__(self):
return sum(1 for _ in self)


class Trial(collections.abc.MutableMapping):
"""
Data related to a trial.
Expand Down Expand Up @@ -102,10 +161,10 @@ def func(a=1, b=2):
"""

# Provided by _data:
used_parameters: list
# used_parameters: list
id: str
wdir: str
resolved_parameters: Dict
# resolved_parameters: Dict

def __init__(
self,
Expand All @@ -119,7 +178,8 @@ def __init__(
self._prefix = prefix
self._record_used_parameters = record_used_parameters

self._data.setdefault("used_parameters", [])
# Initialize
self.used_parameters.data()

self._valid = True

Expand All @@ -134,37 +194,45 @@ def _validate_data(self):

return self._valid

@property
def used_parameters(self):
return _PrefixedTrialDataSetView(self, "used_parameters")

@property
def resolved_parameters(self):
return _PrefixedTrialDataDictView(self, "resolved_parameters")

# MutableMapping provides concrete generic implementations of all
# methods except for __getitem__, __setitem__, __delitem__,
# __iter__, and __len__.

def update(self, *args, **kwargs):
# Allow a Trial to be updated with itself
values = dict(*args, **kwargs)
super().update(values)

def __getitem__(self, name):
"""Get the value of a parameter."""

key = f"{self._prefix}{name}"
if self._record_used_parameters:
self.used_parameters.append(key)
return self._data["resolved_parameters"][key]
self.used_parameters.add(name)

return self.resolved_parameters[name]

@property
def unused_parameters(self):
return sorted(set(self.resolved_parameters.keys()) - set(self.used_parameters))
return sorted(set(self.resolved_parameters) - set(self.used_parameters))

def __setitem__(self, name, value):
"""Set the value of a parameter."""
self._data["resolved_parameters"][f"{self._prefix}{name}"] = value
self.resolved_parameters[name] = value

def __delitem__(self, name):
"""Delete a parameter."""
del self._data["resolved_parameters"][f"{self._prefix}{name}"]
del self.resolved_parameters[name]

def __iter__(self):
start = len(self._prefix)
return (
k[start:]
for k in self._data["resolved_parameters"]
if k.startswith(self._prefix)
)
return iter(self.resolved_parameters)

def todict(self, with_prefix=False):
"""
Expand All @@ -180,7 +248,7 @@ def todict(self, with_prefix=False):
return data

def __len__(self):
return sum(1 for k in self)
return len(self.resolved_parameters)

def __repr__(self):
return f"<Trial({dict(self)})>"
Expand Down Expand Up @@ -237,6 +305,14 @@ def _drop_prefix(k: str):

return descr

@overload
def get(self, key) -> Any:
...

@overload
def get(self, key, default: T) -> T:
...

def get(self, key, default=None, setdefault=True):
"""Get a parameter value.
Expand All @@ -249,9 +325,7 @@ def get(self, key, default=None, setdefault=True):

def save(self):
# Compact used parameters
self.used_parameters = sorted(set(self.used_parameters))
# Save unused parameters
self._data["unused_parameters"] = self.unused_parameters
self._data["unused_parameters"] = sorted(self.unused_parameters)

# Write to the store
self._root.update(self)
Expand Down Expand Up @@ -369,12 +443,12 @@ def remove(self):
"""Remove this trial from the store."""
self._root.remove(self)

def get_result(self, name):
def get_result(self, name, default=None):
result = self._data["result"]
if result is None:
return None
return default

return result.get(name, None)
return result.get(name, default)

def update_result(self, values: Optional[Mapping] = None, **kwargs):
if self._data["result"] is None:
Expand Down Expand Up @@ -679,8 +753,15 @@ def get_log(self, aggregate=True):
# yield final entry
yield acc

def aggregate_log(self, include):
return self._logger.aggregate(include)
def aggregate_log(self, include=None, exclude=None) -> Dict:
"""
Aggregate (min/max/mean/final) each field in the log.
Args:
include (Collection, optional): If not None, include only these fields.
exclude (Collection, optional): If not None, exclude these fields.
"""
return self._logger.aggregate(include=include, exclude=exclude)

def find_files(self, pattern, recursive=False) -> List:
"""
Expand Down

0 comments on commit 9a909e4

Please # to comment.