diff --git a/python-client/giskard/push/contribution.py b/python-client/giskard/push/contribution.py index 0b6ea17b5d..6dac636a4f 100644 --- a/python-client/giskard/push/contribution.py +++ b/python-client/giskard/push/contribution.py @@ -55,9 +55,7 @@ def _get_model_predictions(model: BaseModel, sliced_ds: Dataset): return raw_prediction, correct_prediction -def _create_non_text_contribution_push( - shap_feature, sliced_ds, ds, model, correct_prediction -): +def _create_non_text_contribution_push(shap_feature, sliced_ds, ds, model, correct_prediction): """ Create a ContributionPush object for non-text features with outlier contributions. @@ -72,9 +70,7 @@ def _create_non_text_contribution_push( ContributionPush: An object representing the contribution push for non-text features. """ # Calculate bounds for the feature - bounds = slice_bounds_relative( - feature=shap_feature, value=sliced_ds.df[shap_feature].values[0], ds=ds - ) + bounds = slice_bounds_relative(feature=shap_feature, value=sliced_ds.df[shap_feature].values[0], ds=ds) # Create the ContributionPush object return ContributionPush( @@ -86,9 +82,7 @@ def _create_non_text_contribution_push( ) -def _create_text_contribution_push( - shap_feature, sliced_ds, model, raw_prediction, correct_prediction -): +def _create_text_contribution_push(shap_feature, sliced_ds, model, raw_prediction, correct_prediction): """ Create a ContributionPush object for text features. @@ -103,36 +97,36 @@ def _create_text_contribution_push( ContributionPush: An object representing the contribution push for text features. """ # Explain the text feature + input_df = model.prepare_dataframe(sliced_ds.df, column_dtypes=sliced_ds.column_dtypes, target=sliced_ds.target) + text_explanation = explain_text( model=model, - input_df=sliced_ds.df, + input_df=input_df, text_column=shap_feature, text_document=sliced_ds.df[shap_feature].iloc[0], ) # Create a dictionary mapping words to their importance scores if model.meta.model_type == SupportedModelTypes.CLASSIFICATION: - text_explanation_map = dict( - zip(text_explanation[0], text_explanation[1][raw_prediction]) - ) + text_explanation_map = dict(zip(text_explanation[0], text_explanation[1][raw_prediction])) else: text_explanation_map = dict(zip(text_explanation[0], text_explanation[1])) # Detect the most important word based on the explanation most_important_word = _detect_text_shap_outlier(text_explanation_map) - return ContributionPush( - feature=shap_feature, - feature_type="text", - value=most_important_word, - model_type=model.meta.model_type, - correct_prediction=correct_prediction, - ) + # If for any reason, the most_important_word is not found, don't create a push notification + if most_important_word is not None: + return ContributionPush( + feature=shap_feature, + feature_type="text", + value=most_important_word, + model_type=model.meta.model_type, + correct_prediction=correct_prediction, + ) -def create_contribution_push( - model: BaseModel, ds: Dataset, df: pd.DataFrame -) -> ContributionPush: +def create_contribution_push(model: BaseModel, ds: Dataset, df: pd.DataFrame) -> ContributionPush: """ Create contribution notification from SHAP values. @@ -150,45 +144,30 @@ def create_contribution_push( ContributionPush if outlier contribution found, else None """ # Check if there is only one feature type in the dataset, and if it's "text" - _text_is_the_only_feature = ( - len(ds.column_types.values()) == 1 - and list(ds.column_types.values())[0] == "text" - ) + _text_is_the_only_feature = len(ds.column_types.values()) == 1 and list(ds.column_types.values())[0] == "text" # Get global SHAP values for the model's predictions on the input DataFrame global_feature_shap = _get_shap_values(model, ds, df) # Detect the SHAP feature with outlier contribution, if SHAP values exist - shap_feature = ( - _detect_shap_outlier(global_feature_shap) if _existing_shap_values(ds) else None - ) + shap_feature = _detect_shap_outlier(global_feature_shap) if _existing_shap_values(ds) else None # Check if a SHAP feature with outlier contribution was detected _shap_outlier_detected = shap_feature is not None # Determine if model predictions are needed based on the presence of outlier SHAP features and target variable - _model_predictions_needed = ( - _shap_outlier_detected or _text_is_the_only_feature - ) and ds.target is not None + _model_predictions_needed = (_shap_outlier_detected or _text_is_the_only_feature) and ds.target is not None # If model predictions are needed, continue if _model_predictions_needed: # Choose the SHAP feature for analysis, considering the case when there's only one text feature - shap_feature = ( - shap_feature - if not _text_is_the_only_feature - else list(ds.column_types.keys())[0] - ) + shap_feature = shap_feature if not _text_is_the_only_feature else list(ds.column_types.keys())[0] # Check if the SHAP feature is not a text feature - _shap_feature_is_not_text = ( - _shap_outlier_detected and ds.column_types[shap_feature] != "text" - ) + _shap_feature_is_not_text = _shap_outlier_detected and ds.column_types[shap_feature] != "text" # Check if the SHAP feature is a text feature or it's the only feature in the dataset - _shap_feature_is_text = ( - _shap_outlier_detected and ds.column_types[shap_feature] == "text" - ) + _shap_feature_is_text = _shap_outlier_detected and ds.column_types[shap_feature] == "text" # Create a new dataset for analysis, copying the column types and excluding validation sliced_ds = Dataset( @@ -203,15 +182,11 @@ def create_contribution_push( # If the SHAP feature is not a text feature, create a non-text ContributionPush if _shap_feature_is_not_text: - return _create_non_text_contribution_push( - shap_feature, sliced_ds, ds, model, correct_prediction - ) + return _create_non_text_contribution_push(shap_feature, sliced_ds, ds, model, correct_prediction) # If the SHAP feature is a text feature or it's the only feature, create a text ContributionPush elif _shap_feature_is_text or _text_is_the_only_feature: - return _create_text_contribution_push( - shap_feature, sliced_ds, model, raw_prediction, correct_prediction - ) + return _create_text_contribution_push(shap_feature, sliced_ds, model, raw_prediction, correct_prediction) def _detect_shap_outlier(global_feature_shap): @@ -280,9 +255,7 @@ def _get_shap_values(model: BaseModel, ds: Dataset, df: pd.DataFrame): Dictionary of SHAP values per feature """ if model.meta.model_type == SupportedModelTypes.CLASSIFICATION: - return explain(model, ds, df.iloc[0])["explanations"][ - model.meta.classification_labels[0] - ] + return explain(model, ds, df.iloc[0])["explanations"][model.meta.classification_labels[0]] elif model.meta.model_type == SupportedModelTypes.REGRESSION: return explain(model, ds, df.iloc[0])["explanations"]["default"] diff --git a/python-client/tests/test_push.py b/python-client/tests/test_push.py index e3f4da6676..03bceca4dd 100644 --- a/python-client/tests/test_push.py +++ b/python-client/tests/test_push.py @@ -211,3 +211,12 @@ def test_coltype_to_supported_perturbation_type(): perturbation_type = coltype_to_supported_perturbation_type("text") assert perturbation_type == SupportedPerturbationType.TEXT + + +def test_text_explain_in_push(medical_transcript_model, medical_transcript_data): + + problematic_df_entry = medical_transcript_data.df.iloc[[3]] + output = create_contribution_push(medical_transcript_model, medical_transcript_data, problematic_df_entry) + + assert output is not None + assert output.value is not None