Skip to content
33 changes: 21 additions & 12 deletions python-client/giskard/push/perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import pandas as pd
import sys

from giskard.core.core import SupportedModelTypes
from giskard.datasets.base import Dataset
Expand All @@ -35,7 +36,7 @@
)
from ..push import PerturbationPush

text_transfo_list = [
text_transformation_list = [
TextLowercase,
TextUppercase,
TextTitleCase,
Expand All @@ -45,9 +46,7 @@
]


def create_perturbation_push(
model: BaseModel, ds: Dataset, df: pd.DataFrame
) -> PerturbationPush:
def create_perturbation_push(model: BaseModel, ds: Dataset, df: pd.DataFrame) -> PerturbationPush:
"""Create a perturbation notification by applying transformations.

Applies supported perturbations to each feature in the dataset
Expand All @@ -66,6 +65,7 @@ def create_perturbation_push(
for feat, coltype in ds.column_types.items():
coltype = coltype_to_supported_perturbation_type(coltype)
transformation_info = _apply_perturbation(model, ds, df, feat, coltype)
# df contains only one row, which is the sample being looked at in the debugger
value = df.iloc[0][feat]
if transformation_info is not None:
return PerturbationPush(
Expand Down Expand Up @@ -108,6 +108,7 @@ def _apply_perturbation(
transformation_function = list()
value_perturbed = list()
transformation_functions_params = list()

passed = False
# Create a slice of the dataset with only the row to perturb
ds_slice = Dataset(
Expand Down Expand Up @@ -167,9 +168,20 @@ def _text(
):
passed = False
# Iterate over the possible text transformations
for text_transformation in text_transfo_list:
for text_transformation in text_transformation_list:
# Create the transformation
t = text_transformation(column=feature)
_is_typo_transformation = issubclass(text_transformation, TextTypoTransformation)
kwargs = {}
if _is_typo_transformation:
# TextTypoTransformation generates a random typo for text features. In order to have the same typo per
# sample with the push feature in the debugger, we need to generate a unique seed per sample (hashed_seed)
# to guarantee the same perturbation per sample.
hashed_seed = hash(f"{', '.join(map(lambda x: repr(x), ds_slice_copy.df.values))}".encode("utf-8"))
# hash could give negative ints, and np.random.seed accepts only positive ints
positive_hashed_seed = hashed_seed % ((sys.maxsize + 1) * 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rabah-khalek could you add a commend why we're doing it? In a week we won't remember by heart

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

kwargs = {"rng_seed": positive_hashed_seed}

t = text_transformation(column=feature, **kwargs)

# Transform the slice
transformed = ds_slice_copy.transform(t)
Expand Down Expand Up @@ -206,6 +218,7 @@ def _numeric(
np.unique(np.linspace(-2 * mad, 0, num=10).round().astype(int)),
np.unique(np.linspace(2 * mad, 0, num=10).round().astype(int)),
]
# df contains only one row, which is the sample being looked at in the debugger
value_to_perturb = ds.df[feature].iloc[0]
for values_added in values_added_list:
for value in values_added:
Expand All @@ -222,17 +235,13 @@ def _numeric(
if perturbed:
value_perturbed.append(transformed.df[feature].values.item(0))
transformation_function.append(t)
transformation_functions_params.append(
dict(column_name=feature, value_added=float(value))
)
transformation_functions_params.append(dict(column_name=feature, value_added=float(value)))
else:
break
return len(transformation_function) > 0


def _check_after_perturbation(
model: BaseModel, ref_row: Dataset, row_perturbed: Dataset
) -> bool:
def _check_after_perturbation(model: BaseModel, ref_row: Dataset, row_perturbed: Dataset) -> bool:
"""
Check if perturbation changed the model's prediction.

Expand Down