From e8d49c19b06e5890ef978b20fc22e328a8671f4b Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Tue, 29 Oct 2024 16:46:17 -0700 Subject: [PATCH] maintain LookupTable value col order --- CHANGELOG.rst | 6 ++- src/vivarium/component.py | 35 ++++++++++------- src/vivarium/framework/lookup/manager.py | 26 ++++++------- src/vivarium/framework/lookup/table.py | 18 ++++----- tests/framework/components/test_component.py | 29 +++++++++++++- tests/helpers.py | 41 +++++++++++++++++++- 6 files changed, 114 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4b92d4514..c412b1428 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,4 +1,8 @@ -**3.0.15 - TBD/TBD/TBD** +**3.0.16 - 10/30/24** + + - Bugfix to prevent a LookupTable from changing order of the value columns + +**3.0.15 - 10/25/24** - Fix mypy errors in vivarium/framework/event.py diff --git a/src/vivarium/component.py b/src/vivarium/component.py index bbe40d4f3..70e0a0fb4 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -13,7 +13,7 @@ from abc import ABC from importlib import import_module from inspect import signature -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union import pandas as pd from layered_config_tree import ConfigurationError, LayeredConfigTree @@ -577,7 +577,7 @@ def build_lookup_table( builder: "Builder", # todo: replace with LookupTableData data_source: Union[str, float, int, list, pd.DataFrame], - value_columns: Optional[Iterable[str]] = None, + value_columns: Optional[Sequence[str]] = None, ) -> LookupTable: """Builds a LookupTable from a data source. @@ -610,31 +610,38 @@ def build_lookup_table( if isinstance(data, list): return builder.lookup.build_table(data, value_columns=list(value_columns)) if isinstance(data, pd.DataFrame): - all_columns = set(data.columns) + duplicated_columns = set(data.columns[data.columns.duplicated()]) + if duplicated_columns: + raise ConfigurationError( + f"Dataframe contains duplicate columns {duplicated_columns}." + ) + all_columns = list(data.columns) if value_columns is None: - value_columns = set(self.get_value_columns(data)) - else: - value_columns = set(value_columns) + value_columns = self.get_value_columns(data) potential_parameter_columns = [ str(col).removesuffix("_start") for col in all_columns if str(col).endswith("_start") ] - parameter_columns = set() - bin_edge_columns = set() + parameter_columns = [] + bin_edge_columns = [] for column in potential_parameter_columns: if f"{column}_end" in all_columns: - parameter_columns.add(column) - bin_edge_columns.update([f"{column}_start", f"{column}_end"]) + parameter_columns.append(column) + bin_edge_columns += [f"{column}_start", f"{column}_end"] - key_columns = all_columns - value_columns - bin_edge_columns + key_columns = [ + col + for col in all_columns + if col not in value_columns and col not in bin_edge_columns + ] return builder.lookup.build_table( data=data, - key_columns=list(key_columns), - parameter_columns=list(parameter_columns), - value_columns=list(value_columns), + key_columns=key_columns, + parameter_columns=parameter_columns, + value_columns=value_columns, ) return builder.lookup.build_table(data) diff --git a/src/vivarium/framework/lookup/manager.py b/src/vivarium/framework/lookup/manager.py index cd0a3b3bb..1b4635639 100644 --- a/src/vivarium/framework/lookup/manager.py +++ b/src/vivarium/framework/lookup/manager.py @@ -15,7 +15,7 @@ from datetime import datetime, timedelta from numbers import Number -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, List, Sequence, Tuple, Union import pandas as pd @@ -64,9 +64,9 @@ def setup(self, builder: "Builder") -> None: def build_table( self, data: LookupTableData, - key_columns: Union[List[str], Tuple[str, ...]], - parameter_columns: Union[List[str], Tuple[str, ...]], - value_columns: Union[List[str], Tuple[str, ...]], + key_columns: Sequence[str], + parameter_columns: Sequence[str], + value_columns: Sequence[str], ) -> LookupTable: """Construct a lookup table from input data.""" table = self._build_table(data, key_columns, parameter_columns, value_columns) @@ -78,9 +78,9 @@ def build_table( def _build_table( self, data: LookupTableData, - key_columns: Union[List[str], Tuple[str, ...]], - parameter_columns: Union[List[str], Tuple[str, ...]], - value_columns: Union[List[str], Tuple[str, ...]], + key_columns: Sequence[str], + parameter_columns: Sequence[str], + value_columns: Sequence[str], ) -> LookupTable: # We don't want to require explicit names for tables, but giving them # generic names is useful for introspection. @@ -135,9 +135,9 @@ def __init__(self, manager: LookupTableManager): def build_table( self, data: LookupTableData, - key_columns: Union[List[str], Tuple[str, ...]] = (), - parameter_columns: Union[List[str], Tuple[str, ...]] = (), - value_columns: Union[List[str], Tuple[str, ...]] = (), + key_columns: Sequence[str] = (), + parameter_columns: Sequence[str] = (), + value_columns: Sequence[str] = (), ) -> LookupTable: """Construct a LookupTable from input data. @@ -180,9 +180,9 @@ def build_table( def validate_build_table_parameters( data: LookupTableData, - key_columns: Union[List[str], Tuple[str, ...]], - parameter_columns: Union[List[str], Tuple[str, ...]], - value_columns: Union[List[str], Tuple[str, ...]], + key_columns: Sequence[str], + parameter_columns: Sequence[str], + value_columns: Sequence[str], ) -> None: """Makes sure the data format agrees with the provided column layout.""" if ( diff --git a/src/vivarium/framework/lookup/table.py b/src/vivarium/framework/lookup/table.py index 7ee048622..555e641d6 100644 --- a/src/vivarium/framework/lookup/table.py +++ b/src/vivarium/framework/lookup/table.py @@ -14,7 +14,7 @@ import dataclasses from abc import ABC, abstractmethod -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Sequence, Tuple, Union import numpy as np import pandas as pd @@ -46,12 +46,12 @@ class LookupTable(ABC): """The data from which to build the interpolation.""" population_view_builder: Callable = None """Callable to get a population view to be used by the lookup table.""" - key_columns: Union[List[str], Tuple[str]] = () + key_columns: Sequence[str] = () """Column names to be used as categorical parameters in Interpolation to select between interpolation functions.""" - parameter_columns: Union[List[str], Tuple] = () + parameter_columns: Sequence[str] = () """Column names to be used as continuous parameters in Interpolation.""" - value_columns: Union[List[str], Tuple[str]] = () + value_columns: Sequence[str] = () """Names of value columns to be interpolated over.""" interpolation_order: int = 0 """Order of interpolation. Used to decide interpolation strategy.""" @@ -107,9 +107,9 @@ def __init__( table_number: int, data: Union[ScalarValue, pd.DataFrame, List[ScalarValue], Tuple[ScalarValue]], population_view_builder: Callable, - key_columns: Union[List[str], Tuple[str]], - parameter_columns: Union[List[str], Tuple], - value_columns: Union[List[str], Tuple[str]], + key_columns: Sequence[str], + parameter_columns: Sequence[str], + value_columns: Sequence[str], interpolation_order: int, clock: Callable, extrapolate: bool, @@ -202,8 +202,8 @@ def __init__( table_number: int, data: Union[ScalarValue, pd.DataFrame, List[ScalarValue], Tuple[ScalarValue]], population_view_builder: Callable, - key_columns: Union[List[str], Tuple[str]], - value_columns: Union[List[str], Tuple[str]], + key_columns: Sequence[str], + value_columns: Sequence[str], **kwargs, ): super().__init__( diff --git a/tests/framework/components/test_component.py b/tests/framework/components/test_component.py index 5e265130a..5c79d83b0 100644 --- a/tests/framework/components/test_component.py +++ b/tests/framework/components/test_component.py @@ -12,13 +12,14 @@ FilteredPopulationView, LookupCreator, NoPopulationView, + OrderedColumnsLookupCreator, Parameterized, ParameterizedByComponent, SingleLookupCreator, ) from vivarium import Artifact, InteractiveContext from vivarium.framework.engine import Builder -from vivarium.framework.lookup.table import ScalarTable +from vivarium.framework.lookup.table import CategoricalTable, InterpolatedTable, ScalarTable from vivarium.framework.population import PopulationError @@ -244,7 +245,7 @@ def test_listeners_are_registered_at_custom_priorities(): assert component.on_simulation_end in set(simulation_end_methods.get(1, [])) -def test_component_configuration_gets_set(base_config): +def test_component_configuration_gets_set(): without_config = ColumnCreator() with_config = ColumnRequirer() @@ -383,3 +384,27 @@ def test_failing_component_lookup_table_configurations( sim.configuration.update(override_config) with pytest.raises(error_type, match=match): sim.setup() + + +@pytest.mark.parametrize( + "table", ["ordered_columns_categorical", "ordered_columns_interpolated"] +) +def test_value_column_order_is_maintained(table): + """Tests that the order of value columns is maintained when creating a LookupTable. + + Notes + ----- + This test is a bit of a hack. We found an issue where the order of value columns + was changing due to casting the value columns as a set on the back end (which + does not guarantee order). The problem is that we can't actually guarantee + that casting as a set will change the order either. However, with a large + enough number of value columns, it seems likely that the order will change. + """ + component = OrderedColumnsLookupCreator() + sim = InteractiveContext(components=[component]) + lookup_table = component.lookup_tables[table] + assert isinstance( + lookup_table, CategoricalTable if "categorical" in table else InterpolatedTable + ) + data = lookup_table(sim.get_population().index) + assert list(data.columns) == ["one", "two", "three", "four", "five", "six", "seven"] diff --git a/tests/helpers.py b/tests/helpers.py index a342e0803..d2cfc0d6a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -183,9 +183,46 @@ def load_baking_time(_builder: Builder) -> float: class SingleLookupCreator(ColumnCreator): + pass + + +class OrderedColumnsLookupCreator(Component): @property - def standard_lookup_tables(self) -> List[str]: - return ["favorite_color"] + def columns_created(self) -> List[str]: + return ["foo", "bar"] + + def on_initialize_simulants(self, pop_data: SimulantData) -> None: + initialization_data = pd.DataFrame( + { + "foo": "key1", + "bar": 15, + }, + index=pop_data.index, + ) + self.population_view.update(initialization_data) + + def build_all_lookup_tables(self, builder: "Builder") -> None: + value_columns = ["one", "two", "three", "four", "five", "six", "seven"] + ordered_columns = pd.DataFrame( + [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14]], + columns=value_columns, + ) + ordered_columns_categorical = ordered_columns.copy() + ordered_columns_categorical["foo"] = ["key1", "key2"] + ordered_columns_interpolated = ordered_columns.copy() + ordered_columns_interpolated["foo"] = ["key1", "key1"] + ordered_columns_interpolated["bar_start"] = [10, 20] + ordered_columns_interpolated["bar_end"] = [20, 30] + self.lookup_tables["ordered_columns_categorical"] = self.build_lookup_table( + builder, + ordered_columns_categorical, + value_columns, + ) + self.lookup_tables["ordered_columns_interpolated"] = self.build_lookup_table( + builder, + ordered_columns_interpolated, + value_columns, + ) class ColumnRequirer(Component):