diff --git a/giskard/core/dataset_validation.py b/giskard/core/dataset_validation.py index 057806b992..4e3c690990 100644 --- a/giskard/core/dataset_validation.py +++ b/giskard/core/dataset_validation.py @@ -8,18 +8,20 @@ from giskard.datasets.base import Dataset -def validate_target(ds: Dataset): - if not ds.target: +def validate_optional_target(ds: Dataset): + if ds.target is None: warning( "You did not provide the optional argument 'target'. " "'target' is the column name in df corresponding to the actual target variable (ground truth)." ) - else: - if ds.target not in ds.columns: - raise ValueError( - "Invalid target parameter:" - f" '{ds.target}' column is not present in the dataset with columns: {list(ds.columns)}" - ) + + +def validate_target_exists(ds: Dataset): + if ds.target is not None and ds.target not in ds.columns: + raise ValueError( + "Invalid target parameter:" + f" '{ds.target}' column is not present in the dataset with columns: {list(ds.columns)}" + ) def validate_dtypes(ds: Dataset): diff --git a/giskard/core/model_validation.py b/giskard/core/model_validation.py index 899c19082f..7f93b5bf78 100644 --- a/giskard/core/model_validation.py +++ b/giskard/core/model_validation.py @@ -14,11 +14,14 @@ from giskard.models.base import BaseModel, WrapperModel from ..utils import fullname from ..utils.analytics_collector import analytics, get_dataset_properties, get_model_properties +from .dataset_validation import validate_optional_target @configured_validate_arguments def validate_model(model: BaseModel, validate_ds: Optional[Dataset] = None, print_validation_message: bool = True): try: + if model.meta.model_type != SupportedModelTypes.TEXT_GENERATION and validate_ds is not None: + validate_optional_target(validate_ds) _do_validate_model(model, validate_ds) except (ValueError, TypeError) as err: _track_validation_error(err, model, validate_ds) diff --git a/giskard/datasets/base/__init__.py b/giskard/datasets/base/__init__.py index d403f925c2..3d6b8f502b 100644 --- a/giskard/datasets/base/__init__.py +++ b/giskard/datasets/base/__init__.py @@ -186,10 +186,10 @@ def __init__( self.target = target if validation: - from giskard.core.dataset_validation import validate_dtypes, validate_target + from giskard.core.dataset_validation import validate_dtypes, validate_target_exists validate_dtypes(self) - validate_target(self) + validate_target_exists(self) self.column_dtypes = self.extract_column_dtypes(self.df) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a81e931732..9426c1a6b1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -7,6 +7,7 @@ import requests_mock from giskard.datasets.base import Dataset +from giskard.core.dataset_validation import validate_optional_target from giskard.client.dtos import DatasetMetaInfo from tests import utils @@ -29,7 +30,6 @@ "target", ] - valid_df = pd.DataFrame( { "categorical_column": ["turtle", "crocodile", "turtle"], @@ -65,6 +65,7 @@ def test_valid_df_column_types(): r"in df corresponding to the actual target variable \(ground truth\)\.", ): my_dataset = Dataset(valid_df) + validate_optional_target(my_dataset) assert my_dataset.column_types == { "categorical_column": "category", "text_column": "text", diff --git a/tests/testing/test_llm_output_requirement.py b/tests/testing/test_llm_output_requirement.py index 4e91f1aa3a..2d1f1ef30e 100644 --- a/tests/testing/test_llm_output_requirement.py +++ b/tests/testing/test_llm_output_requirement.py @@ -15,7 +15,7 @@ @patch("giskard.testing.tests.llm.output_requirements.RequirementEvaluator") def test_llm_output_requirement(RequirementEvaluator): model = sentinel.model - dataset = Dataset(pd.DataFrame({"feature": ["value"]}), target=False) + dataset = Dataset(pd.DataFrame({"feature": ["value"]})) # Successful test RequirementEvaluator.return_value.evaluate.return_value = EvaluationResult( @@ -23,7 +23,10 @@ def test_llm_output_requirement(RequirementEvaluator): ) my_test = llm_tests.test_llm_output_against_requirement( - model=model, dataset=dataset, requirement="The model should not generate content that incites harm or violence", debug=True + model=model, + dataset=dataset, + requirement="The model should not generate content that incites harm or violence", + debug=True, ) res = my_test.execute() assert res.passed @@ -74,7 +77,7 @@ def test_llm_single_output_requirement(RequirementEvaluator): model=model, input_var=input_var, requirement="The model should not generate content that incites harm or violence", - debug=True + debug=True, ) res = my_test.execute() assert res.passed