Skip to content

Commit

Permalink
Feature/sbachmei/mic 5127 update standard observers (#441)
Browse files Browse the repository at this point in the history
* reset index in default results_formatter

* handle some .astype() warnings

* more observers configuration_defaults stratification name into getter method
  • Loading branch information
stevebachmeier committed Aug 13, 2024
1 parent 98dbcd4 commit c274906
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/vivarium/framework/results/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def register_adding_observation(
requires_values: List[str] = [],
results_formatter: Callable[
[str, pd.DataFrame], pd.DataFrame
] = lambda measure, results: results,
] = lambda measure, results: results.reset_index(),
additional_stratifications: List[str] = [],
excluded_stratifications: List[str] = [],
aggregator_sources: Optional[List[str]] = None,
Expand Down
8 changes: 5 additions & 3 deletions src/vivarium/framework/results/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,13 @@ def initialize_results(
df = pd.DataFrame(
list(itertools.product(*stratification_values.values())),
columns=stratification_names,
).astype(CategoricalDtype)
).astype(CategoricalDtype())
else:
# We are aggregating the entire population so create a single-row index
stratification_names = ["stratification"]
df = pd.DataFrame(["all"], columns=stratification_names).astype(CategoricalDtype)
df = pd.DataFrame(["all"], columns=stratification_names).astype(
CategoricalDtype()
)

# Initialize a zeros dataframe
df[VALUE_COLUMN] = 0.0
Expand Down Expand Up @@ -170,7 +172,7 @@ def _aggregate(
pop_groups[aggregator_sources].apply(aggregator).fillna(0.0)
if aggregator_sources
else pop_groups.apply(aggregator)
)
).astype(float)
return aggregates

@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion src/vivarium/framework/results/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@ def __init__(self) -> None:
def configuration_defaults(self) -> Dict[str, Any]:
return {
"stratification": {
self.name.split("_observer")[0]: {
self.get_configuration_name(): {
"exclude": [],
"include": [],
},
},
}

def get_configuration_name(self) -> str:
return self.name.split("_observer")[0]

@abstractmethod
def register_observations(self, builder: Builder) -> None:
"""(Required). Register observations with within each observer."""
Expand Down

0 comments on commit c274906

Please # to comment.