fix: node_sample_weight array dtype with xgboost explainer #52
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Resolves #48 .
Changes
Please see individual commits in the PR for the changelog.
node_sample_weight
numpy array to be double type. This makes it compatible with the type declared in the cext (_cext.cc
). Previously, the array was declared to be float32 in Python, and double / float64 in the C extension (_cext_dense_tree_update_weights
C function). When the dtypes don't match up, it causes the C extension to not be able to properly pass the updated weights array back to Python..test_provided_background_tree_path_dependent
test to use theadult()
dataset instead ofiris()
, and under-specify the model (drastically lower the xgboost depth andnum_boost_rounds
, otherwise we end up with too many leaves/terminal nodes in the model and there'll be some nodes with 0 weights).Some history
The actual manifestation of this problem comes from the
test_provided_background_tree_path_dependent
test failing.slundberg tried to resolve this problem in PR slundberg#2781, but didn't work.
The actual root cause is what is presented in #48 , that no matter what background dataset was provided, the
node_sample_weight
array would still always end up as an array of 0's, thus failing the test.