diff --git a/python-client/giskard/core/errors.py b/python-client/giskard/core/errors.py index 44c526f1c4..8782c4a7d2 100644 --- a/python-client/giskard/core/errors.py +++ b/python-client/giskard/core/errors.py @@ -19,3 +19,8 @@ def __init__(self, flavor: str = None, functionality: str = None, msg: str = Non f"with {self.functionality or self.flavor} support " f"with `pip install giskard[{self.flavor}]`." ) + + +class GiskardImportError(ImportError): + def __init__(self, missing_package: str) -> None: + self.msg = f"The '{missing_package}' Python package is not installed; please execute 'pip install {missing_package}' to obtain it." diff --git a/python-client/giskard/integrations/wandb/wandb_utils.py b/python-client/giskard/integrations/wandb/wandb_utils.py index 8e1fe773db..c72f381833 100644 --- a/python-client/giskard/integrations/wandb/wandb_utils.py +++ b/python-client/giskard/integrations/wandb/wandb_utils.py @@ -5,11 +5,13 @@ import pandas as pd from shap import Explanation +from giskard.core.errors import GiskardImportError + try: import wandb # noqa from wandb.wandb_run import Run except ImportError as e: - raise ImportError("The 'wandb' python package is not installed. To get it, run 'pip install wandb'.") from e + raise GiskardImportError("wandb") from e @contextlib.contextmanager diff --git a/python-client/giskard/models/model_explanation.py b/python-client/giskard/models/model_explanation.py index 59232c67f3..427836fb93 100644 --- a/python-client/giskard/models/model_explanation.py +++ b/python-client/giskard/models/model_explanation.py @@ -4,13 +4,11 @@ import numpy as np import pandas as pd -from shap.maskers import Text -from shap import KernelExplainer, Explanation, Explainer from giskard.datasets.base import Dataset from giskard.models.base import BaseModel -from giskard.models.shap_result import ShapResult from giskard.ml_worker.utils.logging import timer +from giskard.core.errors import GiskardImportError warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") logger = logging.getLogger(__name__) @@ -123,6 +121,11 @@ def _calculate_dataset_shap_values(model: BaseModel, dataset: Dataset) -> np.nda shap_values : np.ndarray The model's SHAP values as a numpy array """ + try: + from shap import KernelExplainer + except ImportError as e: + raise GiskardImportError("shap") from e + # Prepare background sample to be used in the KernelSHAP. background_df = model.prepare_dataframe(dataset.df, dataset.column_dtypes, dataset.target) background_sample = _get_background_example(background_df, dataset.column_types) @@ -163,7 +166,7 @@ def _get_highest_proba_shap(shap_values: np.ndarray, model: BaseModel, dataset: return [shap_values[predicted_class][sample_idx] for sample_idx, predicted_class in enumerate(predictions)] -def explain_with_shap(model: BaseModel, dataset: Dataset, only_highest_proba: bool = True) -> ShapResult: +def explain_with_shap(model: BaseModel, dataset: Dataset, only_highest_proba: bool = True): """Explain the model with SHAP and return the results as a `ShapResult` object. Parameters @@ -183,6 +186,12 @@ def explain_with_shap(model: BaseModel, dataset: Dataset, only_highest_proba: bo ShapResult The model's SHAP values. """ + + try: + from shap import Explanation + from giskard.models.shap_result import ShapResult + except ImportError as e: + raise GiskardImportError("shap") from e shap_values = _calculate_dataset_shap_values(model, dataset) if only_highest_proba and model.is_classification: shap_values = _get_highest_proba_shap(shap_values, model, dataset) @@ -196,6 +205,11 @@ def explain_with_shap(model: BaseModel, dataset: Dataset, only_highest_proba: bo def _calculate_sample_shap_values(model: BaseModel, dataset: Dataset, input_data: Dict) -> np.ndarray: + try: + from shap import KernelExplainer + except ImportError as e: + raise GiskardImportError("shap") from e + df = model.prepare_dataframe(dataset.df, column_dtypes=dataset.column_dtypes, target=dataset.target) data_to_explain = _prepare_for_explanation(pd.DataFrame([input_data]), model=model, dataset=dataset) @@ -229,6 +243,11 @@ def explain(model: BaseModel, dataset: Dataset, input_data: Dict): @timer() def explain_text(model: BaseModel, input_df: pd.DataFrame, text_column: str, text_document: str): + try: + from shap.maskers import Text + from shap import Explainer + except ImportError as e: + raise GiskardImportError("shap") from e try: masker = Text(tokenizer=r"\W+") text_explainer = Explainer(text_explanation_prediction_wrapper(model.predict_df, input_df, text_column), masker) diff --git a/python-client/giskard/models/shap_result.py b/python-client/giskard/models/shap_result.py index 3737915116..60b502ee49 100644 --- a/python-client/giskard/models/shap_result.py +++ b/python-client/giskard/models/shap_result.py @@ -2,11 +2,10 @@ from dataclasses import dataclass from typing import Optional -from shap import Explanation - from giskard.core.core import ModelType, SupportedModelTypes from ..utils.analytics_collector import analytics from giskard.client.python_utils import warning +from giskard.core.errors import GiskardImportError class PanelNames(str, Enum): @@ -39,6 +38,11 @@ class ShapResult: A flag indicating whether to provide SHAP explanations only for the predictions with the highest probability or not. """ + try: + from shap import Explanation + except ImportError as e: + raise GiskardImportError("shap") from e + explanations: Explanation feature_types: Optional[dict] = None model_type: Optional[ModelType] = None