Skip to content

Commit a61c596

Browse files
authored
Merge pull request #37 from adaa-polsl/correct-max-rule-count-in-survival
correct max rule count in survival problems
2 parents 6c2b25f + 421f56d commit a61c596

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

rulekit/survival.py

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class _SurvivalModelsParams(BaseModel):
4242
complementary_conditions: Optional[bool] = DEFAULT_PARAMS_VALUE[
4343
"complementary_conditions"
4444
]
45+
max_rule_count: int = DEFAULT_PARAMS_VALUE["max_rule_count"]
4546

4647

4748
class _SurvivalExpertModelParams(_SurvivalModelsParams, ExpertModelParams):

tests/test_survival.py

+17
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,23 @@ def test_getting_training_dataset_kaplan_meier_estimator(self):
185185
"Estimator should contain probabilities for each unique time from the dataset",
186186
)
187187

188+
def test_max_rule_count(self):
189+
MAX_RULE_COUNT = 3
190+
df: pd.DataFrame = read_arff(
191+
os.path.join(dir_path, "resources", "data", "bmt-train-0.arff")
192+
)
193+
X, y = df.drop("survival_status", axis=1), df["survival_status"]
194+
clf = survival.SurvivalRules(
195+
survival_time_attr="survival_time",
196+
max_rule_count=MAX_RULE_COUNT,
197+
)
198+
clf.fit(X, y)
199+
self.assertLessEqual(
200+
len(clf.model.rules),
201+
MAX_RULE_COUNT,
202+
f"Ruleset should contain no more than {MAX_RULE_COUNT} rules according to max_rule_count parameter",
203+
)
204+
188205

189206
class TestExpertSurvivalRules(unittest.TestCase):
190207

0 commit comments

Comments
 (0)