Scripts for easily comparing different aspects of the imodels package. Contains code to reproduce FIGS + Hierarchical shrinkage + G-FIGS + MDI+.
Follow these steps to benchmark a new (supervised) model.
- Write the sklearn-compliant model (init, fit, predict, predict_proba for classifiers) and add it somewhere in a local folder or in
imodels
- Update configs - create a new folder mimicking an existing folder (e.g.
config.interactions
)- Select which datasets you want by modifying
datasets.py
(datasets will be downloaded locally automatically later) - Select which models you want by editing a list similar to
models.py
- Select which datasets you want by modifying
- run
01_fit_models.py
- pass the appropriate cmdline args (e.g. model, dataset, config)
- example command:
python 01_fit_models.py --config interactions --classification_or_regression classification --split_seed 0
- another ex.:
python 01_fit_models.py --config interactions --classification_or_regression classification --model randomforest --split_seed 0
- running everything: loop over
split_seed
+classification_or_regression
- alternatively, to parallelize over a slurm cluster, run
01_submit_fitting.py
with the appropriate loops
- run
02_aggregate_results.py
(which just combines the output of01_run_comparisons.py
into acombined.pkl
file across datasets) for plotting - put scripts/notebooks into a subdirectory of the
notebooks
folder (e.g.notebooks/interactions
)
Note: If you want to benchmark feature importances, go to feature_importance/. For benchmarking other tasks such as unsupervised learning, you will have to make more substantial changes (mostly in 01_fit_models.py
).
- When running multiple seeds, we want to aggregate over all keys that are not the split_seed
- If a hyperparameter is not passed in
ModelConfig
(e.g. because we are using parial), it cannot be aggregated over seeds later on- The
extra_aggregate_keys={'max_leaf_nodes': n}
is a workaround for this (see configs withpartial
to understand how it works)
- The
- If a hyperparameter is not passed in
Tests are run via pytest
- Stable rules - finding a stable set of rules across different models
Fast Interpretable Greedy-Tree Sums (FIGS) is an algorithm for fitting concise rule-based models. Specifically, FIGS generalizes CART to simultaneously grow a flexible number of trees in a summation. The total number of splits across all the trees can be restricted by a pre-specified threshold, keeping the model interpretable. Experiments across a wide array of real-world datasets show that FIGS achieves state-of-the-art prediction performance when restricted to just a few splits (e.g. less than 20).
Example FIGS model. FIGS learns a sum of trees with a flexible number of trees; to make its prediction, it sums the result from each tree.
📄 Paper (ICML 2022), 🔗 Post, 📌 Citation
Hierarchical shrinkage is an extremely fast post-hoc regularization method which works on any decision tree (or tree-based ensemble, such as Random Forest). It does not modify the tree structure, and instead regularizes the tree by shrinking the prediction over each node towards the sample means of its ancestors (using a single regularization parameter). Experiments over a wide variety of datasets show that hierarchical shrinkage substantially increases the predictive performance of individual decision trees and decision-tree ensembles.
HS Example. HS appplies post-hoc regularization to any decision tree by shrinking each node towards its parent.
Machine learning in high-stakes domains, such as healthcare, faces two critical challenges: (1) generalizing to diverse data distributions given limited training data while (2) maintaining interpretability. To address these challenges, G-FIGS effectively pools data across diverse groups to output a concise, rule-based model. Given distinct groups of instances in a dataset (e.g., medical patients grouped by age or treatment site), G-FIGS first estimates group membership probabilities for each instance. Then, it uses these estimates as instance weights in FIGS (Tan et al. 2022), to grow a set of decision trees whose values sum to the final prediction. G-FIGS achieves state-of-the-art prediction performance on important clinical datasets; e.g., holding the level of sensitivity fixed at 92%, G-FIGS increases specificity for identifying cervical spine injury by up to 10% over CART and up to 3% over FIGS alone, with larger gains at higher sensitivity levels. By keeping the total number of rules below 16 in FIGS, the final models remain interpretable, and we find that their rules match medical domain expertise. All code, data, and models are released on Github.
G-FIGS 2-step process explained.
MDI+ is a novel feature importance framework, which generalizes the popular mean decrease in impurity (MDI) importance score for random forests. At its core, MDI+ expands upon a recently discovered connection between linear regression and decision trees. In doing so, MDI+ enables practitioners to (1) tailor the feature importance computation to the data/problem structure and (2) incorporate additional features or knowledge to mitigate known biases of decision trees. In both real data case studies and extensive real-data-inspired simulations, MDI+ outperforms commonly used feature importance measures (e.g., MDI, permutation-based scores, and TreeSHAP) by substantional margins.