Skip to content

Commit

Permalink
fix: node_sample_weight array dtype with xgboost explainer (#52)
Browse files Browse the repository at this point in the history
* fix: simplify test

* style: whitespaces and EOL

* docs: accepted option name is log_loss now

* fix: node_sample_weights needs to be double

The dtype here needs to match with the dtype declared in the _cext.cc file, which is
NPY_DOUBLE.

* fix: use adult dataset instead

because we need a binary label. Iris dataset is not binary
  • Loading branch information
thatlittleboy authored May 22, 2023
1 parent a4e2e9c commit f72ff6e
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 106 deletions.
24 changes: 12 additions & 12 deletions shap/cext/_cext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ static PyObject *_cext_compute_expectations(PyObject *self, PyObject *args)
PyObject *children_right_obj;
PyObject *node_sample_weight_obj;
PyObject *values_obj;

/* Parse the input tuple */
if (!PyArg_ParseTuple(
args, "OOOO", &children_left_obj, &children_right_obj, &node_sample_weight_obj, &values_obj
Expand Down Expand Up @@ -131,7 +131,7 @@ static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args)
int model_output;
PyObject *base_offset_obj;
bool interactions;

/* Parse the input tuple */
if (!PyArg_ParseTuple(
args, "OOOOOOOiOOOOOiOOiib", &children_left_obj, &children_right_obj, &children_default_obj,
Expand Down Expand Up @@ -222,7 +222,7 @@ static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args)
// retrieve return value before python cleanup of objects
tfloat ret_value = (double)values[0];

// clean up the created python objects
// clean up the created python objects
Py_XDECREF(children_left_array);
Py_XDECREF(children_right_array);
Py_XDECREF(children_default_array);
Expand Down Expand Up @@ -261,7 +261,7 @@ static PyObject *_cext_dense_tree_predict(PyObject *self, PyObject *args)
PyObject *X_missing_obj;
PyObject *y_obj;
PyObject *out_pred_obj;

/* Parse the input tuple */
if (!PyArg_ParseTuple(
args, "OOOOOOiiOiOOOO", &children_left_obj, &children_right_obj, &children_default_obj,
Expand Down Expand Up @@ -339,7 +339,7 @@ static PyObject *_cext_dense_tree_predict(PyObject *self, PyObject *args)

dense_tree_predict(out_pred, trees, data, model_output);

// clean up the created python objects
// clean up the created python objects
Py_XDECREF(children_left_array);
Py_XDECREF(children_right_array);
Py_XDECREF(children_default_array);
Expand Down Expand Up @@ -371,7 +371,7 @@ static PyObject *_cext_dense_tree_update_weights(PyObject *self, PyObject *args)
PyObject *node_sample_weight_obj;
PyObject *X_obj;
PyObject *X_missing_obj;

/* Parse the input tuple */
if (!PyArg_ParseTuple(
args, "OOOOOOiOOO", &children_left_obj, &children_right_obj, &children_default_obj,
Expand Down Expand Up @@ -433,14 +433,14 @@ static PyObject *_cext_dense_tree_update_weights(PyObject *self, PyObject *args)

dense_tree_update_weights(trees, data);

// clean up the created python objects
// clean up the created python objects
Py_XDECREF(children_left_array);
Py_XDECREF(children_right_array);
Py_XDECREF(children_default_array);
Py_XDECREF(features_array);
Py_XDECREF(thresholds_array);
Py_XDECREF(values_array);
//PyArray_ResolveWritebackIfCopy(node_sample_weight_array);
// PyArray_ResolveWritebackIfCopy(node_sample_weight_array);
Py_XDECREF(node_sample_weight_array);
Py_XDECREF(X_array);
Py_XDECREF(X_missing_array);
Expand All @@ -467,8 +467,8 @@ static PyObject *_cext_dense_tree_saabas(PyObject *self, PyObject *args)
PyObject *X_missing_obj;
PyObject *y_obj;
PyObject *out_pred_obj;


/* Parse the input tuple */
if (!PyArg_ParseTuple(
args, "OOOOOOiiOiOOOO", &children_left_obj, &children_right_obj, &children_default_obj,
Expand Down Expand Up @@ -540,7 +540,7 @@ static PyObject *_cext_dense_tree_saabas(PyObject *self, PyObject *args)

dense_tree_saabas(out_pred, trees, data);

// clean up the created python objects
// clean up the created python objects
Py_XDECREF(children_left_array);
Py_XDECREF(children_right_array);
Py_XDECREF(children_default_array);
Expand All @@ -557,4 +557,4 @@ static PyObject *_cext_dense_tree_saabas(PyObject *self, PyObject *args)
/* Build the output tuple */
PyObject *ret = Py_BuildValue("d", (double)values[0]);
return ret;
}
}
Loading

0 comments on commit f72ff6e

Please # to comment.