Skip to content

Commit

Permalink
test: test extracted code
Browse files Browse the repository at this point in the history
  • Loading branch information
baggiponte committed Jun 9, 2024
1 parent 27de89f commit c45a965
Showing 1 changed file with 71 additions and 56 deletions.
127 changes: 71 additions & 56 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,77 @@
import polars as pl
import pytest

from functime import plotting
from functime._utils_plotting import (
get_chosen_entities,
get_num_points,
get_num_rows,
get_num_series,
get_subplot_grid_position,
)


@pytest.fixture
def mock_dataframe():
data = {
"entity": ["A", "A", "B", "B", "C", "C"],
"time": [1, 2, 1, 2, 1, 2],
"value": [10, 20, 30, 40, 50, 60],
}
return pl.LazyFrame(data)


@pytest.fixture
def mock_entity_counts(mock_dataframe):
return mock_dataframe.select(pl.col("entity").value_counts())


def test_set_subplot_default_kwargs_no_existing_kwargs():
kwargs = {}
updated_kwargs = plotting._set_subplot_default_kwargs(kwargs, 2, 3)
def test_get_chosen_entities_not_random(mock_dataframe):
actual = get_chosen_entities(y=mock_dataframe, n_series=3, seed=None).to_list()
expected = ["A", "B", "C"]
assert actual == expected

assert updated_kwargs["width"] == 250 * 3 + 100 # default width * cols + space
assert updated_kwargs["height"] == 200 * 2 + 100 # default height * rows + space
assert updated_kwargs["template"] == "plotly_white"

@pytest.mark.parametrize("n_series, seed", [(3, 42), (2, 42)])
def test_get_chosen_entities_random(mock_dataframe, n_series, seed):
expected = (
mock_dataframe.select(pl.col("entity").unique(maintain_order=True))
.collect()
.sample(n_series, seed=seed)
.to_series()
.to_list()
)

actual = get_chosen_entities(
y=mock_dataframe,
n_series=n_series,
seed=seed,
).to_list()

assert actual == expected


@pytest.mark.parametrize("num_points, expected", [(None, 2), (1, 1), (2, 2)])
def test_get_num_points(mock_entity_counts, num_points, expected):
actual = get_num_points(entities_counts=mock_entity_counts, num_points=num_points)
assert actual == expected

def test_set_subplot_default_kwargs_with_one_defined_kwarg():
kwargs = {"width": 800, "some_other_kwarg": "value"}
updated_kwargs = plotting._set_subplot_default_kwargs(kwargs, 2, 3)

assert updated_kwargs["width"] == 800 # Should remain unchanged
assert updated_kwargs["height"] == 200 * 2 + 100 # default height * rows + space
assert updated_kwargs["some_other_kwarg"] == "value"
@pytest.mark.parametrize("num_points", [0, 3, -1])
def test_get_num_points_raises_value_error(mock_entity_counts, num_points):
with pytest.raises(ValueError):
get_num_points(entities_counts=mock_entity_counts, num_points=num_points)


@pytest.mark.parametrize("num_series, expected", [(None, 3), (1, 1)])
def test_get_num_series(mock_entity_counts, num_series, expected):
actual = get_num_series(entities_counts=mock_entity_counts, num_series=num_series)
assert actual == expected


@pytest.mark.parametrize("num_series", [0, -1, 4])
def test_get_num_series_raises_value_error(mock_entity_counts, num_series):
with pytest.raises(ValueError):
get_num_series(entities_counts=mock_entity_counts, num_series=num_series)


@pytest.mark.parametrize(
Expand All @@ -34,54 +86,17 @@ def test_set_subplot_default_kwargs_with_one_defined_kwarg():
(10, 15, 1), # More columns than series
],
)
def test_calculate_subplot_n_rows(n_series, n_cols, expected_rows):
assert plotting._calculate_subplot_n_rows(n_series, n_cols) == expected_rows
def test_get_num_rows(n_series, n_cols, expected_rows):
assert get_num_rows(num_series=n_series, num_cols=n_cols) == expected_rows


@pytest.mark.parametrize(
"n_series, n_cols",
[
(0, 2), # No series
(10, 0), # Zero columns
(-1, 2), # Negative series
(10, -2), # Negative columns
],
[(0, 1), (1, 0), (0, 0)],
)
def test_calculate_subplot_n_rows_errors(n_series, n_cols):
def test_get_num_rows_raises_value_error(n_series, n_cols):
with pytest.raises(ValueError):
plotting._calculate_subplot_n_rows(n_series, n_cols)


def create_mock_dataframe():
# Create a mock DataFrame for testing
data = {
"entity": ["A", "A", "B", "B", "C", "C"],
"time": [1, 2, 1, 2, 1, 2],
"value": [10, 20, 30, 40, 50, 60],
}
return pl.DataFrame(data)


@pytest.mark.parametrize(
"n_series, last_n, expected_entities",
[
(2, 1, {"A", "B"}), # Test with 2 series, last 1 record
(3, 2, {"A", "B", "C"}), # Test with all series, last 2 records
(4, 2, {"A", "B", "C"}), # More series than available
],
)
def test_prepare_data_for_subplots(n_series, last_n, expected_entities):
df = create_mock_dataframe()
entities_sample, _, y_filtered = plotting._prepare_data_for_subplots(
df, n_series, last_n, seed=1
)

# Check if the correct entities are sampled
assert set(entities_sample) == expected_entities

# Check if the data is correctly filtered
for entity in entities_sample:
assert y_filtered.filter(pl.col("entity") == entity).height <= last_n
get_num_rows(num_series=n_series, num_cols=n_cols)


@pytest.mark.parametrize(
Expand All @@ -98,4 +113,4 @@ def test_prepare_data_for_subplots(n_series, last_n, expected_entities):
],
)
def test_get_subplot_grid_position(i, n_cols, expected_row_col):
assert plotting._get_subplot_grid_position(i, n_cols) == expected_row_col
assert get_subplot_grid_position(element=i, num_cols=n_cols) == expected_row_col

0 comments on commit c45a965

Please # to comment.