Skip to content
22 changes: 22 additions & 0 deletions python-client/giskard/testing/tests/metamorphic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Optional

import numpy as np
import pandas as pd

from giskard import test
Expand All @@ -14,6 +16,7 @@
from giskard.ml_worker.utils.logging import timer
from giskard.models.base import BaseModel
from giskard.models.utils import fix_seed
from giskard.scanner.llm.utils import LLMImportError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should move LLMImportError outside of scanner now that we are integrating it more with the rest of the codebase. But not super important.



def _predict_numeric_result(model: BaseModel, ds: Dataset, output_proba=True, classification_label=None):
Expand Down Expand Up @@ -61,6 +64,25 @@ def _compare_prediction(results_df, prediction_task, direction, output_sensitivi
if prediction_task == SupportedModelTypes.CLASSIFICATION:
passed_idx = results_df.loc[results_df["prediction"] == results_df["perturbed_prediction"]].index.values

elif prediction_task == SupportedModelTypes.TEXT_GENERATION:
try:
import evaluate
scorer = evaluate.load("bertscore")
except ImportError as err:
raise LLMImportError() from err
except FileNotFoundError as err:
raise LLMImportError("Your version of evaluate does not support 'bertscore'. "
"Please use 'pip install -U evaluate' to upgrade it") from err

score = scorer.compute(
predictions=results_df["perturbed_prediction"].values,
references=results_df["prediction"].values,
model_type="distilbert-base-multilingual-cased",
idf=True,
)
passed = np.array(score["f1"]) > 1 - output_sensitivity
passed_idx = results_df.loc[passed].index.values

elif prediction_task == SupportedModelTypes.REGRESSION:
results_df["predict_difference_ratio"] = results_df.apply(
lambda x: _prediction_ratio(x["prediction"], x["perturbed_prediction"]),
Expand Down
41 changes: 41 additions & 0 deletions python-client/tests/test_metamorphic_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import giskard.testing.tests.metamorphic as metamorphic
from giskard import Dataset, Model
from giskard.ml_worker.testing.registry.transformation_function import transformation_function
from giskard.ml_worker.testing.stat_utils import equivalence_t_test, paired_t_test
from giskard.ml_worker.testing.stat_utils import equivalence_wilcoxon, paired_wilcoxon
Expand Down Expand Up @@ -246,3 +247,43 @@ def perturbation(x: pd.Series) -> pd.Series:

assert results.actual_slices_size[0] == len(german_credit_test_data)
assert results.passed, f"metric = {results.metric}"


def test_metamorphic_invariance_llm():
from langchain.chains import LLMChain
from langchain.llms.fake import FakeListLLM
from langchain.prompts import PromptTemplate

responses = [
"\n\nHueFoots.",
"\n\nEcoDrive Motors.",
"\n\nRainbow Socks.",
"\n\nNoOil Motors.",
]
llm = FakeListLLM(responses=responses)
prompt = PromptTemplate(
input_variables=["product"],
template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=llm, prompt=prompt)

wrapped_model = Model(chain, model_type="text_generation")
df = pd.DataFrame(["colorful socks", "electric car"], columns=["product"])

wrapped_dataset = Dataset(df, cat_columns=[])

@transformation_function()
def perturbation(x: pd.Series) -> pd.Series:
x["product"] = f"some {x['product']}"
return x

results = metamorphic.test_metamorphic_invariance(
model=wrapped_model,
dataset=wrapped_dataset,
transformation_function=perturbation,
threshold=0.1,
output_sensitivity=0.2,
).execute()

assert results.actual_slices_size[0] == 2
assert results.passed