Skip to content

Commit

Permalink
raise when stratification has duplicate categories (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier committed Aug 13, 2024
1 parent 805d6df commit af85e8b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ def add_stratification(
raise ValueError(
f"Stratification name '{name}' is already used: {str(already_used[0])}."
)
unique_categories = set(categories)
if len(categories) != len(unique_categories):
for category in unique_categories:
categories.remove(category)
raise ValueError(
f"Found duplicate categories in stratification '{name}': {categories}."
)
stratification = Stratification(name, sources, categories, mapper, is_vectorized)
self.stratifications.append(stratification)

Expand Down
27 changes: 24 additions & 3 deletions tests/framework/results/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import math
import re
from datetime import timedelta

import pandas as pd
Expand All @@ -24,6 +25,11 @@
)


def _aggregate_state_person_time(x: pd.DataFrame) -> float:
"""Helper aggregator function for observation testing"""
return len(x) * (28 / 365.25)


@pytest.mark.parametrize(
"name, sources, categories, mapper, is_vectorized",
[
Expand Down Expand Up @@ -88,9 +94,24 @@ def test_add_stratifcation_duplicate_name_raises():
ctx.add_stratification(NAME, [], [], None, False)


def _aggregate_state_person_time(x: pd.DataFrame) -> float:
"""Helper aggregator function for observation testing"""
return len(x) * (28 / 365.25)
@pytest.mark.parametrize(
"duplicates",
[
["slytherin"],
["gryffindor", "slytherin"],
],
)
def test_add_stratification_duplicate_category_raises(duplicates):
ctx = ResultsContext()
with pytest.raises(
ValueError,
match=re.escape(
f"Found duplicate categories in stratification '{NAME}': {duplicates}"
),
):
ctx.add_stratification(
NAME, SOURCES, CATEGORIES + duplicates, sorting_hat_vector, True
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit af85e8b

Please # to comment.