From aba6ac4363a8cd8f761b235accd94b2e5dc8e82b Mon Sep 17 00:00:00 2001 From: Rabah Abdul Khalek Date: Tue, 31 Oct 2023 19:03:14 +0100 Subject: [PATCH 1/4] moved target validation and silenced it for LLMs --- giskard/core/model_validation.py | 3 +++ giskard/datasets/base/__init__.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/giskard/core/model_validation.py b/giskard/core/model_validation.py index ac5c18c632..8a607a5056 100644 --- a/giskard/core/model_validation.py +++ b/giskard/core/model_validation.py @@ -14,11 +14,14 @@ from giskard.models.base import BaseModel, WrapperModel from ..utils import fullname from ..utils.analytics_collector import analytics, get_dataset_properties, get_model_properties +from .dataset_validation import validate_target @configured_validate_arguments def validate_model(model: BaseModel, validate_ds: Optional[Dataset] = None, print_validation_message: bool = True): try: + if model.meta.model_type != SupportedModelTypes.TEXT_GENERATION: + validate_target(validate_ds) _do_validate_model(model, validate_ds) except (ValueError, TypeError) as err: _track_validation_error(err, model, validate_ds) diff --git a/giskard/datasets/base/__init__.py b/giskard/datasets/base/__init__.py index d403f925c2..433c45567c 100644 --- a/giskard/datasets/base/__init__.py +++ b/giskard/datasets/base/__init__.py @@ -186,10 +186,9 @@ def __init__( self.target = target if validation: - from giskard.core.dataset_validation import validate_dtypes, validate_target + from giskard.core.dataset_validation import validate_dtypes validate_dtypes(self) - validate_target(self) self.column_dtypes = self.extract_column_dtypes(self.df) From 04a431b2fe09a6dda847d8e36d3e4ef9fe7ae180 Mon Sep 17 00:00:00 2001 From: Rabah Abdul Khalek Date: Thu, 2 Nov 2023 17:47:08 +0100 Subject: [PATCH 2/4] fixing tests --- giskard/core/dataset_validation.py | 16 +++++++++------- giskard/core/model_validation.py | 6 +++--- giskard/datasets/base/__init__.py | 3 ++- tests/test_dataset.py | 2 ++ 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/giskard/core/dataset_validation.py b/giskard/core/dataset_validation.py index 057806b992..2bff0768f1 100644 --- a/giskard/core/dataset_validation.py +++ b/giskard/core/dataset_validation.py @@ -8,18 +8,20 @@ from giskard.datasets.base import Dataset -def validate_target(ds: Dataset): +def validate_optional_target(ds: Dataset): if not ds.target: warning( "You did not provide the optional argument 'target'. " "'target' is the column name in df corresponding to the actual target variable (ground truth)." ) - else: - if ds.target not in ds.columns: - raise ValueError( - "Invalid target parameter:" - f" '{ds.target}' column is not present in the dataset with columns: {list(ds.columns)}" - ) + + +def validate_target_exists(ds: Dataset): + if ds.target and ds.target not in ds.columns: + raise ValueError( + "Invalid target parameter:" + f" '{ds.target}' column is not present in the dataset with columns: {list(ds.columns)}" + ) def validate_dtypes(ds: Dataset): diff --git a/giskard/core/model_validation.py b/giskard/core/model_validation.py index 8a607a5056..93e4d39bfc 100644 --- a/giskard/core/model_validation.py +++ b/giskard/core/model_validation.py @@ -14,14 +14,14 @@ from giskard.models.base import BaseModel, WrapperModel from ..utils import fullname from ..utils.analytics_collector import analytics, get_dataset_properties, get_model_properties -from .dataset_validation import validate_target +from .dataset_validation import validate_optional_target @configured_validate_arguments def validate_model(model: BaseModel, validate_ds: Optional[Dataset] = None, print_validation_message: bool = True): try: - if model.meta.model_type != SupportedModelTypes.TEXT_GENERATION: - validate_target(validate_ds) + if model.meta.model_type != SupportedModelTypes.TEXT_GENERATION and validate_ds is not None: + validate_optional_target(validate_ds) _do_validate_model(model, validate_ds) except (ValueError, TypeError) as err: _track_validation_error(err, model, validate_ds) diff --git a/giskard/datasets/base/__init__.py b/giskard/datasets/base/__init__.py index 433c45567c..3d6b8f502b 100644 --- a/giskard/datasets/base/__init__.py +++ b/giskard/datasets/base/__init__.py @@ -186,9 +186,10 @@ def __init__( self.target = target if validation: - from giskard.core.dataset_validation import validate_dtypes + from giskard.core.dataset_validation import validate_dtypes, validate_target_exists validate_dtypes(self) + validate_target_exists(self) self.column_dtypes = self.extract_column_dtypes(self.df) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a745573a40..a6dae4f8d1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -2,6 +2,7 @@ import numpy as np import pytest from giskard.datasets.base import Dataset +from giskard.core.dataset_validation import validate_optional_target valid_df = pd.DataFrame( { @@ -38,6 +39,7 @@ def test_valid_df_column_types(): r"in df corresponding to the actual target variable \(ground truth\)\.", ): my_dataset = Dataset(valid_df) + validate_optional_target(my_dataset) assert my_dataset.column_types == { "categorical_column": "category", "text_column": "text", From 75dc072d1f40f09d78091a42e48738e77b10e621 Mon Sep 17 00:00:00 2001 From: Rabah Abdul Khalek Date: Sun, 5 Nov 2023 13:56:07 +0100 Subject: [PATCH 3/4] explicit check for ds.target --- giskard/core/dataset_validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/giskard/core/dataset_validation.py b/giskard/core/dataset_validation.py index 2bff0768f1..4e3c690990 100644 --- a/giskard/core/dataset_validation.py +++ b/giskard/core/dataset_validation.py @@ -9,7 +9,7 @@ def validate_optional_target(ds: Dataset): - if not ds.target: + if ds.target is None: warning( "You did not provide the optional argument 'target'. " "'target' is the column name in df corresponding to the actual target variable (ground truth)." @@ -17,7 +17,7 @@ def validate_optional_target(ds: Dataset): def validate_target_exists(ds: Dataset): - if ds.target and ds.target not in ds.columns: + if ds.target is not None and ds.target not in ds.columns: raise ValueError( "Invalid target parameter:" f" '{ds.target}' column is not present in the dataset with columns: {list(ds.columns)}" From 7d43b6a8ec5aa287bd64a2365201d501b35cf95e Mon Sep 17 00:00:00 2001 From: Rabah Abdul Khalek Date: Mon, 6 Nov 2023 19:34:19 +0100 Subject: [PATCH 4/4] fixed tests --- tests/testing/test_llm_output_requirement.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/testing/test_llm_output_requirement.py b/tests/testing/test_llm_output_requirement.py index 4e91f1aa3a..2d1f1ef30e 100644 --- a/tests/testing/test_llm_output_requirement.py +++ b/tests/testing/test_llm_output_requirement.py @@ -15,7 +15,7 @@ @patch("giskard.testing.tests.llm.output_requirements.RequirementEvaluator") def test_llm_output_requirement(RequirementEvaluator): model = sentinel.model - dataset = Dataset(pd.DataFrame({"feature": ["value"]}), target=False) + dataset = Dataset(pd.DataFrame({"feature": ["value"]})) # Successful test RequirementEvaluator.return_value.evaluate.return_value = EvaluationResult( @@ -23,7 +23,10 @@ def test_llm_output_requirement(RequirementEvaluator): ) my_test = llm_tests.test_llm_output_against_requirement( - model=model, dataset=dataset, requirement="The model should not generate content that incites harm or violence", debug=True + model=model, + dataset=dataset, + requirement="The model should not generate content that incites harm or violence", + debug=True, ) res = my_test.execute() assert res.passed @@ -74,7 +77,7 @@ def test_llm_single_output_requirement(RequirementEvaluator): model=model, input_var=input_var, requirement="The model should not generate content that incites harm or violence", - debug=True + debug=True, ) res = my_test.execute() assert res.passed