Skip to content

Commit

Permalink
maintain LookupTable value col order
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier committed Oct 29, 2024
1 parent a42e397 commit e8d49c1
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 41 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
**3.0.15 - TBD/TBD/TBD**
**3.0.16 - 10/30/24**

- Bugfix to prevent a LookupTable from changing order of the value columns

**3.0.15 - 10/25/24**

- Fix mypy errors in vivarium/framework/event.py

Expand Down
35 changes: 21 additions & 14 deletions src/vivarium/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from abc import ABC
from importlib import import_module
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union

import pandas as pd
from layered_config_tree import ConfigurationError, LayeredConfigTree
Expand Down Expand Up @@ -577,7 +577,7 @@ def build_lookup_table(
builder: "Builder",
# todo: replace with LookupTableData
data_source: Union[str, float, int, list, pd.DataFrame],
value_columns: Optional[Iterable[str]] = None,
value_columns: Optional[Sequence[str]] = None,
) -> LookupTable:
"""Builds a LookupTable from a data source.
Expand Down Expand Up @@ -610,31 +610,38 @@ def build_lookup_table(
if isinstance(data, list):
return builder.lookup.build_table(data, value_columns=list(value_columns))
if isinstance(data, pd.DataFrame):
all_columns = set(data.columns)
duplicated_columns = set(data.columns[data.columns.duplicated()])
if duplicated_columns:
raise ConfigurationError(
f"Dataframe contains duplicate columns {duplicated_columns}."
)
all_columns = list(data.columns)
if value_columns is None:
value_columns = set(self.get_value_columns(data))
else:
value_columns = set(value_columns)
value_columns = self.get_value_columns(data)

potential_parameter_columns = [
str(col).removesuffix("_start")
for col in all_columns
if str(col).endswith("_start")
]
parameter_columns = set()
bin_edge_columns = set()
parameter_columns = []
bin_edge_columns = []
for column in potential_parameter_columns:
if f"{column}_end" in all_columns:
parameter_columns.add(column)
bin_edge_columns.update([f"{column}_start", f"{column}_end"])
parameter_columns.append(column)
bin_edge_columns += [f"{column}_start", f"{column}_end"]

key_columns = all_columns - value_columns - bin_edge_columns
key_columns = [
col
for col in all_columns
if col not in value_columns and col not in bin_edge_columns
]

return builder.lookup.build_table(
data=data,
key_columns=list(key_columns),
parameter_columns=list(parameter_columns),
value_columns=list(value_columns),
key_columns=key_columns,
parameter_columns=parameter_columns,
value_columns=value_columns,
)

return builder.lookup.build_table(data)
Expand Down
26 changes: 13 additions & 13 deletions src/vivarium/framework/lookup/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from datetime import datetime, timedelta
from numbers import Number
from typing import TYPE_CHECKING, List, Tuple, Union
from typing import TYPE_CHECKING, List, Sequence, Tuple, Union

import pandas as pd

Expand Down Expand Up @@ -64,9 +64,9 @@ def setup(self, builder: "Builder") -> None:
def build_table(
self,
data: LookupTableData,
key_columns: Union[List[str], Tuple[str, ...]],
parameter_columns: Union[List[str], Tuple[str, ...]],
value_columns: Union[List[str], Tuple[str, ...]],
key_columns: Sequence[str],
parameter_columns: Sequence[str],
value_columns: Sequence[str],
) -> LookupTable:
"""Construct a lookup table from input data."""
table = self._build_table(data, key_columns, parameter_columns, value_columns)
Expand All @@ -78,9 +78,9 @@ def build_table(
def _build_table(
self,
data: LookupTableData,
key_columns: Union[List[str], Tuple[str, ...]],
parameter_columns: Union[List[str], Tuple[str, ...]],
value_columns: Union[List[str], Tuple[str, ...]],
key_columns: Sequence[str],
parameter_columns: Sequence[str],
value_columns: Sequence[str],
) -> LookupTable:
# We don't want to require explicit names for tables, but giving them
# generic names is useful for introspection.
Expand Down Expand Up @@ -135,9 +135,9 @@ def __init__(self, manager: LookupTableManager):
def build_table(
self,
data: LookupTableData,
key_columns: Union[List[str], Tuple[str, ...]] = (),
parameter_columns: Union[List[str], Tuple[str, ...]] = (),
value_columns: Union[List[str], Tuple[str, ...]] = (),
key_columns: Sequence[str] = (),
parameter_columns: Sequence[str] = (),
value_columns: Sequence[str] = (),
) -> LookupTable:
"""Construct a LookupTable from input data.
Expand Down Expand Up @@ -180,9 +180,9 @@ def build_table(

def validate_build_table_parameters(
data: LookupTableData,
key_columns: Union[List[str], Tuple[str, ...]],
parameter_columns: Union[List[str], Tuple[str, ...]],
value_columns: Union[List[str], Tuple[str, ...]],
key_columns: Sequence[str],
parameter_columns: Sequence[str],
value_columns: Sequence[str],
) -> None:
"""Makes sure the data format agrees with the provided column layout."""
if (
Expand Down
18 changes: 9 additions & 9 deletions src/vivarium/framework/lookup/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import dataclasses
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Sequence, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -46,12 +46,12 @@ class LookupTable(ABC):
"""The data from which to build the interpolation."""
population_view_builder: Callable = None
"""Callable to get a population view to be used by the lookup table."""
key_columns: Union[List[str], Tuple[str]] = ()
key_columns: Sequence[str] = ()
"""Column names to be used as categorical parameters in Interpolation
to select between interpolation functions."""
parameter_columns: Union[List[str], Tuple] = ()
parameter_columns: Sequence[str] = ()
"""Column names to be used as continuous parameters in Interpolation."""
value_columns: Union[List[str], Tuple[str]] = ()
value_columns: Sequence[str] = ()
"""Names of value columns to be interpolated over."""
interpolation_order: int = 0
"""Order of interpolation. Used to decide interpolation strategy."""
Expand Down Expand Up @@ -107,9 +107,9 @@ def __init__(
table_number: int,
data: Union[ScalarValue, pd.DataFrame, List[ScalarValue], Tuple[ScalarValue]],
population_view_builder: Callable,
key_columns: Union[List[str], Tuple[str]],
parameter_columns: Union[List[str], Tuple],
value_columns: Union[List[str], Tuple[str]],
key_columns: Sequence[str],
parameter_columns: Sequence[str],
value_columns: Sequence[str],
interpolation_order: int,
clock: Callable,
extrapolate: bool,
Expand Down Expand Up @@ -202,8 +202,8 @@ def __init__(
table_number: int,
data: Union[ScalarValue, pd.DataFrame, List[ScalarValue], Tuple[ScalarValue]],
population_view_builder: Callable,
key_columns: Union[List[str], Tuple[str]],
value_columns: Union[List[str], Tuple[str]],
key_columns: Sequence[str],
value_columns: Sequence[str],
**kwargs,
):
super().__init__(
Expand Down
29 changes: 27 additions & 2 deletions tests/framework/components/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
FilteredPopulationView,
LookupCreator,
NoPopulationView,
OrderedColumnsLookupCreator,
Parameterized,
ParameterizedByComponent,
SingleLookupCreator,
)
from vivarium import Artifact, InteractiveContext
from vivarium.framework.engine import Builder
from vivarium.framework.lookup.table import ScalarTable
from vivarium.framework.lookup.table import CategoricalTable, InterpolatedTable, ScalarTable
from vivarium.framework.population import PopulationError


Expand Down Expand Up @@ -244,7 +245,7 @@ def test_listeners_are_registered_at_custom_priorities():
assert component.on_simulation_end in set(simulation_end_methods.get(1, []))


def test_component_configuration_gets_set(base_config):
def test_component_configuration_gets_set():
without_config = ColumnCreator()
with_config = ColumnRequirer()

Expand Down Expand Up @@ -383,3 +384,27 @@ def test_failing_component_lookup_table_configurations(
sim.configuration.update(override_config)
with pytest.raises(error_type, match=match):
sim.setup()


@pytest.mark.parametrize(
"table", ["ordered_columns_categorical", "ordered_columns_interpolated"]
)
def test_value_column_order_is_maintained(table):
"""Tests that the order of value columns is maintained when creating a LookupTable.
Notes
-----
This test is a bit of a hack. We found an issue where the order of value columns
was changing due to casting the value columns as a set on the back end (which
does not guarantee order). The problem is that we can't actually guarantee
that casting as a set will change the order either. However, with a large
enough number of value columns, it seems likely that the order will change.
"""
component = OrderedColumnsLookupCreator()
sim = InteractiveContext(components=[component])
lookup_table = component.lookup_tables[table]
assert isinstance(
lookup_table, CategoricalTable if "categorical" in table else InterpolatedTable
)
data = lookup_table(sim.get_population().index)
assert list(data.columns) == ["one", "two", "three", "four", "five", "six", "seven"]
41 changes: 39 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,46 @@ def load_baking_time(_builder: Builder) -> float:


class SingleLookupCreator(ColumnCreator):
pass


class OrderedColumnsLookupCreator(Component):
@property
def standard_lookup_tables(self) -> List[str]:
return ["favorite_color"]
def columns_created(self) -> List[str]:
return ["foo", "bar"]

def on_initialize_simulants(self, pop_data: SimulantData) -> None:
initialization_data = pd.DataFrame(
{
"foo": "key1",
"bar": 15,
},
index=pop_data.index,
)
self.population_view.update(initialization_data)

def build_all_lookup_tables(self, builder: "Builder") -> None:
value_columns = ["one", "two", "three", "four", "five", "six", "seven"]
ordered_columns = pd.DataFrame(
[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14]],
columns=value_columns,
)
ordered_columns_categorical = ordered_columns.copy()
ordered_columns_categorical["foo"] = ["key1", "key2"]
ordered_columns_interpolated = ordered_columns.copy()
ordered_columns_interpolated["foo"] = ["key1", "key1"]
ordered_columns_interpolated["bar_start"] = [10, 20]
ordered_columns_interpolated["bar_end"] = [20, 30]
self.lookup_tables["ordered_columns_categorical"] = self.build_lookup_table(
builder,
ordered_columns_categorical,
value_columns,
)
self.lookup_tables["ordered_columns_interpolated"] = self.build_lookup_table(
builder,
ordered_columns_interpolated,
value_columns,
)


class ColumnRequirer(Component):
Expand Down

0 comments on commit e8d49c1

Please # to comment.