diff --git a/alibi/explainers/pd_variance.py b/alibi/explainers/pd_variance.py index 28b963a18..ef1405ca6 100644 --- a/alibi/explainers/pd_variance.py +++ b/alibi/explainers/pd_variance.py @@ -1,20 +1,21 @@ import copy import logging -import sys import math import numbers -import numpy as np -import matplotlib.pyplot as plt -from typing import Callable, List, Optional, Dict, Union, Tuple, Any -from itertools import combinations +import sys from enum import Enum -from alibi.api.defaults import DEFAULT_META_PDVARIANCE, DEFAULT_DATA_PDVARIANCE +from itertools import combinations +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +from alibi.api.defaults import DEFAULT_DATA_PDVARIANCE, DEFAULT_META_PDVARIANCE from alibi.api.interfaces import Explainer, Explanation -from alibi.explainers.partial_dependence import Kind, PartialDependence, TreePartialDependence +from alibi.explainers.partial_dependence import (Kind, PartialDependence, + TreePartialDependence) from alibi.explainers.similarity.grad import get_options_string from sklearn.base import BaseEstimator - logger = logging.getLogger(__name__) if sys.version_info >= (3, 8): @@ -360,7 +361,7 @@ def _build_explanation(self, buffers: dict) -> Explanation: feature_values=buffers['feature_values'], feature_names=buffers['feature_names']) - if self.meta['params']['method'] == 'importance': + if self.meta['params']['method'] == Method.IMPORTANCE: data.update(feature_importance=buffers['feature_importance']) else: data.update(feature_interaction=buffers['feature_interaction'])