Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion giskard/core/savable.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def _load_meta_locally(cls, local_dir, uuid: str) -> Optional[SMT]:
if meta is not None:
return meta

return super()._load_meta_locally(local_dir, uuid)
with open(local_dir / "meta.yaml", "r") as f:
return yaml.load(f, Loader=yaml.Loader)

@classmethod
def load(cls, local_dir: Path, uuid: str, meta: SMT):
Expand Down
181 changes: 180 additions & 1 deletion giskard/core/suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dataclasses import dataclass
from datetime import datetime
from functools import singledispatchmethod
from pathlib import Path
from xml.dom import minidom
from xml.etree.ElementTree import Element, SubElement, tostring

Expand Down Expand Up @@ -59,6 +60,45 @@
]


def _parse_function_arguments(folder: Path, function_inputs: List[Dict[str, Any]]):
arguments = dict()

for value in function_inputs:
if value.get("isAlias", False) or value.get("isDefaultValue", False):
continue
if value["type"] == "Dataset":
arguments[value["name"]] = Dataset.load(folder / value["value"])
elif value["type"] == "BaseModel":
arguments[value["name"]] = BaseModel.load(folder / value["value"])
elif value["type"] == "SlicingFunction":
sf_uuid = value["value"]
sf_folder = folder / sf_uuid
arguments[value["name"]] = SlicingFunction.load(
sf_folder, sf_uuid, SlicingFunction._load_meta_locally(sf_folder, sf_uuid)
)(**_parse_function_arguments(folder, value["params"]))
elif value["type"] == "TransformationFunction":
tf_uuid = value["value"]
tf_folder = folder / tf_uuid
arguments[value["name"]] = TransformationFunction.load(
tf_folder, tf_uuid, TransformationFunction._load_meta_locally(tf_folder, tf_uuid)
)(**_parse_function_arguments(folder, value["params"]))
elif value["type"] == "float":
arguments[value["name"]] = float(value["value"])
elif value["type"] == "int":
arguments[value["name"]] = int(value["value"])
elif value["type"] == "str":
arguments[value["name"]] = str(value["value"])
elif value["type"] == "bool":
arguments[value["name"]] = bool(value["value"])
elif value["type"] == "Kwargs":
kwargs = dict()
exec(value["value"], {"kwargs": kwargs})
arguments.update(kwargs)
else:
raise IllegalArgumentError(f"Unknown argument type: {value['type']}")
return arguments


def parse_function_arguments(client, project_key, function_inputs):
arguments = dict()

Expand Down Expand Up @@ -450,6 +490,41 @@ def to_dto(self, client: GiskardClient, project_key: str, uploaded_uuid_status:
displayName=self.display_name,
)

def _to_json(self, folder: Path, saved_uuid_status: Dict[str, bool]):
params = dict(
{
pname: _build_test_input_json(
folder,
p,
pname,
self.giskard_test.meta.args[pname].type,
saved_uuid_status,
)
for pname, p in self.provided_inputs.items()
if pname in self.giskard_test.meta.args
}
)

kwargs_params = [
f"{get_imports_code(value)}\nkwargs[{repr(pname)}] = {repr(value)}"
for pname, value in self.provided_inputs.items()
if pname not in self.giskard_test.meta.args
]
if len(kwargs_params) > 0:
params["kwargs"] = {"name": "kwargs", "value": "\n".join(kwargs_params), "type": "Kwargs"}

if self.giskard_test.meta.uuid not in saved_uuid_status:
test_folder = folder / str(self.giskard_test.meta.uuid)
test_folder.mkdir(exist_ok=True)
self.giskard_test.save(test_folder)

return {
"id": self.suite_test_id,
"testUuid": str(self.giskard_test.meta.uuid),
"functionInputs": params,
"displayName": self.display_name,
}


def single_binary_result(test_results: List):
return all(res.passed for res in test_results)
Expand All @@ -471,7 +546,7 @@ def build_test_input_dto(client, p, pname, ptype, project_key, uploaded_uuid_sta
kwargs_param = (
[]
if len(kwargs_params) == 0
else (TestInputDTO(name="kwargs", value="\n".join(kwargs_params), type="Kwargs"))
else [TestInputDTO(name="kwargs", value="\n".join(kwargs_params), type="Kwargs")]
)

return TestInputDTO(
Expand All @@ -498,6 +573,46 @@ def build_test_input_dto(client, p, pname, ptype, project_key, uploaded_uuid_sta
return TestInputDTO(name=pname, value=str(p), type=ptype)


def _build_test_input_json(folder, p, pname, ptype, uploaded_uuid_status: Dict[str, bool]):
if issubclass(type(p), Dataset) or issubclass(type(p), BaseModel):
if _try_save_artifact(p, folder, uploaded_uuid_status):
return {"name": pname, "value": str(p.id), "type": ptype}
else:
return {"name": pname, "value": pname, "is_alias": True, "type": ptype}
elif issubclass(type(p), Artifact):
if not _try_save_artifact(p, folder, uploaded_uuid_status):
return {"name": pname, "value": pname, "is_alias": True, "type": ptype}

kwargs_params = [
f"kwargs[{pname}] = {repr(value)}" for pname, value in p.params.items() if pname not in p.meta.args
]
kwargs_param = (
[] if len(kwargs_params) == 0 else [{"name": "kwargs", "value": "\n".join(kwargs_params), "type": "Kwargs"}]
)

return {
"name": pname,
"value": str(p.meta.uuid),
"type": ptype,
"params": [
_build_test_input_json(
folder,
value,
pname,
p.meta.args[pname].type,
uploaded_uuid_status,
)
for pname, value in p.params.items()
if pname in p.meta.args
]
+ kwargs_param,
}
elif isinstance(p, SuiteInput):
return {"name": pname, "value": p.name, "is_alias": True, "type": ptype}
else:
return {"name": pname, "value": str(p), "type": ptype}


def generate_test_partial(
test_fn: Test,
test_id: Optional[Union[int, str]] = None,
Expand Down Expand Up @@ -679,6 +794,27 @@ def create_test_params(test_partial, kwargs) -> TestParams:
test_params[pname] = kwargs[pname]
return test_params

def save(self, folder: str):
folder_path = Path(folder)
if folder_path.exists() and folder_path.is_file():
raise ValueError(f"{folder_path} is a file, please provide a folder")

folder_path.mkdir(parents=True, exist_ok=True)

if self.name is None:
self.name = "Unnamed test suite"

saved_uuid_status: Dict[str, bool] = dict()

json_content = self._to_json(folder_path, saved_uuid_status)

with open(folder_path / "suite.json", "w") as f:
json.dump(json_content, f)

analytics.track("lib:test_suite:saved")

return self

def upload(self, client: GiskardClient, project_key: Optional[str] = None):
"""Saves the test suite to the Giskard backend and sets its ID.

Expand Down Expand Up @@ -733,6 +869,13 @@ def to_dto(self, client: GiskardClient, project_key: str, uploaded_uuid_status:

return TestSuiteDTO(name=self.name, project_key=project_key, tests=suite_tests, function_inputs=list())

def _to_json(self, folder: Path, saved_uuid_status: Dict[str, bool] = None):
return {
"name": self.name,
"tests": [test._to_json(folder, saved_uuid_status) for test in self.tests],
"function_inputs": [],
}

def add_test(
self,
test_fn: Test,
Expand Down Expand Up @@ -955,6 +1098,26 @@ def download(cls, client: GiskardClient, project_key: str, suite_id: int) -> "Su

return suite

@classmethod
def load(cls, folder: str) -> "Suite":
folder_path = Path(folder)

with open(folder_path / "suite.json", "r") as f:
suite_json = json.load(f)

suite = Suite(name=suite_json.get("name", "Unnamed test suite"))

for test_json in suite_json.get("tests", []):
test_uuid = test_json.get("testUuid")
test_folder = folder_path / test_uuid

test = GiskardTest.load(test_folder, test_uuid, GiskardTest._load_meta_locally(test_folder, test_uuid))

test_arguments = _parse_function_arguments(folder_path, test_json.get("functionInputs").values())
suite.add_test(test(**test_arguments), suite_test_id=test_json.get("id"))

return suite


def contains_tag(func: TestFunctionMeta, tag: str):
return any([t for t in func.tags if t.upper() == tag.upper()])
Expand All @@ -981,3 +1144,19 @@ def _try_upload_artifact(artifact, client, project_key: str, uploaded_uuid_statu
uploaded_uuid_status[artifact_id] = False

return uploaded_uuid_status[artifact_id]


def _try_save_artifact(artifact, path: Path, saved_uuid_status: Dict[str, bool]) -> bool:
artifact_id = serialize_parameter(artifact)

if artifact_id not in saved_uuid_status:
try:
artifact_path = path / artifact_id
artifact_path.mkdir(exist_ok=True)
artifact.save(artifact_path)
saved_uuid_status[artifact_id] = True
except: # noqa NOSONAR
warning(f"Failed to save {str(artifact)} used in the test suite. The test suite will be partially saved.")
saved_uuid_status[artifact_id] = False

return saved_uuid_status[artifact_id]
5 changes: 4 additions & 1 deletion giskard/registry/giskard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from abc import ABC, abstractmethod
from pathlib import Path

import yaml

from giskard.core.core import SMT, TestFunctionMeta
from giskard.core.savable import Artifact
from giskard.core.test_result import TestResult
Expand Down Expand Up @@ -71,7 +73,8 @@ def _load_meta_locally(cls, local_dir, uuid: str) -> Optional[TestFunctionMeta]:
if meta is not None:
return meta

return super()._load_meta_locally(local_dir, uuid)
with open(local_dir / "meta.yaml", "r") as f:
return yaml.load(f, Loader=yaml.Loader)

@classmethod
def load(cls, local_dir: Path, uuid: str, meta: TestFunctionMeta):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_suite.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tempfile
import uuid
from datetime import datetime

Expand Down Expand Up @@ -128,3 +129,22 @@ def test_suite_result_to_dto():

assert dto.results[0].inputs["dataset"] == str(dataset.id)
assert dto.results[0].inputs["threshold"] == str(0.5)


def test_suite_save_and_load(german_credit_data, german_credit_model):
my_test = test_accuracy(threshold=0.7)

suite = Suite()
suite.add_test(my_test)

with tempfile.TemporaryDirectory() as tmp_dirname:
suite.save(tmp_dirname)
loaded_suite = Suite.load(tmp_dirname)

result = loaded_suite.run(model=german_credit_model, dataset=german_credit_data)

assert result.passed
assert len(result.results) == 1
_, test_result, _ = result.results[0]
assert not test_result.is_error
assert test_result.passed