Skip to content
18 changes: 10 additions & 8 deletions giskard/core/dataset_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@
from giskard.datasets.base import Dataset


def validate_target(ds: Dataset):
if not ds.target:
def validate_optional_target(ds: Dataset):
if ds.target is None:
warning(
"You did not provide the optional argument 'target'. "
"'target' is the column name in df corresponding to the actual target variable (ground truth)."
)
else:
if ds.target not in ds.columns:
raise ValueError(
"Invalid target parameter:"
f" '{ds.target}' column is not present in the dataset with columns: {list(ds.columns)}"
)


def validate_target_exists(ds: Dataset):
if ds.target is not None and ds.target not in ds.columns:
raise ValueError(
"Invalid target parameter:"
f" '{ds.target}' column is not present in the dataset with columns: {list(ds.columns)}"
)


def validate_dtypes(ds: Dataset):
Expand Down
3 changes: 3 additions & 0 deletions giskard/core/model_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
from giskard.models.base import BaseModel, WrapperModel
from ..utils import fullname
from ..utils.analytics_collector import analytics, get_dataset_properties, get_model_properties
from .dataset_validation import validate_optional_target


@configured_validate_arguments
def validate_model(model: BaseModel, validate_ds: Optional[Dataset] = None, print_validation_message: bool = True):
try:
if model.meta.model_type != SupportedModelTypes.TEXT_GENERATION and validate_ds is not None:
validate_optional_target(validate_ds)
Comment on lines +23 to +24
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this in validate_model???

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should refactor model_validation and dataset_validation into one common validation, but the way the code is structured right now, I needed model.meta.model_type to do validate_optional_target for non-LLM models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just fix the llm-warning issue for now and deal with the refactoring post-release, I added a card here: https://linear.app/giskard/issue/GSK-2064/refactor-model-and-dataset-validation

_do_validate_model(model, validate_ds)
except (ValueError, TypeError) as err:
_track_validation_error(err, model, validate_ds)
Expand Down
4 changes: 2 additions & 2 deletions giskard/datasets/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,10 @@ def __init__(
self.target = target

if validation:
from giskard.core.dataset_validation import validate_dtypes, validate_target
from giskard.core.dataset_validation import validate_dtypes, validate_target_exists

validate_dtypes(self)
validate_target(self)
validate_target_exists(self)

self.column_dtypes = self.extract_column_dtypes(self.df)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import requests_mock

from giskard.datasets.base import Dataset
from giskard.core.dataset_validation import validate_optional_target
from giskard.client.dtos import DatasetMetaInfo

from tests import utils
Expand All @@ -29,7 +30,6 @@
"target",
]


valid_df = pd.DataFrame(
{
"categorical_column": ["turtle", "crocodile", "turtle"],
Expand Down Expand Up @@ -65,6 +65,7 @@ def test_valid_df_column_types():
r"in df corresponding to the actual target variable \(ground truth\)\.",
):
my_dataset = Dataset(valid_df)
validate_optional_target(my_dataset)
assert my_dataset.column_types == {
"categorical_column": "category",
"text_column": "text",
Expand Down
9 changes: 6 additions & 3 deletions tests/testing/test_llm_output_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
@patch("giskard.testing.tests.llm.output_requirements.RequirementEvaluator")
def test_llm_output_requirement(RequirementEvaluator):
model = sentinel.model
dataset = Dataset(pd.DataFrame({"feature": ["value"]}), target=False)
dataset = Dataset(pd.DataFrame({"feature": ["value"]}))

# Successful test
RequirementEvaluator.return_value.evaluate.return_value = EvaluationResult(
failure_examples=[], success_examples=_demo_samples, errors=[]
)

my_test = llm_tests.test_llm_output_against_requirement(
model=model, dataset=dataset, requirement="The model should not generate content that incites harm or violence", debug=True
model=model,
dataset=dataset,
requirement="The model should not generate content that incites harm or violence",
debug=True,
)
res = my_test.execute()
assert res.passed
Expand Down Expand Up @@ -74,7 +77,7 @@ def test_llm_single_output_requirement(RequirementEvaluator):
model=model,
input_var=input_var,
requirement="The model should not generate content that incites harm or violence",
debug=True
debug=True,
)
res = my_test.execute()
assert res.passed
Expand Down