Skip to content
8 changes: 6 additions & 2 deletions giskard/ml_worker/websocket/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
import time
import traceback
from collections import defaultdict
from concurrent.futures import CancelledError, Future
from copy import copy
from dataclasses import dataclass
Expand Down Expand Up @@ -520,17 +521,20 @@ def run_test_suite(
client: Optional[GiskardClient], params: websocket.TestSuiteParam, *args, **kwargs
) -> websocket.TestSuite:
log_listener = LogListener()

loaded_artifacts = defaultdict(dict)

try:
tests = [
{
"test": GiskardTest.download(t.testUuid, client, None),
"arguments": parse_function_arguments(client, t.arguments),
"arguments": parse_function_arguments(client, t.arguments, loaded_artifacts),
"id": t.id,
}
for t in params.tests
]

global_arguments = parse_function_arguments(client, params.globalArguments)
global_arguments = parse_function_arguments(client, params.globalArguments, loaded_artifacts)

datasets = {arg.original_id: arg for arg in global_arguments.values() if isinstance(arg, Dataset)}
for test in tests:
Expand Down
43 changes: 34 additions & 9 deletions giskard/ml_worker/websocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional
from collections import defaultdict

import pandas as pd
from mlflow.store.artifact.artifact_repo import verify_artifact_path
from typing import Any, Dict, List, Optional, Callable

from giskard.client.giskard_client import GiskardClient
from giskard.core.suite import DatasetInput, ModelInput, SuiteInput
Expand Down Expand Up @@ -158,7 +159,21 @@ def map_dataset_process_function_meta_ws(callable_type):
}


def parse_function_arguments(client: Optional[GiskardClient], request_arguments: List[websocket.FuncArgument]):
def _get_or_load(loaded_artifacts: Dict[str, Dict[str, Any]], type: str, uuid: str, load_fn: Callable[[], Any]) -> Any:
if uuid not in loaded_artifacts[type]:
loaded_artifacts[type][uuid] = load_fn()

return loaded_artifacts[type][uuid]


def parse_function_arguments(
client: Optional[GiskardClient],
request_arguments: List[websocket.FuncArgument],
loaded_artifacts: Optional[Dict[str, Dict[str, Any]]] = None,
):
if loaded_artifacts is None:
loaded_artifacts = defaultdict(dict)

arguments = dict()

# Processing empty list
Expand All @@ -169,22 +184,32 @@ def parse_function_arguments(client: Optional[GiskardClient], request_arguments:
if arg.is_none:
continue
if arg.dataset is not None:
arguments[arg.name] = Dataset.download(
client,
arg.dataset.project_key,
arguments[arg.name] = _get_or_load(
loaded_artifacts,
"Dataset",
arg.dataset.id,
arg.dataset.sample,
lambda: Dataset.download(
client,
arg.dataset.project_key,
arg.dataset.id,
arg.dataset.sample,
),
)
elif arg.model is not None:
arguments[arg.name] = BaseModel.download(client, arg.model.project_key, arg.model.id)
arguments[arg.name] = _get_or_load(
loaded_artifacts,
"BaseModel",
arg.model.id,
lambda: BaseModel.download(client, arg.model.project_key, arg.model.id),
)
elif arg.slicingFunction is not None:
arguments[arg.name] = SlicingFunction.download(
arg.slicingFunction.id, client, arg.slicingFunction.project_key
)(**parse_function_arguments(client, arg.args))
)(**parse_function_arguments(client, arg.args, loaded_artifacts))
elif arg.transformationFunction is not None:
arguments[arg.name] = TransformationFunction.download(
arg.transformationFunction.id, client, arg.transformationFunction.project_key
)(**parse_function_arguments(client, arg.args))
)(**parse_function_arguments(client, arg.args, loaded_artifacts))
elif arg.float_arg is not None:
arguments[arg.name] = float(arg.float_arg)
elif arg.int_arg is not None:
Expand Down
7 changes: 6 additions & 1 deletion giskard/models/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@ def __init__(
if len(classification_labels) != len(set(classification_labels)):
raise ValueError("Duplicates are found in 'classification_labels', please only provide unique values.")

self._cache = ModelCache(model_type, str(self.id), cache_dir=kwargs.get("prediction_cache_dir"))
self._cache = ModelCache(
model_type,
str(self.id),
persist_cache=kwargs.get("persist_cache", False),
cache_dir=kwargs.get("prediction_cache_dir"),
)

# sklearn and catboost will fill classification_labels before this check
if model_type == SupportedModelTypes.CLASSIFICATION and not classification_labels:
Expand Down
19 changes: 14 additions & 5 deletions giskard/models/cache/cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import csv
from pathlib import Path
from typing import Any, Iterable, List, Optional

import numpy as np
import pandas as pd
from typing import Any, Iterable, List, Optional

from ...client.python_utils import warning
from ...core.core import SupportedModelTypes
Expand All @@ -26,14 +26,23 @@ def flatten(xs):
class ModelCache:
_default_cache_dir_prefix = Path(settings.home_dir / settings.cache_dir / "global" / "prediction_cache")

def __init__(self, model_type: SupportedModelTypes, id: Optional[str] = None, cache_dir: Optional[Path] = None):
def __init__(
self,
model_type: SupportedModelTypes,
id: Optional[str] = None,
persist_cache: bool = False,
cache_dir: Optional[Path] = None,
):
self.id = id
self.prediction_cache = dict()

if cache_dir is None and self.id:
cache_dir = self._default_cache_dir_prefix.joinpath(self.id)
if persist_cache:
if cache_dir is None and self.id:
cache_dir = self._default_cache_dir_prefix / self.id

self.cache_file = cache_dir / CACHE_CSV_FILENAME if cache_dir else None
self.cache_file = cache_dir / CACHE_CSV_FILENAME if cache_dir else None
else:
self.cache_file = None

self.vectorized_get_cache_or_na = np.vectorize(self.get_cache_or_na, otypes=[object])
self.model_type = model_type
Expand Down
101 changes: 101 additions & 0 deletions tests/communications/test_websocket_actor_tests.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import random
import uuid

import pandas as pd
import pytest

import giskard
from giskard import test
from giskard.datasets.base import Dataset
from giskard.ml_worker import websocket
from giskard.ml_worker.testing.test_result import TestResult as GiskardTestResult, TestMessage, TestMessageLevel
from giskard.ml_worker.websocket import listener
from giskard.models.base import BaseModel
from giskard.testing.tests import debug_prefix
from tests import utils

Expand Down Expand Up @@ -62,6 +65,13 @@ def my_simple_test_legacy_debug(dataset: Dataset, debug: bool = False):
return GiskardTestResult(passed=False, output_df=output_ds)


@giskard.test()
def same_prediction(left: BaseModel, right: BaseModel, ds: giskard.Dataset):
left_pred = left.predict(ds)
right_pred = right.predict(ds)
return giskard.TestResult(passed=list(left_pred.raw_prediction) == list(right_pred.raw_prediction))


def test_websocket_actor_run_ad_hoc_test_legacy_debug(enron_data: Dataset):
project_key = str(uuid.uuid4())

Expand Down Expand Up @@ -343,6 +353,97 @@ def test_websocket_actor_run_test_suite():
assert not reply.results[2].result.passed


def test_websocket_actor_run_test_suite_share_models_and_dataset_instance():
def random_prediction(df):
return [random.randint(0, 9) for _ in df.index]

# Use random model to ensure model prediction cache is shared (same instance loaded)
random_model = giskard.Model(random_prediction, "regression", feature_names=["feature"])
mock_dataset = giskard.Dataset(pd.DataFrame({"feature": range(100)}))

with utils.MockedClient(mock_all=False) as (client, mr):
params = websocket.TestSuiteParam(
projectKey=str(uuid.uuid4()),
tests=[
websocket.SuiteTestArgument(
id=0,
testUuid=same_prediction.meta.uuid,
arguments=[
websocket.FuncArgument(
name="left",
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
none=False,
),
websocket.FuncArgument(
name="right",
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
none=False,
),
websocket.FuncArgument(
name="ds",
dataset=websocket.ArtifactRef(project_key="project_key", id=str(mock_dataset.id)),
none=False,
),
],
),
websocket.SuiteTestArgument(
id=1,
testUuid=same_prediction.meta.uuid,
arguments=[
websocket.FuncArgument(
name="left",
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
none=False,
)
],
),
websocket.SuiteTestArgument(
id=2,
testUuid=same_prediction.meta.uuid,
arguments=[
websocket.FuncArgument(
name="right",
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
none=False,
)
],
),
websocket.SuiteTestArgument(id=2, testUuid=same_prediction.meta.uuid, arguments=[]),
],
globalArguments=[
websocket.FuncArgument(
name="left",
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
none=False,
),
websocket.FuncArgument(
name="right",
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
none=False,
),
websocket.FuncArgument(
name="ds",
dataset=websocket.ArtifactRef(project_key="project_key", id=str(mock_dataset.id)),
none=False,
),
],
)
utils.register_uri_for_artifact_meta_info(mr, same_prediction, None)

utils.register_uri_for_model_meta_info(mr, random_model, "project_key")
utils.register_uri_for_model_artifact_info(mr, random_model, "project_key", register_file_contents=True)

utils.register_uri_for_dataset_meta_info(mr, mock_dataset, "project_key")
utils.register_uri_for_dataset_artifact_info(mr, mock_dataset, "project_key", register_file_contents=True)

reply = listener.run_test_suite(client, params)

assert isinstance(reply, websocket.TestSuite)
assert not reply.is_error
assert reply.is_pass
assert 4 == len(reply.results)


def test_websocket_actor_run_test_suite_raise_error():
with utils.MockedClient(mock_all=False) as (client, mr):
params = websocket.TestSuiteParam(
Expand Down
5 changes: 4 additions & 1 deletion tests/models/test_model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import pandas as pd
import pytest
import xxhash
from langchain import LLMChain, PromptTemplate
from langchain.llms.fake import FakeListLLM

import giskard
from giskard import Dataset, Model
from giskard.core.core import SupportedModelTypes
from giskard.models.cache import ModelCache
from langchain import LLMChain, PromptTemplate


# https://symbl.cc/fr/unicode/blocks/

Expand All @@ -31,6 +32,7 @@ def test_unicode_prediction(keys, values):
with TemporaryDirectory() as temp_cache_dir:
cache = ModelCache(
model_type=SupportedModelTypes.TEXT_GENERATION,
persist_cache=True,
cache_dir=Path(temp_cache_dir),
)
key_series = pd.Series(keys)
Expand All @@ -43,6 +45,7 @@ def test_unicode_prediction(keys, values):
warmed_up_cache = ModelCache(
id="warmed_up",
model_type=SupportedModelTypes.TEXT_GENERATION,
persist_cache=True,
cache_dir=Path(temp_cache_dir),
)
# Ensure warm up works fine
Expand Down