diff --git a/experitur/core/logger.py b/experitur/core/logger.py index d5f2a36..df8a9a0 100644 --- a/experitur/core/logger.py +++ b/experitur/core/logger.py @@ -49,21 +49,35 @@ def to_pandas(self): return pd.json_normalize(self.read()) - def aggregate(self, include) -> Dict: - include = set(include) + def aggregate(self, include=None, exclude=None) -> Dict: + """ + Aggregate (min/max/mean/final) each field in the log. + + Args: + include (Collection, optional): If not None, include only these fields. + exclude (Collection, optional): If not None, exclude these fields. + """ + + include = set(include) if include is not None else None + exclude = set(exclude) if exclude is not None else None metrics = defaultdict(list) for entry in self.read(): for k, v in entry.items(): - if k in include: - metrics[k].append(v) + if include is not None and k not in include: + continue + + if exclude is not None and k in exclude: + continue + + metrics[k].append(v) result = {} for k in metrics: result[f"max_{k}"] = max(metrics[k]) result[f"min_{k}"] = min(metrics[k]) - result[f"final_{k}"] = metrics[k][-1] result[f"mean_{k}"] = sum(metrics[k]) / len(metrics[k]) + result[f"final_{k}"] = metrics[k][-1] return result diff --git a/experitur/core/trial.py b/experitur/core/trial.py index bfff4ad..ee848b4 100644 --- a/experitur/core/trial.py +++ b/experitur/core/trial.py @@ -67,6 +67,65 @@ def _to_str(obj): return repr(obj) +def _filter_prefixed(prefix, values: Iterable): + start = len(prefix) + return (k[start:] for k in values if k.startswith(prefix)) + + +class _PrefixedTrialDataSetView(collections.abc.MutableSet): + def __init__(self, trial: "Trial", key) -> None: + self.trial = trial + self.key = key + + def data(self): + return self.trial._data.setdefault(self.key, []) + + def __contains__(self, x: object) -> bool: + x = f"{self.trial._prefix}{x}" + return x in self.data() + + def __iter__(self): + return _filter_prefixed(self.trial._prefix, self.data()) + + def __len__(self): + return sum(1 for _ in self) + + def discard(self, value) -> None: + value = f"{self.trial._prefix}{value}" + self.trial._data[self.key] = [o for o in self.data() if o != value] + + def add(self, value): + value = f"{self.trial._prefix}{value}" + self.data().append(value) + + +class _PrefixedTrialDataDictView(collections.abc.MutableMapping): + def __init__(self, trial: "Trial", key) -> None: + self.trial = trial + self.key = key + + def data(self): + return self.trial._data.setdefault(self.key, {}) + + def __getitem__(self, key): + key = f"{self.trial._prefix}{key}" + return self.data()[key] + + def __setitem__(self, key, value): + key = f"{self.trial._prefix}{key}" + self.data()[key] = value + + def __delitem__(self, key): + key = f"{self.trial._prefix}{key}" + del self.data()[key] + + def __iter__(self): + return _filter_prefixed(self.trial._prefix, self.data()) + + def __len__(self): + return sum(1 for _ in self) + + class Trial(collections.abc.MutableMapping): """ Data related to a trial. @@ -102,10 +161,10 @@ def func(a=1, b=2): """ # Provided by _data: - used_parameters: list + # used_parameters: list id: str wdir: str - resolved_parameters: Dict + # resolved_parameters: Dict def __init__( self, @@ -119,7 +178,8 @@ def __init__( self._prefix = prefix self._record_used_parameters = record_used_parameters - self._data.setdefault("used_parameters", []) + # Initialize + self.used_parameters.data() self._valid = True @@ -134,37 +194,45 @@ def _validate_data(self): return self._valid + @property + def used_parameters(self): + return _PrefixedTrialDataSetView(self, "used_parameters") + + @property + def resolved_parameters(self): + return _PrefixedTrialDataDictView(self, "resolved_parameters") + # MutableMapping provides concrete generic implementations of all # methods except for __getitem__, __setitem__, __delitem__, # __iter__, and __len__. + def update(self, *args, **kwargs): + # Allow a Trial to be updated with itself + values = dict(*args, **kwargs) + super().update(values) + def __getitem__(self, name): """Get the value of a parameter.""" - key = f"{self._prefix}{name}" if self._record_used_parameters: - self.used_parameters.append(key) - return self._data["resolved_parameters"][key] + self.used_parameters.add(name) + + return self.resolved_parameters[name] @property def unused_parameters(self): - return sorted(set(self.resolved_parameters.keys()) - set(self.used_parameters)) + return sorted(set(self.resolved_parameters) - set(self.used_parameters)) def __setitem__(self, name, value): """Set the value of a parameter.""" - self._data["resolved_parameters"][f"{self._prefix}{name}"] = value + self.resolved_parameters[name] = value def __delitem__(self, name): """Delete a parameter.""" - del self._data["resolved_parameters"][f"{self._prefix}{name}"] + del self.resolved_parameters[name] def __iter__(self): - start = len(self._prefix) - return ( - k[start:] - for k in self._data["resolved_parameters"] - if k.startswith(self._prefix) - ) + return iter(self.resolved_parameters) def todict(self, with_prefix=False): """ @@ -180,7 +248,7 @@ def todict(self, with_prefix=False): return data def __len__(self): - return sum(1 for k in self) + return len(self.resolved_parameters) def __repr__(self): return f"" @@ -237,6 +305,14 @@ def _drop_prefix(k: str): return descr + @overload + def get(self, key) -> Any: + ... + + @overload + def get(self, key, default: T) -> T: + ... + def get(self, key, default=None, setdefault=True): """Get a parameter value. @@ -249,9 +325,7 @@ def get(self, key, default=None, setdefault=True): def save(self): # Compact used parameters - self.used_parameters = sorted(set(self.used_parameters)) - # Save unused parameters - self._data["unused_parameters"] = self.unused_parameters + self._data["unused_parameters"] = sorted(self.unused_parameters) # Write to the store self._root.update(self) @@ -369,12 +443,12 @@ def remove(self): """Remove this trial from the store.""" self._root.remove(self) - def get_result(self, name): + def get_result(self, name, default=None): result = self._data["result"] if result is None: - return None + return default - return result.get(name, None) + return result.get(name, default) def update_result(self, values: Optional[Mapping] = None, **kwargs): if self._data["result"] is None: @@ -679,8 +753,15 @@ def get_log(self, aggregate=True): # yield final entry yield acc - def aggregate_log(self, include): - return self._logger.aggregate(include) + def aggregate_log(self, include=None, exclude=None) -> Dict: + """ + Aggregate (min/max/mean/final) each field in the log. + + Args: + include (Collection, optional): If not None, include only these fields. + exclude (Collection, optional): If not None, exclude these fields. + """ + return self._logger.aggregate(include=include, exclude=exclude) def find_files(self, pattern, recursive=False) -> List: """