Skip to content
31 changes: 6 additions & 25 deletions giskard/core/savable.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,31 +53,17 @@ def _get_name(cls) -> str:
return f"{cls.__class__.__name__.lower()}s"

@classmethod
def _get_meta_endpoint(cls, uuid: str, project_key: Optional[str]) -> str:
if project_key is None:
return posixpath.join(cls._get_name(), uuid)
else:
return posixpath.join("project", project_key, cls._get_name(), uuid)
def _get_meta_endpoint(cls, uuid: str, project_key: str) -> str:
return posixpath.join("project", project_key, cls._get_name(), uuid)

def _save_meta_locally(self, local_dir):
with open(Path(local_dir) / "meta.yaml", "w") as f:
yaml.dump(self.meta, f)

@classmethod
def _load_meta_locally(cls, local_dir, uuid: str) -> Optional[SMT]:
file = Path(local_dir) / "meta.yaml"
if not file.exists():
return None

with open(file, "r") as f:
# PyYAML prohibits the arbitary execution so our class cannot be loaded safely,
# see: https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation
return yaml.load(f, Loader=yaml.UnsafeLoader)

def upload(
self,
client: GiskardClient,
project_key: Optional[str] = None,
project_key: str,
uploaded_dependencies: Optional[Set["Artifact"]] = None,
) -> str:
"""
Expand Down Expand Up @@ -114,14 +100,13 @@ def upload(
return self.meta.uuid

@classmethod
def download(cls, uuid: str, client: Optional[GiskardClient], project_key: Optional[str]) -> "Artifact":
def download(cls, uuid: str, client: GiskardClient, project_key: str) -> "Artifact":
"""
Downloads the artifact from the Giskard hub or retrieves it from the local cache.

Args:
uuid (str): The UUID of the artifact to download.
client (Optional[GiskardClient]): The Giskard client instance used for communication with the hub. If None,
the artifact will be retrieved from the local cache if available. Defaults to None.
client (GiskardClient): The Giskard client instance used for communication with the hub.
project_key (Optional[str]): The project key where the artifact is located. If None, the artifact will be
retrieved from the global scope. Defaults to None.

Expand All @@ -135,11 +120,7 @@ def download(cls, uuid: str, client: Optional[GiskardClient], project_key: Optio
name = cls._get_name()

local_dir = settings.home_dir / settings.cache_dir / name / uuid

if client is None:
meta = cls._load_meta_locally(local_dir, uuid)
else:
meta = client.load_meta(cls._get_meta_endpoint(uuid, project_key), cls._get_meta_class())
meta = client.load_meta(cls._get_meta_endpoint(uuid, project_key), cls._get_meta_class())

assert meta is not None, "Could not retrieve test meta"

Expand Down
4 changes: 2 additions & 2 deletions giskard/core/suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def to_dto(self, client: GiskardClient, project_key: str, uploaded_uuid_status:

return SuiteTestDTO(
id=self.suite_test_id,
testUuid=self.giskard_test.upload(client),
testUuid=self.giskard_test.upload(client, project_key),
functionInputs=params,
displayName=self.display_name,
)
Expand Down Expand Up @@ -935,7 +935,7 @@ def download(cls, client: GiskardClient, project_key: str, suite_id: int) -> "Su
suite.project_key = project_key

for test_json in suite_dto.tests:
test = GiskardTest.download(test_json.testUuid, client, None)
test = GiskardTest.download(test_json.testUuid, client, project_key)
test_arguments = parse_function_arguments(client, project_key, test_json.functionInputs.values())
suite.add_test(test(**test_arguments), suite_test_id=test_json.id)

Expand Down
61 changes: 22 additions & 39 deletions giskard/ml_worker/websocket/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
do_create_sub_dataset,
do_run_adhoc_test,
function_argument_to_ws,
log_artifact_local,
map_dataset_process_function_meta_ws,
map_function_meta_ws,
map_result_to_single_test_result_ws,
Expand Down Expand Up @@ -134,12 +133,12 @@ def parse_and_execute(
action: MLWorkerAction,
params,
ml_worker: MLWorkerInfo,
client_params: Optional[Dict[str, str]],
client_params: Dict[str, str],
) -> websocket.WorkerReply:
action_params = parse_action_param(action, params)
return callback(
ml_worker=ml_worker,
client=GiskardClient(**client_params) if client_params is not None else None,
client=GiskardClient(**client_params),
action=action.name,
params=action_params,
)
Expand Down Expand Up @@ -314,7 +313,7 @@ def run_other_model(dataset, prediction_results, is_text_generation):


@websocket_actor(MLWorkerAction.runModel)
def run_model(client: Optional[GiskardClient], params: websocket.RunModelParam, *args, **kwargs) -> websocket.Empty:
def run_model(client: GiskardClient, params: websocket.RunModelParam, *args, **kwargs) -> websocket.Empty:
try:
model = BaseModel.download(client, params.model.project_key, params.model.id)
dataset = Dataset.download(
Expand Down Expand Up @@ -349,35 +348,23 @@ def run_model(client: Optional[GiskardClient], params: websocket.RunModelParam,
tmp_dir = Path(f)
predictions_csv = get_file_name("predictions", "csv", params.dataset.sample)
results.to_csv(index=False, path_or_buf=tmp_dir / predictions_csv)
if client:
client.log_artifact(
tmp_dir / predictions_csv,
f"models/inspections/{params.inspectionId}",
)
else:
log_artifact_local(
tmp_dir / predictions_csv,
f"models/inspections/{params.inspectionId}",
)
client.log_artifact(
tmp_dir / predictions_csv,
f"models/inspections/{params.inspectionId}",
)

calculated_csv = get_file_name("calculated", "csv", params.dataset.sample)
calculated.to_csv(index=False, path_or_buf=tmp_dir / calculated_csv)
if client:
client.log_artifact(
tmp_dir / calculated_csv,
f"models/inspections/{params.inspectionId}",
)
else:
log_artifact_local(
tmp_dir / calculated_csv,
f"models/inspections/{params.inspectionId}",
)
client.log_artifact(
tmp_dir / calculated_csv,
f"models/inspections/{params.inspectionId}",
)
return websocket.Empty()


@websocket_actor(MLWorkerAction.runModelForDataFrame)
def run_model_for_data_frame(
client: Optional[GiskardClient], params: websocket.RunModelForDataFrameParam, *args, **kwargs
client: GiskardClient, params: websocket.RunModelForDataFrameParam, *args, **kwargs
) -> websocket.RunModelForDataFrame:
model = BaseModel.download(client, params.model.project_key, params.model.id)
df = pd.DataFrame.from_records([r.columns for r in params.dataframe.rows])
Expand Down Expand Up @@ -407,7 +394,7 @@ def run_model_for_data_frame(


@websocket_actor(MLWorkerAction.explain)
def explain_ws(client: Optional[GiskardClient], params: websocket.ExplainParam, *args, **kwargs) -> websocket.Explain:
def explain_ws(client: GiskardClient, params: websocket.ExplainParam, *args, **kwargs) -> websocket.Explain:
model = BaseModel.download(client, params.model.project_key, params.model.id)
dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id, params.dataset.sample)
explanations = explain(model, dataset, params.columns)
Expand All @@ -419,7 +406,7 @@ def explain_ws(client: Optional[GiskardClient], params: websocket.ExplainParam,

@websocket_actor(MLWorkerAction.explainText)
def explain_text_ws(
client: Optional[GiskardClient], params: websocket.ExplainTextParam, *args, **kwargs
client: GiskardClient, params: websocket.ExplainTextParam, *args, **kwargs
) -> websocket.ExplainText:
model = BaseModel.download(client, params.model.project_key, params.model.id)
text_column = params.feature_name
Expand Down Expand Up @@ -460,7 +447,7 @@ def get_catalog(*args, **kwargs) -> websocket.Catalog:

@websocket_actor(MLWorkerAction.datasetProcessing)
def dataset_processing(
client: Optional[GiskardClient], params: websocket.DatasetProcessingParam, *args, **kwargs
client: GiskardClient, params: websocket.DatasetProcessingParam, *args, **kwargs
) -> websocket.DatasetProcessing:
dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id, params.dataset.sample)

Expand Down Expand Up @@ -500,7 +487,7 @@ def dataset_processing(

@websocket_actor(MLWorkerAction.runAdHocTest)
def run_ad_hoc_test(
client: Optional[GiskardClient], params: websocket.RunAdHocTestParam, *args, **kwargs
client: GiskardClient, params: websocket.RunAdHocTestParam, *args, **kwargs
) -> websocket.RunAdHocTest:
test: GiskardTest = GiskardTest.download(params.testUuid, client, params.projectKey)

Expand All @@ -525,9 +512,7 @@ def run_ad_hoc_test(


@websocket_actor(MLWorkerAction.runTestSuite)
def run_test_suite(
client: Optional[GiskardClient], params: websocket.TestSuiteParam, *args, **kwargs
) -> websocket.TestSuite:
def run_test_suite(client: GiskardClient, params: websocket.TestSuiteParam, *args, **kwargs) -> websocket.TestSuite:
loaded_artifacts = defaultdict(dict)

try:
Expand Down Expand Up @@ -594,7 +579,7 @@ def echo(params: websocket.EchoMsg, *args, **kwargs) -> websocket.EchoResponse:


def handle_cta(
client: Optional[GiskardClient],
client: GiskardClient,
params: websocket.GetPushParam,
push: Optional[Push],
push_kind: PushKind,
Expand Down Expand Up @@ -635,9 +620,7 @@ def handle_cta(


@websocket_actor(MLWorkerAction.getPush, timeout=30, ignore_timeout=True)
def get_push(
client: Optional[GiskardClient], params: websocket.GetPushParam, *args, **kwargs
) -> websocket.GetPushResponse:
def get_push(client: GiskardClient, params: websocket.GetPushParam, *args, **kwargs) -> websocket.GetPushResponse:
# Save cta_kind and push_kind and remove it from params
cta_kind = params.cta_kind
push_kind = params.push_kind
Expand Down Expand Up @@ -690,7 +673,7 @@ def push_to_ws(push: Push):
return push.to_ws() if push is not None else None


def get_push_objects(client: Optional[GiskardClient], params: websocket.GetPushParam):
def get_push_objects(client: GiskardClient, params: websocket.GetPushParam):
try:
model = BaseModel.download(client, params.model.project_key, params.model.id)
dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id)
Expand Down Expand Up @@ -735,7 +718,7 @@ def get_push_objects(client: Optional[GiskardClient], params: websocket.GetPushP

@websocket_actor(MLWorkerAction.createSubDataset)
def create_sub_dataset(
client: Optional[GiskardClient], params: websocket.CreateSubDatasetParam, *arg, **kwargs
client: GiskardClient, params: websocket.CreateSubDatasetParam, *arg, **kwargs
) -> websocket.CreateSubDataset:
datasets = {
dateset_id: Dataset.download(
Expand All @@ -751,7 +734,7 @@ def create_sub_dataset(

@websocket_actor(MLWorkerAction.createDataset)
def create_dataset(
client: Optional[GiskardClient], params: websocket.CreateDatasetParam, *arg, **kwargs
client: GiskardClient, params: websocket.CreateDatasetParam, *arg, **kwargs
) -> websocket.CreateSubDataset:
dataset = do_create_dataset(params.name, params.headers, params.rows)

Expand Down
26 changes: 2 additions & 24 deletions giskard/ml_worker/websocket/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from typing import Any, Callable, Dict, List, Optional

import logging
import os
import shutil
import uuid
from collections import defaultdict

import pandas as pd
from mlflow.store.artifact.artifact_repo import verify_artifact_path

from giskard.client.giskard_client import GiskardClient
from giskard.core.suite import DatasetInput, ModelInput, SuiteInput
Expand All @@ -34,7 +31,6 @@
)
from giskard.ml_worker.websocket.action import MLWorkerAction
from giskard.models.base import BaseModel
from giskard.path_utils import artifacts_dir
from giskard.registry.registry import tests_registry
from giskard.registry.slicing_function import SlicingFunction
from giskard.registry.transformation_function import TransformationFunction
Expand Down Expand Up @@ -126,21 +122,6 @@ def map_function_meta_ws(callable_type):
}


def log_artifact_local(local_file, artifact_path=None):
# Log artifact locally from an internal worker
verify_artifact_path(artifact_path)

file_name = os.path.basename(local_file)

if artifact_path:
artifact_file = artifacts_dir / artifact_path / file_name
else:
artifact_file = artifacts_dir / file_name
artifact_file.parent.mkdir(parents=True, exist_ok=True)

shutil.copy(local_file, artifact_file)


def map_dataset_process_function_meta_ws(callable_type):
return {
test.uuid: websocket.DatasetProcessFunctionMeta(
Expand Down Expand Up @@ -182,7 +163,7 @@ def _get_or_load(loaded_artifacts: Dict[str, Dict[str, Any]], type: str, uuid: s


def parse_function_arguments(
client: Optional[GiskardClient],
client: GiskardClient,
request_arguments: List[websocket.FuncArgument],
loaded_artifacts: Optional[Dict[str, Dict[str, Any]]] = None,
):
Expand Down Expand Up @@ -245,7 +226,7 @@ def parse_function_arguments(
def map_result_to_single_test_result_ws(
result,
datasets: Dict[uuid.UUID, Dataset],
client: Optional[GiskardClient] = None,
client: GiskardClient,
project_key: Optional[str] = None,
) -> websocket.SingleTestResult:
if isinstance(result, TestResult):
Expand Down Expand Up @@ -302,9 +283,6 @@ def _upload_generated_output_df(client, datasets, project_key, result):
)

if result.output_df.original_id not in datasets.keys():
if not client:
raise RuntimeError("Legacy test debugging using `output_df` is not supported internal ML worker")

if not project_key:
raise ValueError("Unable to upload debug dataset due to missing `project_key`")

Expand Down
43 changes: 19 additions & 24 deletions giskard/models/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def upload(self, client: GiskardClient, project_key, validate_ds=None, *_args, *
return str(self.id)

@classmethod
def download(cls, client: Optional[GiskardClient], project_key, model_id, *_args, **_kwargs):
def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):
"""
Downloads the specified model from the Giskard hub and loads it into memory.

Expand All @@ -480,29 +480,24 @@ def download(cls, client: Optional[GiskardClient], project_key, model_id, *_args
AssertionError: If the local directory where the model should be saved does not exist.
"""
local_dir = settings.home_dir / settings.cache_dir / "models" / model_id
if client is None:
# internal worker case, no token based http client [deprecated, to be removed]
assert local_dir.exists(), f"Cannot find existing model {project_key}.{model_id} in {local_dir}"
meta_response, meta = cls.read_meta_from_local_dir(local_dir)
else:
client.load_artifact(local_dir, posixpath.join("models", model_id))
meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)
# internal worker case, no token based http client
if not local_dir.exists():
raise RuntimeError(f"Cannot find existing model {project_key}.{model_id} in {local_dir}")
with (Path(local_dir) / META_FILENAME).open(encoding="utf-8") as f:
file_meta = yaml.load(f, Loader=yaml.Loader)
classification_labels = cls.cast_labels(meta_response)
meta = ModelMeta(
name=meta_response.name,
description=meta_response.description,
model_type=SupportedModelTypes[meta_response.modelType],
feature_names=meta_response.featureNames,
classification_labels=classification_labels,
classification_threshold=meta_response.threshold,
loader_module=file_meta["loader_module"],
loader_class=file_meta["loader_class"],
)
client.load_artifact(local_dir, posixpath.join("models", model_id))
meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)
# internal worker case, no token based http client
if not local_dir.exists():
raise RuntimeError(f"Cannot find existing model {project_key}.{model_id} in {local_dir}")
with (Path(local_dir) / META_FILENAME).open(encoding="utf-8") as f:
file_meta = yaml.load(f, Loader=yaml.Loader)
classification_labels = cls.cast_labels(meta_response)
meta = ModelMeta(
name=meta_response.name,
description=meta_response.description,
model_type=SupportedModelTypes[meta_response.modelType],
feature_names=meta_response.featureNames,
classification_labels=classification_labels,
classification_threshold=meta_response.threshold,
loader_module=file_meta["loader_module"],
loader_class=file_meta["loader_class"],
)

model_py_ver = (
tuple(meta_response.languageVersion.split(".")) if "PYTHON" == meta_response.language.upper() else None
Expand Down
Loading