Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python-client/giskard/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ def __init__(self, flavor: str = None, functionality: str = None, msg: str = Non
f"with {self.functionality or self.flavor} support "
f"with `pip install giskard[{self.flavor}]`."
)


class GiskardImportError(ImportError):
def __init__(self, missing_package: str) -> None:
self.msg = f"The '{missing_package}' Python package is not installed; please execute 'pip install {missing_package}' to obtain it."
4 changes: 3 additions & 1 deletion python-client/giskard/integrations/wandb/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import pandas as pd
from shap import Explanation

from giskard.core.errors import GiskardImportError

try:
import wandb # noqa
from wandb.wandb_run import Run
except ImportError as e:
raise ImportError("The 'wandb' python package is not installed. To get it, run 'pip install wandb'.") from e
raise GiskardImportError("wandb") from e


@contextlib.contextmanager
Expand Down
27 changes: 23 additions & 4 deletions python-client/giskard/models/model_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

import numpy as np
import pandas as pd
from shap.maskers import Text
from shap import KernelExplainer, Explanation, Explainer

from giskard.datasets.base import Dataset
from giskard.models.base import BaseModel
from giskard.models.shap_result import ShapResult
from giskard.ml_worker.utils.logging import timer
from giskard.core.errors import GiskardImportError

warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -123,6 +121,11 @@ def _calculate_dataset_shap_values(model: BaseModel, dataset: Dataset) -> np.nda
shap_values : np.ndarray
The model's SHAP values as a numpy array
"""
try:
from shap import KernelExplainer
except ImportError as e:
raise GiskardImportError("shap") from e

# Prepare background sample to be used in the KernelSHAP.
background_df = model.prepare_dataframe(dataset.df, dataset.column_dtypes, dataset.target)
background_sample = _get_background_example(background_df, dataset.column_types)
Expand Down Expand Up @@ -163,7 +166,7 @@ def _get_highest_proba_shap(shap_values: np.ndarray, model: BaseModel, dataset:
return [shap_values[predicted_class][sample_idx] for sample_idx, predicted_class in enumerate(predictions)]


def explain_with_shap(model: BaseModel, dataset: Dataset, only_highest_proba: bool = True) -> ShapResult:
def explain_with_shap(model: BaseModel, dataset: Dataset, only_highest_proba: bool = True):
"""Explain the model with SHAP and return the results as a `ShapResult` object.

Parameters
Expand All @@ -183,6 +186,12 @@ def explain_with_shap(model: BaseModel, dataset: Dataset, only_highest_proba: bo
ShapResult
The model's SHAP values.
"""

try:
from shap import Explanation
from giskard.models.shap_result import ShapResult
except ImportError as e:
raise GiskardImportError("shap") from e
shap_values = _calculate_dataset_shap_values(model, dataset)
if only_highest_proba and model.is_classification:
shap_values = _get_highest_proba_shap(shap_values, model, dataset)
Expand All @@ -196,6 +205,11 @@ def explain_with_shap(model: BaseModel, dataset: Dataset, only_highest_proba: bo


def _calculate_sample_shap_values(model: BaseModel, dataset: Dataset, input_data: Dict) -> np.ndarray:
try:
from shap import KernelExplainer
except ImportError as e:
raise GiskardImportError("shap") from e

df = model.prepare_dataframe(dataset.df, column_dtypes=dataset.column_dtypes, target=dataset.target)
data_to_explain = _prepare_for_explanation(pd.DataFrame([input_data]), model=model, dataset=dataset)

Expand Down Expand Up @@ -229,6 +243,11 @@ def explain(model: BaseModel, dataset: Dataset, input_data: Dict):

@timer()
def explain_text(model: BaseModel, input_df: pd.DataFrame, text_column: str, text_document: str):
try:
from shap.maskers import Text
from shap import Explainer
except ImportError as e:
raise GiskardImportError("shap") from e
try:
masker = Text(tokenizer=r"\W+")
text_explainer = Explainer(text_explanation_prediction_wrapper(model.predict_df, input_df, text_column), masker)
Expand Down
8 changes: 6 additions & 2 deletions python-client/giskard/models/shap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from dataclasses import dataclass
from typing import Optional

from shap import Explanation

from giskard.core.core import ModelType, SupportedModelTypes
from ..utils.analytics_collector import analytics
from giskard.client.python_utils import warning
from giskard.core.errors import GiskardImportError


class PanelNames(str, Enum):
Expand Down Expand Up @@ -39,6 +38,11 @@ class ShapResult:
A flag indicating whether to provide SHAP explanations only for the predictions with the highest probability or not.
"""

try:
from shap import Explanation
except ImportError as e:
raise GiskardImportError("shap") from e

explanations: Explanation
feature_types: Optional[dict] = None
model_type: Optional[ModelType] = None
Expand Down