Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 26 additions & 53 deletions python-client/giskard/push/contribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def _get_model_predictions(model: BaseModel, sliced_ds: Dataset):
return raw_prediction, correct_prediction


def _create_non_text_contribution_push(
shap_feature, sliced_ds, ds, model, correct_prediction
):
def _create_non_text_contribution_push(shap_feature, sliced_ds, ds, model, correct_prediction):
"""
Create a ContributionPush object for non-text features with outlier contributions.

Expand All @@ -72,9 +70,7 @@ def _create_non_text_contribution_push(
ContributionPush: An object representing the contribution push for non-text features.
"""
# Calculate bounds for the feature
bounds = slice_bounds_relative(
feature=shap_feature, value=sliced_ds.df[shap_feature].values[0], ds=ds
)
bounds = slice_bounds_relative(feature=shap_feature, value=sliced_ds.df[shap_feature].values[0], ds=ds)

# Create the ContributionPush object
return ContributionPush(
Expand All @@ -86,9 +82,7 @@ def _create_non_text_contribution_push(
)


def _create_text_contribution_push(
shap_feature, sliced_ds, model, raw_prediction, correct_prediction
):
def _create_text_contribution_push(shap_feature, sliced_ds, model, raw_prediction, correct_prediction):
"""
Create a ContributionPush object for text features.

Expand All @@ -103,36 +97,36 @@ def _create_text_contribution_push(
ContributionPush: An object representing the contribution push for text features.
"""
# Explain the text feature
input_df = model.prepare_dataframe(sliced_ds.df, column_dtypes=sliced_ds.column_dtypes, target=sliced_ds.target)

text_explanation = explain_text(
model=model,
input_df=sliced_ds.df,
input_df=input_df,
text_column=shap_feature,
text_document=sliced_ds.df[shap_feature].iloc[0],
)

# Create a dictionary mapping words to their importance scores
if model.meta.model_type == SupportedModelTypes.CLASSIFICATION:
text_explanation_map = dict(
zip(text_explanation[0], text_explanation[1][raw_prediction])
)
text_explanation_map = dict(zip(text_explanation[0], text_explanation[1][raw_prediction]))
else:
text_explanation_map = dict(zip(text_explanation[0], text_explanation[1]))

# Detect the most important word based on the explanation
most_important_word = _detect_text_shap_outlier(text_explanation_map)

return ContributionPush(
feature=shap_feature,
feature_type="text",
value=most_important_word,
model_type=model.meta.model_type,
correct_prediction=correct_prediction,
)
# If for any reason, the most_important_word is not found, don't create a push notification
if most_important_word is not None:
return ContributionPush(
feature=shap_feature,
feature_type="text",
value=most_important_word,
model_type=model.meta.model_type,
correct_prediction=correct_prediction,
)


def create_contribution_push(
model: BaseModel, ds: Dataset, df: pd.DataFrame
) -> ContributionPush:
def create_contribution_push(model: BaseModel, ds: Dataset, df: pd.DataFrame) -> ContributionPush:
"""
Create contribution notification from SHAP values.

Expand All @@ -150,45 +144,30 @@ def create_contribution_push(
ContributionPush if outlier contribution found, else None
"""
# Check if there is only one feature type in the dataset, and if it's "text"
_text_is_the_only_feature = (
len(ds.column_types.values()) == 1
and list(ds.column_types.values())[0] == "text"
)
_text_is_the_only_feature = len(ds.column_types.values()) == 1 and list(ds.column_types.values())[0] == "text"

# Get global SHAP values for the model's predictions on the input DataFrame
global_feature_shap = _get_shap_values(model, ds, df)

# Detect the SHAP feature with outlier contribution, if SHAP values exist
shap_feature = (
_detect_shap_outlier(global_feature_shap) if _existing_shap_values(ds) else None
)
shap_feature = _detect_shap_outlier(global_feature_shap) if _existing_shap_values(ds) else None

# Check if a SHAP feature with outlier contribution was detected
_shap_outlier_detected = shap_feature is not None

# Determine if model predictions are needed based on the presence of outlier SHAP features and target variable
_model_predictions_needed = (
_shap_outlier_detected or _text_is_the_only_feature
) and ds.target is not None
_model_predictions_needed = (_shap_outlier_detected or _text_is_the_only_feature) and ds.target is not None

# If model predictions are needed, continue
if _model_predictions_needed:
# Choose the SHAP feature for analysis, considering the case when there's only one text feature
shap_feature = (
shap_feature
if not _text_is_the_only_feature
else list(ds.column_types.keys())[0]
)
shap_feature = shap_feature if not _text_is_the_only_feature else list(ds.column_types.keys())[0]

# Check if the SHAP feature is not a text feature
_shap_feature_is_not_text = (
_shap_outlier_detected and ds.column_types[shap_feature] != "text"
)
_shap_feature_is_not_text = _shap_outlier_detected and ds.column_types[shap_feature] != "text"

# Check if the SHAP feature is a text feature or it's the only feature in the dataset
_shap_feature_is_text = (
_shap_outlier_detected and ds.column_types[shap_feature] == "text"
)
_shap_feature_is_text = _shap_outlier_detected and ds.column_types[shap_feature] == "text"

# Create a new dataset for analysis, copying the column types and excluding validation
sliced_ds = Dataset(
Expand All @@ -203,15 +182,11 @@ def create_contribution_push(

# If the SHAP feature is not a text feature, create a non-text ContributionPush
if _shap_feature_is_not_text:
return _create_non_text_contribution_push(
shap_feature, sliced_ds, ds, model, correct_prediction
)
return _create_non_text_contribution_push(shap_feature, sliced_ds, ds, model, correct_prediction)

# If the SHAP feature is a text feature or it's the only feature, create a text ContributionPush
elif _shap_feature_is_text or _text_is_the_only_feature:
return _create_text_contribution_push(
shap_feature, sliced_ds, model, raw_prediction, correct_prediction
)
return _create_text_contribution_push(shap_feature, sliced_ds, model, raw_prediction, correct_prediction)


def _detect_shap_outlier(global_feature_shap):
Expand Down Expand Up @@ -280,9 +255,7 @@ def _get_shap_values(model: BaseModel, ds: Dataset, df: pd.DataFrame):
Dictionary of SHAP values per feature
"""
if model.meta.model_type == SupportedModelTypes.CLASSIFICATION:
return explain(model, ds, df.iloc[0])["explanations"][
model.meta.classification_labels[0]
]
return explain(model, ds, df.iloc[0])["explanations"][model.meta.classification_labels[0]]
elif model.meta.model_type == SupportedModelTypes.REGRESSION:
return explain(model, ds, df.iloc[0])["explanations"]["default"]

Expand Down
9 changes: 9 additions & 0 deletions python-client/tests/test_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,12 @@ def test_coltype_to_supported_perturbation_type():

perturbation_type = coltype_to_supported_perturbation_type("text")
assert perturbation_type == SupportedPerturbationType.TEXT


def test_text_explain_in_push(medical_transcript_model, medical_transcript_data):

problematic_df_entry = medical_transcript_data.df.iloc[[3]]
output = create_contribution_push(medical_transcript_model, medical_transcript_data, problematic_df_entry)

assert output is not None
assert output.value is not None