diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 2fff1980f..1d3b3977f 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -92,6 +92,13 @@ def add_stratification( raise ValueError( f"Stratification name '{name}' is already used: {str(already_used[0])}." ) + unique_categories = set(categories) + if len(categories) != len(unique_categories): + for category in unique_categories: + categories.remove(category) + raise ValueError( + f"Found duplicate categories in stratification '{name}': {categories}." + ) stratification = Stratification(name, sources, categories, mapper, is_vectorized) self.stratifications.append(stratification) diff --git a/tests/framework/results/test_context.py b/tests/framework/results/test_context.py index dbbcae2cd..376c05f70 100644 --- a/tests/framework/results/test_context.py +++ b/tests/framework/results/test_context.py @@ -1,5 +1,6 @@ import itertools import math +import re from datetime import timedelta import pandas as pd @@ -24,6 +25,11 @@ ) +def _aggregate_state_person_time(x: pd.DataFrame) -> float: + """Helper aggregator function for observation testing""" + return len(x) * (28 / 365.25) + + @pytest.mark.parametrize( "name, sources, categories, mapper, is_vectorized", [ @@ -88,9 +94,24 @@ def test_add_stratifcation_duplicate_name_raises(): ctx.add_stratification(NAME, [], [], None, False) -def _aggregate_state_person_time(x: pd.DataFrame) -> float: - """Helper aggregator function for observation testing""" - return len(x) * (28 / 365.25) +@pytest.mark.parametrize( + "duplicates", + [ + ["slytherin"], + ["gryffindor", "slytherin"], + ], +) +def test_add_stratification_duplicate_category_raises(duplicates): + ctx = ResultsContext() + with pytest.raises( + ValueError, + match=re.escape( + f"Found duplicate categories in stratification '{NAME}': {duplicates}" + ), + ): + ctx.add_stratification( + NAME, SOURCES, CATEGORIES + duplicates, sorting_hat_vector, True + ) @pytest.mark.parametrize(