Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7789fa9
Initial commit with the implementation of the SHAP explanation graphs…
AbSsEnT Aug 8, 2023
e31798e
Changed logic of obtaining feature names and types.
AbSsEnT Aug 8, 2023
e9de16d
Removed redundant 'model.prepare_dataframe'. Small refactoring.
AbSsEnT Aug 8, 2023
4f92418
Added sorting of logged dataset, test suite result and scan result to…
AbSsEnT Aug 9, 2023
1a59240
Moved 'explain' function below shap-related functions.
AbSsEnT Aug 9, 2023
cc15925
Code refactoring.
AbSsEnT Aug 9, 2023
5fbb14c
Changed naming for variables inside functions.
AbSsEnT Aug 9, 2023
c1c075e
Removed explainer return, as it is not needed.
AbSsEnT Aug 9, 2023
9dbf847
Moved 'prepare_df' to the separate utils.py file to avoid code duplic…
AbSsEnT Aug 9, 2023
167c629
Added docstring to the '_get_cls_prediction_explanation'
AbSsEnT Aug 9, 2023
ff6ed6f
Created dataclass ShapResult to store shap explanations there and enc…
AbSsEnT Aug 10, 2023
6366064
Refactoring of the 'background_example' function.
AbSsEnT Aug 10, 2023
75a013f
Refactoring.
AbSsEnT Aug 10, 2023
4935cd8
Refactoring.
AbSsEnT Aug 10, 2023
8ad73b1
Changed enum class declaration.
AbSsEnT Aug 10, 2023
6bf77d7
Refactored model_explanation.py to be able to perform testing of expl…
AbSsEnT Aug 11, 2023
d2c4739
Small fix in comments.
AbSsEnT Aug 11, 2023
c4ee3c1
Uncommented fixture.
AbSsEnT Aug 16, 2023
8b11843
Refactored "_get_highest_prob_shap" function. Made it more compact an…
AbSsEnT Aug 16, 2023
660fd66
Removed #noqa options from the shap imports. Optimized imports.
AbSsEnT Aug 16, 2023
99f82b8
Refactored _prepare_for_explanation function. Changed naming of the f…
AbSsEnT Aug 16, 2023
0746ff3
Renamed explain_full(one) to "_calculate_dataset(sample)_shap_values"
AbSsEnT Aug 16, 2023
72bd4d4
Refactored _get_background_example function.
AbSsEnT Aug 16, 2023
0a5928f
Refactored 'explain_with_shap' function and 'ShapResult' dataclass fo…
AbSsEnT Aug 16, 2023
d80aae5
Fixed bugs with unit-tests for wandb.
AbSsEnT Aug 16, 2023
d98b943
Transferred '_compare_explain_functions' to the 'test_model_explanati…
AbSsEnT Aug 16, 2023
feffd84
Refactoring. Renaming and functions replacement.
AbSsEnT Aug 16, 2023
868c33f
Renaming.
AbSsEnT Aug 16, 2023
a2e9833
Merge branch 'GSK-1505-wandb' into GSK-1533-wandb-shap
rabah-khalek Aug 16, 2023
ce40cf3
Merge branch 'GSK-1505-wandb' into GSK-1533-wandb-shap
rabah-khalek Aug 16, 2023
309a64d
Transferred plotting functions from the shap_result.py to the wandb_u…
AbSsEnT Aug 16, 2023
428e919
Merge remote-tracking branch 'origin/GSK-1533-wandb-shap' into GSK-15…
AbSsEnT Aug 16, 2023
1b2302f
small update to error msg
rabah-khalek Aug 16, 2023
e0c20bb
updated unit test
rabah-khalek Aug 16, 2023
0083f30
small update
rabah-khalek Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python-client/giskard/core/suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def to_wandb(self, **kwargs) -> None:
# Log just a test description and a metric.
columns = ["Metric name", "Data slice", "Metric value", "Passed"]
data = [[*_parse_test_name(result[0]), result[1].metric, result[1].passed] for result in self.results]
run.log({"Test-Suite Results": wandb.Table(columns=columns, data=data)})
run.log({"Test suite results/Test-Suite Results": wandb.Table(columns=columns, data=data)})


class SuiteInput:
Expand Down
2 changes: 1 addition & 1 deletion python-client/giskard/datasets/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def to_wandb(self, **kwargs) -> None:
with wandb_run(**kwargs) as run:
import wandb # noqa library import already checked in wandb_run

run.log({"dataset": wandb.Table(dataframe=self.df)})
run.log({"Dataset/dataset": wandb.Table(dataframe=self.df)})


def _cast_to_list_like(object):
Expand Down
128 changes: 85 additions & 43 deletions python-client/giskard/models/model_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,103 @@

import numpy as np
import pandas as pd
from shap.maskers import Text
from shap import KernelExplainer, Explanation, Explainer

from giskard.datasets.base import Dataset
from giskard.ml_worker.utils.logging import timer
from giskard.models.base import BaseModel
from giskard.models.shap_result import ShapResult
from giskard.ml_worker.utils.logging import timer

warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
import shap # noqa

logger = logging.getLogger(__name__)


@timer()
def explain(model: BaseModel, dataset: Dataset, input_data: Dict):
def prepare_df(df):
df = model.prepare_dataframe(df, column_dtypes=dataset.column_dtypes, target=dataset.target)
if dataset.target in df.columns:
prepared_ds = Dataset(df=df, target=dataset.target, column_types=dataset.column_types)
else:
prepared_ds = Dataset(df=df, column_types=dataset.column_types)
prepared_df = model.prepare_dataframe(
prepared_ds.df, column_dtypes=prepared_ds.column_dtypes, target=prepared_ds.target
)
columns_in_original_order = (
model.meta.feature_names
if model.meta.feature_names
else [c for c in dataset.df.columns if c in prepared_df.columns]
)
# Make sure column order is the same as in df
return prepared_df[columns_in_original_order]
def _get_highest_prob_shap(shap_values: list, model: BaseModel, dataset: Dataset) -> list:
"""Get SHAP explanations of classes with the highest predicted probability."""
predictions = model.predict(dataset).raw_prediction
return [shap_values[predicted_class][sample_idx] for sample_idx, predicted_class in enumerate(predictions)]


def _prepare_for_explanation(input_df: pd.DataFrame, model: BaseModel, dataset: Dataset) -> pd.DataFrame:
"""Prepare dataframe for an inference step."""
input_df = model.prepare_dataframe(input_df, column_dtypes=dataset.column_dtypes, target=dataset.target)

target = dataset.target if dataset.target in input_df.columns else None
prepared_dataset = Dataset(input_df, column_types=dataset.column_types, target=target)

# Make sure column order is the same as in the dataset.df.
columns_original_order = (
model.meta.feature_names
if model.meta.feature_names
else [c for c in dataset.df.columns if c in prepared_dataset.df.columns]
)

prepared_df = prepared_dataset.df[columns_original_order]
return prepared_df


def _get_background_example(df: pd.DataFrame, feature_types: Dict[str, str]) -> pd.DataFrame:
"""Create background example for the SHAP explainer as a mode/median of features."""
median = df.median(numeric_only=True)
background_sample = df.mode(dropna=False).head(1)

# Use median of the numerical features.
numerical_features = [feature for feature in list(df.columns) if feature_types.get(feature) == "numeric"]
for feature in numerical_features:
background_sample[feature] = median[feature]

background_sample = background_sample.astype(df.dtypes)
return background_sample

df = model.prepare_dataframe(dataset.df, column_dtypes=dataset.column_dtypes, target=dataset.target)
feature_names = list(df.columns)

input_df = prepare_df(pd.DataFrame([input_data]))
def _calculate_dataset_shap_values(model: BaseModel, dataset: Dataset) -> np.ndarray:
"""Perform SHAP values calculation for samples of a given dataset."""
# Prepare background sample to be used in the KernelSHAP.
background_df = model.prepare_dataframe(dataset.df, dataset.column_dtypes, dataset.target)
background_sample = _get_background_example(background_df, dataset.column_types)

# Prepare input data for an explanation.
data_to_explain = _prepare_for_explanation(dataset.df, model=model, dataset=dataset)

# Obtain SHAP explanations.
explainer = KernelExplainer(model.predict_df, background_sample, data_to_explain.columns, keep_index=True)
shap_values = explainer.shap_values(data_to_explain, silent=True)
return shap_values


def explain_with_shap(model: BaseModel, dataset: Dataset, only_highest_prob: bool = True) -> ShapResult:
"""Get SHAP explanation result."""
shap_values = _calculate_dataset_shap_values(model, dataset)
if only_highest_prob and model.is_classification:
shap_values = _get_highest_prob_shap(shap_values, model, dataset)

# Put SHAP values to the Explanation object for a convenience.
feature_names = model.meta.feature_names or list(dataset.df.columns.drop(dataset.target, errors="ignore"))
shap_explanations = Explanation(shap_values, data=dataset.df[feature_names], feature_names=feature_names)

feature_types = {key: dataset.column_types[key] for key in feature_names}
return ShapResult(shap_explanations, feature_types, feature_names, model.meta.model_type, only_highest_prob)


def _calculate_sample_shap_values(model: BaseModel, dataset: Dataset, input_data: Dict) -> np.ndarray:
df = model.prepare_dataframe(dataset.df, column_dtypes=dataset.column_dtypes, target=dataset.target)
data_to_explain = _prepare_for_explanation(pd.DataFrame([input_data]), model=model, dataset=dataset)

def predict_array(array):
arr_df = pd.DataFrame(array, columns=list(df.columns))
return model.predict_df(prepare_df(arr_df))
return model.predict_df(_prepare_for_explanation(arr_df, model=model, dataset=dataset))

example = background_example(df, dataset.column_types)
kernel = shap.KernelExplainer(predict_array, example)
shap_values = kernel.shap_values(input_df, silent=True)
example = _get_background_example(df, dataset.column_types)
kernel = KernelExplainer(predict_array, example)
shap_values = kernel.shap_values(data_to_explain, silent=True)
return shap_values


@timer()
def explain(model: BaseModel, dataset: Dataset, input_data: Dict):
shap_values = _calculate_sample_shap_values(model, dataset, input_data)
feature_names = model.meta.feature_names or list(dataset.df.columns.drop(dataset.target, errors="ignore"))

if model.is_regression:
explanation_chart_data = summary_shap_regression(shap_values=shap_values, feature_names=feature_names)
Expand All @@ -63,11 +118,8 @@ def predict_array(array):
@timer()
def explain_text(model: BaseModel, input_df: pd.DataFrame, text_column: str, text_document: str):
try:
text_explainer = shap.Explainer(
text_explanation_prediction_wrapper(model.predict_df, input_df, text_column),
shap.maskers.Text(tokenizer=r"\W+"),
)

masker = Text(tokenizer=r"\W+")
text_explainer = Explainer(text_explanation_prediction_wrapper(model.predict_df, input_df, text_column), masker)
shap_values = text_explainer(pd.Series([text_document]))

return (
Expand All @@ -80,16 +132,6 @@ def explain_text(model: BaseModel, input_df: pd.DataFrame, text_column: str, tex
raise Exception("Failed to create text explanation") from e


def background_example(df: pd.DataFrame, input_types: Dict[str, str]) -> pd.DataFrame:
example = df.mode(dropna=False).head(1) # si plusieurs modes, on prend le premier
# example.fillna("", inplace=True)
median = df.median()
num_columns = [key for key in list(df.columns) if input_types.get(key) == "numeric"]
for column in num_columns:
example[column] = median[column]
return example.astype(df.dtypes)


def summary_shap_classification(
shap_values: List[np.ndarray],
feature_names: List[str],
Expand Down
124 changes: 124 additions & 0 deletions python-client/giskard/models/shap_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from enum import Enum
from typing import Any, Iterable
from dataclasses import dataclass

import wandb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import wandb shouldn't be here, as we said, it would break. Please embed it only where we need it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is that I am going to use "shap_results.py" module just in case, when I need to import ShapResults class, which requires wandb. So I do not need anything else from this module, which do not use wandb. i.e. I am gonna use this module only if wandb is installed. Otherwise, I need to put wandb import to all private functions and ShapResults.to_wandb() which looks like an overhead. WDYT?

Copy link
Contributor

@rabah-khalek rabah-khalek Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's because we decided to make explain_with_shap a global function that returns ShapResult, so that users can use it without necessarily wanting to log it with wandb (maybe later we want to log it to mlflow or other mlops).

several imports of the same libraries is not an issue in python btw, see: https://stackoverflow.com/questions/37067414/python-import-multiple-times.

If you don't like the idea of import wandb in every plotting function, you can refactor them into wandb_utils.py, I think that's a cleaner solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if I put plotting functions into wandb_utils.py, I will also get an error, but defined by us. Do you expect such behaviour and we are ok with it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise ImportError("The 'wandb' python package is not installed. To get it, run 'pip install wandb'.") from e

import numpy as np
import pandas as pd
from shap import Explanation

from giskard.core.core import ModelType, SupportedModelTypes


class PanelNames(str, Enum):
CATEGORICAL = "Feature importance for categorical features"
NUMERICAL = "Feature importance for numerical features"
GENERAL = "Global feature importance"


def _wandb_bar_plot(shap_explanations: Explanation, feature_name: str) -> Any:
"""Get wandb bar plot of shap values of the categorical feature."""
feature_column = "feature_values"
shap_column = "shap_abs_values"

# Extract feature values and related shap explanations.
shap_values = shap_explanations[:, feature_name].values
feature_values = shap_explanations[:, feature_name].data

# We are interested in magnitude.
shap_abs_values = np.abs(shap_values)

# Calculate mean shap value per feature value.
df = pd.DataFrame(data={feature_column: feature_values, shap_column: shap_abs_values})
shap_abs_means = pd.DataFrame(df.groupby(feature_column)[shap_column].mean()).reset_index()

# Create bar plot.
table = wandb.Table(dataframe=shap_abs_means)
plot = wandb.plot.bar(
table, label=feature_column, value=shap_column, title=f"Mean(Abs(SHAP)) of '{feature_name}' feature values"
)

return plot


def _wandb_scatter_plot(shap_explanations: Explanation, feature_name: str) -> Any:
"""Get wandb scatter plot of shap values of the numerical feature."""
feature_column = "feature_values"
shap_column = "shap_values"

# Extract feature values and related shap explanations.
shap_values = shap_explanations[:, feature_name].values
feature_values = shap_explanations[:, feature_name].data

# Create scatter plot.
df = pd.DataFrame(data={feature_column: feature_values, shap_column: shap_values})
table = wandb.Table(dataframe=df)
plot = wandb.plot.scatter(
table, y=feature_column, x=shap_column, title=f"'{feature_name}' feature values vs SHAP values"
)

return plot


def _wandb_general_bar_plot(shap_explanations: Explanation, feature_names: Iterable) -> Any:
"""Get wandb bar plot of general shap mean values."""
feature_column = "feature"
shap_column = "global_shap_mean"

# Calculate global shap means.
shap_general_means = list()

for feature_name in feature_names:
shap_general_means.append(np.abs(shap_explanations[:, feature_name].values).mean())

# Create bar plot.
df = pd.DataFrame(data={feature_column: feature_names, shap_column: shap_general_means})
table = wandb.Table(dataframe=df)
plot = wandb.plot.bar(
table, label=feature_column, value=shap_column, title="General Mean(Abs(SHAP)) across all features"
)

return plot


@dataclass
class ShapResult:
explanations: Explanation = None
feature_types: dict = None
feature_names: list = None
model_type: ModelType = None
only_highest_prob: bool = True

def _validate_config(self):
if not self.only_highest_prob and self.model_type == SupportedModelTypes.CLASSIFICATION:
raise ValueError(
"We currently support 'ShapResult.to_wandb()' only with 'only_highest_proba == True' for "
"classification models."
)

def to_wandb(self, **kwargs) -> None:
"""Create and log to the WandB run SHAP charts."""
from giskard.integrations.wandb.wandb_utils import wandb_run

self._validate_config()

with wandb_run(**kwargs) as run:
charts = dict()

# Create general SHAP feature importance plot.
general_bar_plot = _wandb_general_bar_plot(self.explanations, self.feature_names)
charts.update({f"{PanelNames.GENERAL}/general_shap_bar_plot": general_bar_plot})

# Create per-feature SHAP plots.
for feature_name, feature_type in self.feature_types.items():
if feature_type == "category":
bar_plot = _wandb_bar_plot(self.explanations, feature_name)
charts.update({f"{PanelNames.CATEGORICAL}/{feature_name}_shap_bar_plot": bar_plot})
elif feature_type == "numeric":
scatter_plot = _wandb_scatter_plot(self.explanations, feature_name)
charts.update({f"{PanelNames.NUMERICAL}/{feature_name}_shap_scatter_plot": scatter_plot})
else:
raise NotImplementedError("We do not support the SHAP logging of text features yet.")

# Log created plots.
run.log(charts)
2 changes: 1 addition & 1 deletion python-client/giskard/scanner/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,5 @@ def to_wandb(self, **kwargs):

with tempfile.NamedTemporaryFile(prefix="giskard-scan-results-", suffix=".html") as f:
self.to_html(filename=f.name)
wandb_artifact_name = f.name.split("/")[-1].split(".html")[0]
wandb_artifact_name = "Vulnerability scan results/" + f.name.split("/")[-1].split(".html")[0]
run.log({wandb_artifact_name: wandb.Html(open(f.name), inject=False)})
28 changes: 25 additions & 3 deletions python-client/tests/integrations/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,34 @@
import re

from giskard import scan
from giskard.models.model_explanation import explain_with_shap

wandb.setup(wandb.Settings(mode="disabled", program=__name__, program_relpath=__name__, disable_code=True))


@pytest.mark.parametrize(
"dataset_name,model_name",
[
("hotel_text_data", "hotel_text_model"),
("german_credit_data", "german_credit_model"),
("breast_cancer_data", "breast_cancer_model"),
("drug_classification_data", "drug_classification_model"),
("diabetes_dataset_with_target", "linear_regression_diabetes"),
("hotel_text_data", "hotel_text_model"),
],
)
def test_fast(dataset_name, model_name, request):
# Expect the 'NotImplementedError' when dataset contains textual features.
exception_fixtures = ("hotel_text_data",)

dataset = request.getfixturevalue(dataset_name)
model = request.getfixturevalue(model_name)
_to_wandb(model, dataset)

if dataset_name in exception_fixtures:
with pytest.raises(NotImplementedError) as e:
_to_wandb(model, dataset)
assert e.match(r"We do not support the SHAP logging of text*")
else:
_to_wandb(model, dataset)


@pytest.mark.parametrize(
Expand All @@ -34,9 +44,17 @@ def test_fast(dataset_name, model_name, request):
)
@pytest.mark.slow
def test_slow(dataset_name, model_name, request):
exception_fixtures = ("enron_data_full", "medical_transcript_data", "amazon_review_data")

dataset = request.getfixturevalue(dataset_name)
model = request.getfixturevalue(model_name)
_to_wandb(model, dataset)

if dataset_name in exception_fixtures:
with pytest.raises(NotImplementedError) as e:
_to_wandb(model, dataset)
assert e.match(r"We do not support the SHAP logging of text*")
else:
_to_wandb(model, dataset)


def _to_wandb(model, dataset):
Expand All @@ -51,4 +69,8 @@ def _to_wandb(model, dataset):
test_suite_results = scan_results.generate_test_suite().run()
test_suite_results.to_wandb()

# Verify that the logging of the SHAP explanation charts works.
explanation_results = explain_with_shap(model, dataset)
explanation_results.to_wandb()

assert re.match("^[0-9a-z]{8}$", str(wandb.run.id))
Loading