From 939e7e276391b86e426156a718e6683b76ef3553 Mon Sep 17 00:00:00 2001 From: Olivier Boulant Date: Thu, 28 Apr 2022 10:27:52 +0200 Subject: [PATCH] fix: Allow Binseg to hit minsize bounds for segments (#249) * storing work to fix precommit and docs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test: compare to the explicit results rather than lenght * style: typos in comment Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charles T --- src/ruptures/detection/binseg.py | 2 +- tests/test_detection.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/ruptures/detection/binseg.py b/src/ruptures/detection/binseg.py index 80fae87d..a4d80a7e 100644 --- a/src/ruptures/detection/binseg.py +++ b/src/ruptures/detection/binseg.py @@ -90,7 +90,7 @@ def single_bkp(self, start, end): return None, 0 gain_list = list() for bkp in range(start, end, self.jump): - if bkp - start > self.min_size and end - bkp > self.min_size: + if bkp - start >= self.min_size and end - bkp >= self.min_size: gain = ( segment_cost - self.cost.error(start, bkp) diff --git a/tests/test_detection.py b/tests/test_detection.py index d8a18bbb..48e207c6 100644 --- a/tests/test_detection.py +++ b/tests/test_detection.py @@ -357,6 +357,17 @@ def test_model_small_signal(signal_bkps_5D_n10, algo, model): ) +@pytest.mark.parametrize( + "model", + ["l1", "l2", "ar", "normal", "rbf", "rank", "mahalanobis"], +) +def test_binseg_min_size(signal_bkps_5D_n10, model): + signal, _ = signal_bkps_5D_n10 + + c_bkps = Binseg(model=model, min_size=5, jump=1).fit_predict(signal, n_bkps=1) + assert all([a == b for a, b in zip(c_bkps, [5, 10])]) + + @pytest.mark.parametrize( "model", ["l1", "l2", "ar", "normal", "rbf", "rank", "mahalanobis"] ) @@ -368,6 +379,9 @@ def test_model_small_signal_dynp(signal_bkps_5D_n10, model): Dynp(model=model, min_size=9, jump=2).fit_predict(signal, 2) with pytest.raises(BadSegmentationParameters): Dynp(model=model, min_size=11, jump=2).fit_predict(signal, 2) + # Test if it can find the single eligible break point compatible with min_size + c_bkps = Dynp(model=model, min_size=5, jump=1).fit_predict(signal, 1) + assert all([a == b for a, b in zip(c_bkps, [5, 10])]) @pytest.mark.parametrize(