Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Make function _get_surrogate_model_replication_measure() public #495

Merged
merged 2 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/interpret_community/mimic/mimic_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,10 @@ def __setstate__(self, state):

def _get_surrogate_model_replication_measure(self, training_data):
"""Return the metric which tells how well the surrogate model replicates the teacher model.

For classification scenarios, this function will return accuracy. For regression scenarios,
this function will return r2_score.

:param training_data: The data for getting the replication metric.
:type training_data: numpy.array or pandas.DataFrame or scipy.sparse.csr_matrix
:return: Metric that tells how well the surrogate model replicates the behavior of teacher model.
Expand All @@ -769,3 +773,16 @@ def _get_surrogate_model_replication_measure(self, training_data):
else:
replication_measure = r2_score(teacher_model_predictions, surrogate_model_predictions)
return replication_measure

def get_surrogate_model_replication_measure(self, training_data):
"""Return the metric which tells how well the surrogate model replicates the teacher model.

For classification scenarios, this function will return accuracy. For regression scenarios,
this function will return r2_score.

:param training_data: The data for getting the replication metric.
:type training_data: numpy.array or pandas.DataFrame or scipy.sparse.csr_matrix
:return: Metric that tells how well the surrogate model replicates the behavior of teacher model.
:rtype: float
"""
return self._get_surrogate_model_replication_measure(training_data=training_data)
9 changes: 6 additions & 3 deletions tests/test_mimic_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def test_explain_raw_feats_regression(self, mimic_explainer):
def _verify_predictions_and_replication_metric(self, mimic_explainer, data):
predictions_main_model = mimic_explainer._get_teacher_model_predictions(data)
predictions_surrogate_model = mimic_explainer._get_surrogate_model_predictions(data)
replication_score = mimic_explainer._get_surrogate_model_replication_measure(data)
replication_score = mimic_explainer.get_surrogate_model_replication_measure(data)

assert predictions_main_model is not None
assert predictions_surrogate_model is not None
Expand All @@ -422,8 +422,11 @@ def _verify_predictions_and_replication_metric(self, mimic_explainer, data):
assert replication_score is not None and isinstance(replication_score, float)

if mimic_explainer.classes is None:
with pytest.raises(ScenarioNotSupportedException):
mimic_explainer._get_surrogate_model_replication_measure(
with pytest.raises(
ScenarioNotSupportedException,
match="Replication measure for regression surrogate not supported "
"because of single instance in training data"):
mimic_explainer.get_surrogate_model_replication_measure(
data[0].reshape(1, len(data[0])))

def test_explain_model_string_classes(self, mimic_explainer):
Expand Down