diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2517c881d..d550adba7 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,7 @@ -**3.0.1- 08/19/24** +**3.0.1- 08/20/24** - Create script to find matching dependency branches + - Add results category exclusion tests **3.0.0 - 08/13/24** diff --git a/tests/framework/results/test_context.py b/tests/framework/results/test_context.py index 34cb26cf3..794d957a6 100644 --- a/tests/framework/results/test_context.py +++ b/tests/framework/results/test_context.py @@ -6,6 +6,8 @@ import numpy as np import pandas as pd import pytest +from layered_config_tree import LayeredConfigTree +from loguru import logger from pandas.core.groupby import DataFrameGroupBy from tests.framework.results.helpers import ( @@ -43,7 +45,7 @@ def mocked_event(mocker) -> Event: ], ids=["vectorized_mapper", "non-vectorized_mapper"], ) -def test_add_stratification(mapper, is_vectorized, mocker): +def test_add_stratification_mappers(mapper, is_vectorized, mocker): ctx = ResultsContext() mocker.patch.object(ctx, "excluded_categories", {}) assert not verify_stratification_added( @@ -62,6 +64,57 @@ def test_add_stratification(mapper, is_vectorized, mocker): ) +@pytest.mark.parametrize( + "excluded_categories", + [ + [], + HOUSE_CATEGORIES[:1], + HOUSE_CATEGORIES[:2], + HOUSE_CATEGORIES[:3], + ], + ids=[ + "no_excluded_categories", + "one_excluded_category", + "two_excluded_categories", + "all_but_one_excluded_categories", + ], +) +def test_add_stratification_excluded_categories(excluded_categories, mocker): + ctx = ResultsContext() + builder = mocker.Mock() + builder.configuration.stratification = LayeredConfigTree( + {"default": [], "excluded_categories": {NAME: excluded_categories}} + ) + builder.logging.get_logger.return_value = logger + ctx.setup(builder) + assert not verify_stratification_added( + ctx.stratifications, + NAME, + NAME_COLUMNS, + HOUSE_CATEGORIES, + [], + sorting_hat_vectorized, + True, + ) + ctx.add_stratification( + name=NAME, + sources=NAME_COLUMNS, + categories=HOUSE_CATEGORIES, + excluded_categories=excluded_categories, + mapper=sorting_hat_vectorized, + is_vectorized=True, + ) + assert verify_stratification_added( + ctx.stratifications, + NAME, + NAME_COLUMNS, + HOUSE_CATEGORIES, + excluded_categories, + sorting_hat_vectorized, + True, + ) + + @pytest.mark.parametrize( "name, categories, excluded_categories, msg_match", [