-
-
Notifications
You must be signed in to change notification settings - Fork 379
GSK-1533 #1307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GSK-1533 #1307
Changes from 26 commits
7789fa9
e31798e
e9de16d
4f92418
1a59240
cc15925
5fbb14c
c1c075e
9dbf847
167c629
ff6ed6f
6366064
75a013f
4935cd8
8ad73b1
6bf77d7
d2c4739
c4ee3c1
8b11843
660fd66
99f82b8
0746ff3
72bd4d4
0a5928f
d80aae5
d98b943
feffd84
868c33f
a2e9833
ce40cf3
309a64d
428e919
1b2302f
e0c20bb
0083f30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| from enum import Enum | ||
| from typing import Any, Iterable | ||
| from dataclasses import dataclass | ||
|
|
||
| import wandb | ||
|
||
| import numpy as np | ||
| import pandas as pd | ||
| from shap import Explanation | ||
|
|
||
| from giskard.core.core import ModelType, SupportedModelTypes | ||
|
|
||
|
|
||
| class PanelNames(str, Enum): | ||
| CATEGORICAL = "Feature importance for categorical features" | ||
| NUMERICAL = "Feature importance for numerical features" | ||
| GENERAL = "Global feature importance" | ||
|
|
||
|
|
||
| def _wandb_bar_plot(shap_explanations: Explanation, feature_name: str) -> Any: | ||
| """Get wandb bar plot of shap values of the categorical feature.""" | ||
| feature_column = "feature_values" | ||
| shap_column = "shap_abs_values" | ||
|
|
||
| # Extract feature values and related shap explanations. | ||
| shap_values = shap_explanations[:, feature_name].values | ||
| feature_values = shap_explanations[:, feature_name].data | ||
|
|
||
| # We are interested in magnitude. | ||
| shap_abs_values = np.abs(shap_values) | ||
|
|
||
| # Calculate mean shap value per feature value. | ||
| df = pd.DataFrame(data={feature_column: feature_values, shap_column: shap_abs_values}) | ||
| shap_abs_means = pd.DataFrame(df.groupby(feature_column)[shap_column].mean()).reset_index() | ||
|
|
||
| # Create bar plot. | ||
| table = wandb.Table(dataframe=shap_abs_means) | ||
| plot = wandb.plot.bar( | ||
| table, label=feature_column, value=shap_column, title=f"Mean(Abs(SHAP)) of '{feature_name}' feature values" | ||
| ) | ||
|
|
||
| return plot | ||
|
|
||
|
|
||
| def _wandb_scatter_plot(shap_explanations: Explanation, feature_name: str) -> Any: | ||
| """Get wandb scatter plot of shap values of the numerical feature.""" | ||
| feature_column = "feature_values" | ||
| shap_column = "shap_values" | ||
|
|
||
| # Extract feature values and related shap explanations. | ||
| shap_values = shap_explanations[:, feature_name].values | ||
| feature_values = shap_explanations[:, feature_name].data | ||
|
|
||
| # Create scatter plot. | ||
| df = pd.DataFrame(data={feature_column: feature_values, shap_column: shap_values}) | ||
| table = wandb.Table(dataframe=df) | ||
| plot = wandb.plot.scatter( | ||
| table, y=feature_column, x=shap_column, title=f"'{feature_name}' feature values vs SHAP values" | ||
| ) | ||
|
|
||
| return plot | ||
|
|
||
|
|
||
| def _wandb_general_bar_plot(shap_explanations: Explanation, feature_names: Iterable) -> Any: | ||
| """Get wandb bar plot of general shap mean values.""" | ||
| feature_column = "feature" | ||
| shap_column = "global_shap_mean" | ||
|
|
||
| # Calculate global shap means. | ||
| shap_general_means = list() | ||
|
|
||
| for feature_name in feature_names: | ||
| shap_general_means.append(np.abs(shap_explanations[:, feature_name].values).mean()) | ||
|
|
||
| # Create bar plot. | ||
| df = pd.DataFrame(data={feature_column: feature_names, shap_column: shap_general_means}) | ||
| table = wandb.Table(dataframe=df) | ||
| plot = wandb.plot.bar( | ||
| table, label=feature_column, value=shap_column, title="General Mean(Abs(SHAP)) across all features" | ||
| ) | ||
|
|
||
| return plot | ||
|
|
||
|
|
||
| @dataclass | ||
| class ShapResult: | ||
| explanations: Explanation = None | ||
| feature_types: dict = None | ||
| feature_names: list = None | ||
| model_type: ModelType = None | ||
| only_highest_prob: bool = True | ||
|
|
||
| def _validate_config(self): | ||
AbSsEnT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not self.only_highest_prob and self.model_type == SupportedModelTypes.CLASSIFICATION: | ||
| raise ValueError( | ||
| "We currently support 'ShapResult.to_wandb()' only with 'only_highest_proba == True' for " | ||
| "classification models." | ||
| ) | ||
|
|
||
| def to_wandb(self, **kwargs) -> None: | ||
AbSsEnT marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Create and log to the WandB run SHAP charts.""" | ||
| from giskard.integrations.wandb.wandb_utils import wandb_run | ||
|
|
||
| self._validate_config() | ||
|
|
||
| with wandb_run(**kwargs) as run: | ||
| charts = dict() | ||
|
|
||
| # Create general SHAP feature importance plot. | ||
| general_bar_plot = _wandb_general_bar_plot(self.explanations, self.feature_names) | ||
| charts.update({f"{PanelNames.GENERAL}/general_shap_bar_plot": general_bar_plot}) | ||
|
|
||
| # Create per-feature SHAP plots. | ||
| for feature_name, feature_type in self.feature_types.items(): | ||
| if feature_type == "category": | ||
| bar_plot = _wandb_bar_plot(self.explanations, feature_name) | ||
| charts.update({f"{PanelNames.CATEGORICAL}/{feature_name}_shap_bar_plot": bar_plot}) | ||
| elif feature_type == "numeric": | ||
| scatter_plot = _wandb_scatter_plot(self.explanations, feature_name) | ||
| charts.update({f"{PanelNames.NUMERICAL}/{feature_name}_shap_scatter_plot": scatter_plot}) | ||
| else: | ||
| raise NotImplementedError("We do not support the SHAP logging of text features yet.") | ||
|
|
||
| # Log created plots. | ||
| run.log(charts) | ||
Uh oh!
There was an error while loading. Please reload this page.