Skip to content

Commit

Permalink
stop writing seed/draw to simulation results (#436)
Browse files Browse the repository at this point in the history
* stop writing seed/draw to simulation results

* pin numpy<2.0.0
  • Loading branch information
stevebachmeier authored Jun 20, 2024
1 parent 923cb57 commit 8fb657c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 30 deletions.
10 changes: 0 additions & 10 deletions src/vivarium/framework/results/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class Observer(Component, ABC):
def __init__(self) -> None:
super().__init__()
self.results_dir = None
self.input_draw = None
self.random_seed = None

@abstractmethod
def register_observations(self, builder: Builder) -> None:
Expand All @@ -36,14 +34,6 @@ def get_formatter_attributes(self, builder: Builder) -> None:
.get("output_data", {})
.get("results_directory", None)
)
self.input_draw = (
builder.configuration.to_dict()
.get("input_data", {})
.get("input_draw_number", None)
)
self.random_seed = (
builder.configuration.to_dict().get("randomness", {}).get("random_seed", None)
)


class StratifiedObserver(Observer):
Expand Down
13 changes: 4 additions & 9 deletions tests/framework/results/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
from functools import partial
from typing import List

import numpy as np
Expand Down Expand Up @@ -122,7 +121,7 @@ def register_observations(self, builder: Builder) -> None:
requires_columns=[
"house_points",
],
results_formatter=partial(formatter, self.random_seed, self.input_draw),
results_formatter=results_formatter,
)


Expand Down Expand Up @@ -154,7 +153,7 @@ def register_observations(self, builder: Builder) -> None:
requires_columns=[
"quidditch_wins",
],
results_formatter=partial(formatter, self.random_seed, self.input_draw),
results_formatter=results_formatter,
)


Expand All @@ -170,7 +169,7 @@ def register_observations(self, builder: Builder) -> None:
requires_columns=[
"quidditch_wins",
],
results_formatter=partial(formatter, self.random_seed, self.input_draw),
results_formatter=results_formatter,
)


Expand Down Expand Up @@ -266,17 +265,13 @@ def setup(self, builder: Builder) -> None:
##################


def formatter(
random_seed: str,
input_draw: str,
def results_formatter(
measure: str,
results: pd.DataFrame,
) -> pd.DataFrame:
"""An test use case of an observer's report method that writes a DataFrame to a CSV file."""
# Add extra cols
results["measure"] = measure
results["random_seed"] = random_seed
results["input_draw"] = input_draw
# Sort the columns such that the stratifications (index) are first
# and VALUE_COLUMN is last and sort the rows by the stratifications.
other_cols = [c for c in results.columns if c != VALUE_COLUMN]
Expand Down
12 changes: 4 additions & 8 deletions tests/framework/results/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,27 @@ def test_observer_instantiation():


@pytest.mark.parametrize(
"is_interactive, results_dir, draw, seed",
"is_interactive, results_dir",
[
(False, "/some/results/dir", 111, 222),
(True, None, None, None),
(False, "/some/results/dir"),
(True, None),
],
)
def test_get_formatter_attributes(is_interactive, results_dir, draw, seed, mocker):
def test_get_formatter_attributes(is_interactive, results_dir, mocker):
builder = mocker.Mock()
if is_interactive:
builder.configuration = LayeredConfigTree()
else:
builder.configuration = LayeredConfigTree(
{
"output_data": {"results_directory": results_dir},
"input_data": {"input_draw_number": draw},
"randomness": {"random_seed": seed},
}
)

observer = TestObserver()
observer.get_formatter_attributes(builder)

assert observer.results_dir == results_dir
assert observer.input_draw == draw
assert observer.random_seed == seed


@pytest.mark.parametrize(
Expand Down
3 changes: 0 additions & 3 deletions tests/framework/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,6 @@ def test_get_results_formatting(SimulationContext, base_config):
df = eval(measure)
# Check that metrics col matches name of dataset
assert (df["measure"] == measure).all()
# Check for other cols
assert "random_seed" in df.columns
assert "input_draw" in df.columns
# We do enforce a col order, but most importantly ensure VALUE_COLUMN is at the end
assert df.columns[-1] == VALUE_COLUMN
# Check values
Expand Down

0 comments on commit 8fb657c

Please # to comment.