diff --git a/giskard/core/savable.py b/giskard/core/savable.py index 25cb2b2a55..27f988782b 100644 --- a/giskard/core/savable.py +++ b/giskard/core/savable.py @@ -53,31 +53,17 @@ def _get_name(cls) -> str: return f"{cls.__class__.__name__.lower()}s" @classmethod - def _get_meta_endpoint(cls, uuid: str, project_key: Optional[str]) -> str: - if project_key is None: - return posixpath.join(cls._get_name(), uuid) - else: - return posixpath.join("project", project_key, cls._get_name(), uuid) + def _get_meta_endpoint(cls, uuid: str, project_key: str) -> str: + return posixpath.join("project", project_key, cls._get_name(), uuid) def _save_meta_locally(self, local_dir): with open(Path(local_dir) / "meta.yaml", "w") as f: yaml.dump(self.meta, f) - @classmethod - def _load_meta_locally(cls, local_dir, uuid: str) -> Optional[SMT]: - file = Path(local_dir) / "meta.yaml" - if not file.exists(): - return None - - with open(file, "r") as f: - # PyYAML prohibits the arbitary execution so our class cannot be loaded safely, - # see: https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation - return yaml.load(f, Loader=yaml.UnsafeLoader) - def upload( self, client: GiskardClient, - project_key: Optional[str] = None, + project_key: str, uploaded_dependencies: Optional[Set["Artifact"]] = None, ) -> str: """ @@ -114,14 +100,13 @@ def upload( return self.meta.uuid @classmethod - def download(cls, uuid: str, client: Optional[GiskardClient], project_key: Optional[str]) -> "Artifact": + def download(cls, uuid: str, client: GiskardClient, project_key: str) -> "Artifact": """ Downloads the artifact from the Giskard hub or retrieves it from the local cache. Args: uuid (str): The UUID of the artifact to download. - client (Optional[GiskardClient]): The Giskard client instance used for communication with the hub. If None, - the artifact will be retrieved from the local cache if available. Defaults to None. + client (GiskardClient): The Giskard client instance used for communication with the hub. project_key (Optional[str]): The project key where the artifact is located. If None, the artifact will be retrieved from the global scope. Defaults to None. @@ -135,11 +120,7 @@ def download(cls, uuid: str, client: Optional[GiskardClient], project_key: Optio name = cls._get_name() local_dir = settings.home_dir / settings.cache_dir / name / uuid - - if client is None: - meta = cls._load_meta_locally(local_dir, uuid) - else: - meta = client.load_meta(cls._get_meta_endpoint(uuid, project_key), cls._get_meta_class()) + meta = client.load_meta(cls._get_meta_endpoint(uuid, project_key), cls._get_meta_class()) assert meta is not None, "Could not retrieve test meta" diff --git a/giskard/core/suite.py b/giskard/core/suite.py index e20c3367f8..cbe6970c71 100644 --- a/giskard/core/suite.py +++ b/giskard/core/suite.py @@ -431,7 +431,7 @@ def to_dto(self, client: GiskardClient, project_key: str, uploaded_uuid_status: return SuiteTestDTO( id=self.suite_test_id, - testUuid=self.giskard_test.upload(client), + testUuid=self.giskard_test.upload(client, project_key), functionInputs=params, displayName=self.display_name, ) @@ -935,7 +935,7 @@ def download(cls, client: GiskardClient, project_key: str, suite_id: int) -> "Su suite.project_key = project_key for test_json in suite_dto.tests: - test = GiskardTest.download(test_json.testUuid, client, None) + test = GiskardTest.download(test_json.testUuid, client, project_key) test_arguments = parse_function_arguments(client, project_key, test_json.functionInputs.values()) suite.add_test(test(**test_arguments), suite_test_id=test_json.id) diff --git a/giskard/ml_worker/websocket/listener.py b/giskard/ml_worker/websocket/listener.py index 1332fac59b..29bf0c0a10 100644 --- a/giskard/ml_worker/websocket/listener.py +++ b/giskard/ml_worker/websocket/listener.py @@ -36,7 +36,6 @@ do_create_sub_dataset, do_run_adhoc_test, function_argument_to_ws, - log_artifact_local, map_dataset_process_function_meta_ws, map_function_meta_ws, map_result_to_single_test_result_ws, @@ -134,12 +133,12 @@ def parse_and_execute( action: MLWorkerAction, params, ml_worker: MLWorkerInfo, - client_params: Optional[Dict[str, str]], + client_params: Dict[str, str], ) -> websocket.WorkerReply: action_params = parse_action_param(action, params) return callback( ml_worker=ml_worker, - client=GiskardClient(**client_params) if client_params is not None else None, + client=GiskardClient(**client_params), action=action.name, params=action_params, ) @@ -314,7 +313,7 @@ def run_other_model(dataset, prediction_results, is_text_generation): @websocket_actor(MLWorkerAction.runModel) -def run_model(client: Optional[GiskardClient], params: websocket.RunModelParam, *args, **kwargs) -> websocket.Empty: +def run_model(client: GiskardClient, params: websocket.RunModelParam, *args, **kwargs) -> websocket.Empty: try: model = BaseModel.download(client, params.model.project_key, params.model.id) dataset = Dataset.download( @@ -349,35 +348,23 @@ def run_model(client: Optional[GiskardClient], params: websocket.RunModelParam, tmp_dir = Path(f) predictions_csv = get_file_name("predictions", "csv", params.dataset.sample) results.to_csv(index=False, path_or_buf=tmp_dir / predictions_csv) - if client: - client.log_artifact( - tmp_dir / predictions_csv, - f"models/inspections/{params.inspectionId}", - ) - else: - log_artifact_local( - tmp_dir / predictions_csv, - f"models/inspections/{params.inspectionId}", - ) + client.log_artifact( + tmp_dir / predictions_csv, + f"models/inspections/{params.inspectionId}", + ) calculated_csv = get_file_name("calculated", "csv", params.dataset.sample) calculated.to_csv(index=False, path_or_buf=tmp_dir / calculated_csv) - if client: - client.log_artifact( - tmp_dir / calculated_csv, - f"models/inspections/{params.inspectionId}", - ) - else: - log_artifact_local( - tmp_dir / calculated_csv, - f"models/inspections/{params.inspectionId}", - ) + client.log_artifact( + tmp_dir / calculated_csv, + f"models/inspections/{params.inspectionId}", + ) return websocket.Empty() @websocket_actor(MLWorkerAction.runModelForDataFrame) def run_model_for_data_frame( - client: Optional[GiskardClient], params: websocket.RunModelForDataFrameParam, *args, **kwargs + client: GiskardClient, params: websocket.RunModelForDataFrameParam, *args, **kwargs ) -> websocket.RunModelForDataFrame: model = BaseModel.download(client, params.model.project_key, params.model.id) df = pd.DataFrame.from_records([r.columns for r in params.dataframe.rows]) @@ -407,7 +394,7 @@ def run_model_for_data_frame( @websocket_actor(MLWorkerAction.explain) -def explain_ws(client: Optional[GiskardClient], params: websocket.ExplainParam, *args, **kwargs) -> websocket.Explain: +def explain_ws(client: GiskardClient, params: websocket.ExplainParam, *args, **kwargs) -> websocket.Explain: model = BaseModel.download(client, params.model.project_key, params.model.id) dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id, params.dataset.sample) explanations = explain(model, dataset, params.columns) @@ -419,7 +406,7 @@ def explain_ws(client: Optional[GiskardClient], params: websocket.ExplainParam, @websocket_actor(MLWorkerAction.explainText) def explain_text_ws( - client: Optional[GiskardClient], params: websocket.ExplainTextParam, *args, **kwargs + client: GiskardClient, params: websocket.ExplainTextParam, *args, **kwargs ) -> websocket.ExplainText: model = BaseModel.download(client, params.model.project_key, params.model.id) text_column = params.feature_name @@ -460,7 +447,7 @@ def get_catalog(*args, **kwargs) -> websocket.Catalog: @websocket_actor(MLWorkerAction.datasetProcessing) def dataset_processing( - client: Optional[GiskardClient], params: websocket.DatasetProcessingParam, *args, **kwargs + client: GiskardClient, params: websocket.DatasetProcessingParam, *args, **kwargs ) -> websocket.DatasetProcessing: dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id, params.dataset.sample) @@ -500,7 +487,7 @@ def dataset_processing( @websocket_actor(MLWorkerAction.runAdHocTest) def run_ad_hoc_test( - client: Optional[GiskardClient], params: websocket.RunAdHocTestParam, *args, **kwargs + client: GiskardClient, params: websocket.RunAdHocTestParam, *args, **kwargs ) -> websocket.RunAdHocTest: test: GiskardTest = GiskardTest.download(params.testUuid, client, params.projectKey) @@ -525,9 +512,7 @@ def run_ad_hoc_test( @websocket_actor(MLWorkerAction.runTestSuite) -def run_test_suite( - client: Optional[GiskardClient], params: websocket.TestSuiteParam, *args, **kwargs -) -> websocket.TestSuite: +def run_test_suite(client: GiskardClient, params: websocket.TestSuiteParam, *args, **kwargs) -> websocket.TestSuite: loaded_artifacts = defaultdict(dict) try: @@ -594,7 +579,7 @@ def echo(params: websocket.EchoMsg, *args, **kwargs) -> websocket.EchoResponse: def handle_cta( - client: Optional[GiskardClient], + client: GiskardClient, params: websocket.GetPushParam, push: Optional[Push], push_kind: PushKind, @@ -635,9 +620,7 @@ def handle_cta( @websocket_actor(MLWorkerAction.getPush, timeout=30, ignore_timeout=True) -def get_push( - client: Optional[GiskardClient], params: websocket.GetPushParam, *args, **kwargs -) -> websocket.GetPushResponse: +def get_push(client: GiskardClient, params: websocket.GetPushParam, *args, **kwargs) -> websocket.GetPushResponse: # Save cta_kind and push_kind and remove it from params cta_kind = params.cta_kind push_kind = params.push_kind @@ -690,7 +673,7 @@ def push_to_ws(push: Push): return push.to_ws() if push is not None else None -def get_push_objects(client: Optional[GiskardClient], params: websocket.GetPushParam): +def get_push_objects(client: GiskardClient, params: websocket.GetPushParam): try: model = BaseModel.download(client, params.model.project_key, params.model.id) dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id) @@ -735,7 +718,7 @@ def get_push_objects(client: Optional[GiskardClient], params: websocket.GetPushP @websocket_actor(MLWorkerAction.createSubDataset) def create_sub_dataset( - client: Optional[GiskardClient], params: websocket.CreateSubDatasetParam, *arg, **kwargs + client: GiskardClient, params: websocket.CreateSubDatasetParam, *arg, **kwargs ) -> websocket.CreateSubDataset: datasets = { dateset_id: Dataset.download( @@ -751,7 +734,7 @@ def create_sub_dataset( @websocket_actor(MLWorkerAction.createDataset) def create_dataset( - client: Optional[GiskardClient], params: websocket.CreateDatasetParam, *arg, **kwargs + client: GiskardClient, params: websocket.CreateDatasetParam, *arg, **kwargs ) -> websocket.CreateSubDataset: dataset = do_create_dataset(params.name, params.headers, params.rows) diff --git a/giskard/ml_worker/websocket/utils.py b/giskard/ml_worker/websocket/utils.py index d1c7734ca5..5769d1448f 100644 --- a/giskard/ml_worker/websocket/utils.py +++ b/giskard/ml_worker/websocket/utils.py @@ -1,13 +1,10 @@ from typing import Any, Callable, Dict, List, Optional import logging -import os -import shutil import uuid from collections import defaultdict import pandas as pd -from mlflow.store.artifact.artifact_repo import verify_artifact_path from giskard.client.giskard_client import GiskardClient from giskard.core.suite import DatasetInput, ModelInput, SuiteInput @@ -34,7 +31,6 @@ ) from giskard.ml_worker.websocket.action import MLWorkerAction from giskard.models.base import BaseModel -from giskard.path_utils import artifacts_dir from giskard.registry.registry import tests_registry from giskard.registry.slicing_function import SlicingFunction from giskard.registry.transformation_function import TransformationFunction @@ -126,21 +122,6 @@ def map_function_meta_ws(callable_type): } -def log_artifact_local(local_file, artifact_path=None): - # Log artifact locally from an internal worker - verify_artifact_path(artifact_path) - - file_name = os.path.basename(local_file) - - if artifact_path: - artifact_file = artifacts_dir / artifact_path / file_name - else: - artifact_file = artifacts_dir / file_name - artifact_file.parent.mkdir(parents=True, exist_ok=True) - - shutil.copy(local_file, artifact_file) - - def map_dataset_process_function_meta_ws(callable_type): return { test.uuid: websocket.DatasetProcessFunctionMeta( @@ -182,7 +163,7 @@ def _get_or_load(loaded_artifacts: Dict[str, Dict[str, Any]], type: str, uuid: s def parse_function_arguments( - client: Optional[GiskardClient], + client: GiskardClient, request_arguments: List[websocket.FuncArgument], loaded_artifacts: Optional[Dict[str, Dict[str, Any]]] = None, ): @@ -245,7 +226,7 @@ def parse_function_arguments( def map_result_to_single_test_result_ws( result, datasets: Dict[uuid.UUID, Dataset], - client: Optional[GiskardClient] = None, + client: GiskardClient, project_key: Optional[str] = None, ) -> websocket.SingleTestResult: if isinstance(result, TestResult): @@ -302,9 +283,6 @@ def _upload_generated_output_df(client, datasets, project_key, result): ) if result.output_df.original_id not in datasets.keys(): - if not client: - raise RuntimeError("Legacy test debugging using `output_df` is not supported internal ML worker") - if not project_key: raise ValueError("Unable to upload debug dataset due to missing `project_key`") diff --git a/giskard/models/base/model.py b/giskard/models/base/model.py index 951c8d2b7c..7868dc5cbb 100644 --- a/giskard/models/base/model.py +++ b/giskard/models/base/model.py @@ -464,7 +464,7 @@ def upload(self, client: GiskardClient, project_key, validate_ds=None, *_args, * return str(self.id) @classmethod - def download(cls, client: Optional[GiskardClient], project_key, model_id, *_args, **_kwargs): + def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs): """ Downloads the specified model from the Giskard hub and loads it into memory. @@ -480,29 +480,24 @@ def download(cls, client: Optional[GiskardClient], project_key, model_id, *_args AssertionError: If the local directory where the model should be saved does not exist. """ local_dir = settings.home_dir / settings.cache_dir / "models" / model_id - if client is None: - # internal worker case, no token based http client [deprecated, to be removed] - assert local_dir.exists(), f"Cannot find existing model {project_key}.{model_id} in {local_dir}" - meta_response, meta = cls.read_meta_from_local_dir(local_dir) - else: - client.load_artifact(local_dir, posixpath.join("models", model_id)) - meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id) - # internal worker case, no token based http client - if not local_dir.exists(): - raise RuntimeError(f"Cannot find existing model {project_key}.{model_id} in {local_dir}") - with (Path(local_dir) / META_FILENAME).open(encoding="utf-8") as f: - file_meta = yaml.load(f, Loader=yaml.Loader) - classification_labels = cls.cast_labels(meta_response) - meta = ModelMeta( - name=meta_response.name, - description=meta_response.description, - model_type=SupportedModelTypes[meta_response.modelType], - feature_names=meta_response.featureNames, - classification_labels=classification_labels, - classification_threshold=meta_response.threshold, - loader_module=file_meta["loader_module"], - loader_class=file_meta["loader_class"], - ) + client.load_artifact(local_dir, posixpath.join("models", model_id)) + meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id) + # internal worker case, no token based http client + if not local_dir.exists(): + raise RuntimeError(f"Cannot find existing model {project_key}.{model_id} in {local_dir}") + with (Path(local_dir) / META_FILENAME).open(encoding="utf-8") as f: + file_meta = yaml.load(f, Loader=yaml.Loader) + classification_labels = cls.cast_labels(meta_response) + meta = ModelMeta( + name=meta_response.name, + description=meta_response.description, + model_type=SupportedModelTypes[meta_response.modelType], + feature_names=meta_response.featureNames, + classification_labels=classification_labels, + classification_threshold=meta_response.threshold, + loader_module=file_meta["loader_module"], + loader_class=file_meta["loader_class"], + ) model_py_ver = ( tuple(meta_response.languageVersion.split(".")) if "PYTHON" == meta_response.language.upper() else None diff --git a/tests/communications/test_websocket_actor.py b/tests/communications/test_websocket_actor.py index 9754cc59cf..8f88466d6a 100644 --- a/tests/communications/test_websocket_actor.py +++ b/tests/communications/test_websocket_actor.py @@ -1,4 +1,3 @@ -import shutil import time import uuid @@ -11,9 +10,7 @@ from giskard.ml_worker.websocket import listener from giskard.ml_worker.websocket.action import MLWorkerAction from giskard.models.base.model import BaseModel -from giskard.settings import settings from giskard.utils import call_in_pool, start_pool -from giskard.utils.file_utils import get_file_name from tests import utils NOT_USED_WEBSOCKET_ACTOR = [ @@ -92,44 +89,6 @@ def test_websocket_actor_get_catalog(): assert "giskard" in t.tags -@pytest.mark.parametrize( - "data,model,sample", - [ - ("enron_data", "enron_model", False), - ("enron_data", "enron_model", True), - ("enron_data", "enron_model", None), - ("hotel_text_data", "hotel_text_model", False), - ], -) -def test_websocket_actor_run_model_internal(data, model, sample, request): - dataset: Dataset = request.getfixturevalue(data) - model: BaseModel = request.getfixturevalue(model) - - project_key = str(uuid.uuid4()) # Use a UUID to separate the resources used by the tests - inspection_id = 0 - - with utils.MockedProjectCacheDir(): - # Prepare dataset and model - utils.local_save_model_under_giskard_home_cache(model) - utils.local_save_dataset_under_giskard_home_cache(dataset) - - params = websocket.RunModelParam( - model=websocket.ArtifactRef(project_key=project_key, id=str(model.id)), - dataset=websocket.ArtifactRef(project_key=project_key, id=str(dataset.id), sample=sample), - inspectionId=inspection_id, - project_key=project_key, - ) - # Internal worker does not have client - reply = listener.run_model(client=None, params=params) - assert isinstance(reply, websocket.Empty) - # Inspection are logged locally - inspection_path = settings.home_dir / "artifacts" / "models" / "inspections" / str(inspection_id) - assert (inspection_path / get_file_name("predictions", "csv", sample)).exists() - assert (inspection_path / get_file_name("calculated", "csv", sample)).exists() - # Clean up - shutil.rmtree(inspection_path, ignore_errors=True) - - @pytest.mark.parametrize( "data,model,sample", [ @@ -164,13 +123,11 @@ def test_websocket_actor_run_model(data, model, sample, request): # Prepare URL for inspections utils.register_uri_for_inspection(mr, project_key, inspection_id, sample) - # Internal worker does not have client reply = listener.run_model(client=client, params=params) assert isinstance(reply, websocket.Empty) -@pytest.mark.parametrize("internal", [True, False]) -def test_websocket_actor_run_model_for_data_frame_regression(internal, request): +def test_websocket_actor_run_model_for_data_frame_regression(request): dataset: Dataset = request.getfixturevalue("hotel_text_data") model: BaseModel = request.getfixturevalue("hotel_text_model") @@ -178,11 +135,8 @@ def test_websocket_actor_run_model_for_data_frame_regression(internal, request): with utils.MockedProjectCacheDir(), utils.MockedClient(mock_all=False) as (client, mr): # Prepare model - if internal: - utils.local_save_model_under_giskard_home_cache(model) - else: - utils.register_uri_for_model_meta_info(mr, model, project_key) - utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True) + utils.register_uri_for_model_meta_info(mr, model, project_key) + utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True) # Prepare dataframe dataframe = websocket.DataFrame( @@ -201,7 +155,7 @@ def test_websocket_actor_run_model_for_data_frame_regression(internal, request): ) # client - reply = listener.run_model_for_data_frame(client=None if internal else client, params=params) + reply = listener.run_model_for_data_frame(client=client, params=params) assert isinstance(reply, websocket.RunModelForDataFrame) assert not reply.all_predictions assert reply.prediction @@ -209,8 +163,7 @@ def test_websocket_actor_run_model_for_data_frame_regression(internal, request): assert not reply.probabilities -@pytest.mark.parametrize("internal", [True, False]) -def test_websocket_actor_run_model_for_data_frame_classification(internal, request): +def test_websocket_actor_run_model_for_data_frame_classification(request): dataset: Dataset = request.getfixturevalue("enron_data") model: BaseModel = request.getfixturevalue("enron_model") @@ -218,11 +171,8 @@ def test_websocket_actor_run_model_for_data_frame_classification(internal, reque with utils.MockedProjectCacheDir(), utils.MockedClient(mock_all=False) as (client, mr): # Prepare model - if internal: - utils.local_save_model_under_giskard_home_cache(model) - else: - utils.register_uri_for_model_meta_info(mr, model, project_key) - utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True) + utils.register_uri_for_model_meta_info(mr, model, project_key) + utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True) # Prepare dataframe dataframe = websocket.DataFrame( @@ -241,7 +191,7 @@ def test_websocket_actor_run_model_for_data_frame_classification(internal, reque ) # client - reply = listener.run_model_for_data_frame(client=None if internal else client, params=params) + reply = listener.run_model_for_data_frame(client=client, params=params) assert isinstance(reply, websocket.RunModelForDataFrame) assert reply.all_predictions assert reply.prediction @@ -249,27 +199,6 @@ def test_websocket_actor_run_model_for_data_frame_classification(internal, reque assert not reply.probabilities -@pytest.mark.parametrize("data,model", [("enron_data", "enron_model"), ("hotel_text_data", "hotel_text_model")]) -def test_websocket_actor_explain_ws_internal(data, model, request): - dataset: Dataset = request.getfixturevalue(data) - model: BaseModel = request.getfixturevalue(model) - - project_key = str(uuid.uuid4()) # Use a UUID to separate the resources used by the tests - - with utils.MockedProjectCacheDir(): - # Prepare model and dataset - utils.local_save_model_under_giskard_home_cache(model) - utils.local_save_dataset_under_giskard_home_cache(dataset) - - params = websocket.ExplainParam( - model=websocket.ArtifactRef(project_key=project_key, id=str(model.id)), - dataset=websocket.ArtifactRef(project_key=project_key, id=str(dataset.id)), - columns={str(k): str(v) for k, v in next(dataset.df.iterrows())[1].items()}, # Pick the first row - ) - reply = listener.explain_ws(client=None, params=params) - assert isinstance(reply, websocket.Explain) - - @pytest.mark.parametrize("data,model", [("enron_data", "enron_model"), ("hotel_text_data", "hotel_text_model")]) @pytest.mark.slow def test_websocket_actor_explain_ws(data, model, request): @@ -294,8 +223,7 @@ def test_websocket_actor_explain_ws(data, model, request): assert isinstance(reply, websocket.Explain) -@pytest.mark.parametrize("internal", [True, False]) -def test_websocket_actor_explain_text_ws_regression(internal, request): +def test_websocket_actor_explain_text_ws_regression(request): dataset: Dataset = request.getfixturevalue("hotel_text_data") model: BaseModel = request.getfixturevalue("hotel_text_model") @@ -303,11 +231,8 @@ def test_websocket_actor_explain_text_ws_regression(internal, request): with utils.MockedProjectCacheDir(), utils.MockedClient(mock_all=False) as (client, mr): # Prepare model and dataset - if internal: - utils.local_save_model_under_giskard_home_cache(model) - else: - utils.register_uri_for_model_meta_info(mr, model, project_key) - utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True) + utils.register_uri_for_model_meta_info(mr, model, project_key) + utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True) text_feature_name = None for col_name, col_type in dataset.column_types.items(): @@ -322,7 +247,7 @@ def test_websocket_actor_explain_text_ws_regression(internal, request): columns={str(k): str(v) for k, v in next(dataset.df.iterrows())[1].items()}, # Pick the first row column_types=dataset.column_types, ) - reply = listener.explain_text_ws(client=None if internal else client, params=params) + reply = listener.explain_text_ws(client=client, params=params) assert isinstance(reply, websocket.ExplainText) # Regression text explaining: Giskard Hub uses "WEIGHTS" to show it assert "WEIGHTS" in reply.weights.keys() @@ -336,7 +261,8 @@ def test_websocket_actor_explain_text_ws_not_text(request): with utils.MockedProjectCacheDir(), utils.MockedClient(mock_all=False) as (client, mr): # Prepare model and dataset - utils.local_save_model_under_giskard_home_cache(model) + utils.register_uri_for_model_meta_info(mr, model, project_key) + utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True) not_text_feature_name = None for col_name, col_type in dataset.column_types.items(): @@ -352,11 +278,10 @@ def test_websocket_actor_explain_text_ws_not_text(request): column_types=dataset.column_types, ) with pytest.raises(ValueError, match="Column .* is not of type text"): - listener.explain_text_ws(client=None, params=params) + listener.explain_text_ws(client=client, params=params) -@pytest.mark.parametrize("internal", [True, False]) -def test_websocket_actor_explain_text_ws_classification(internal, request): +def test_websocket_actor_explain_text_ws_classification(request): dataset: Dataset = request.getfixturevalue("enron_data") model: BaseModel = request.getfixturevalue("enron_model") @@ -364,11 +289,8 @@ def test_websocket_actor_explain_text_ws_classification(internal, request): with utils.MockedProjectCacheDir(), utils.MockedClient(mock_all=False) as (client, mr): # Prepare model and dataset - if internal: - utils.local_save_model_under_giskard_home_cache(model) - else: - utils.register_uri_for_model_meta_info(mr, model, project_key) - utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True) + utils.register_uri_for_model_meta_info(mr, model, project_key) + utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True) text_feature_name = None for col_name, col_type in dataset.column_types.items(): @@ -383,26 +305,22 @@ def test_websocket_actor_explain_text_ws_classification(internal, request): columns={str(k): str(v) for k, v in next(dataset.df.iterrows())[1].items()}, # Pick the first row column_types=dataset.column_types, ) - reply = listener.explain_text_ws(client=None if internal else client, params=params) + reply = listener.explain_text_ws(client=client, params=params) assert isinstance(reply, websocket.ExplainText) # Classification labels for label in model.classification_labels: assert label in reply.weights.keys() -@pytest.mark.parametrize("internal", [True, False]) -def test_websocket_actor_dataset_processing_empty(internal, request): +def test_websocket_actor_dataset_processing_empty(request): dataset: Dataset = request.getfixturevalue("enron_data") project_key = str(uuid.uuid4()) # Use a UUID to separate the resources used by the tests with utils.MockedProjectCacheDir(), utils.MockedClient(mock_all=False) as (client, mr): # Prepare dataset - if internal: - utils.local_save_dataset_under_giskard_home_cache(dataset) - else: - utils.register_uri_for_dataset_meta_info(mr, dataset, project_key) - utils.register_uri_for_dataset_artifact_info(mr, dataset, project_key, register_file_contents=True) + utils.register_uri_for_dataset_meta_info(mr, dataset, project_key) + utils.register_uri_for_dataset_artifact_info(mr, dataset, project_key, register_file_contents=True) # FIXME: functions can be None from the protocol, but not iterable # params = websocket.DatasetProcessingParam( @@ -413,7 +331,7 @@ def test_websocket_actor_dataset_processing_empty(internal, request): dataset=websocket.ArtifactRef(project_key=project_key, id=str(dataset.id), sample=False), functions=[], ) - reply = listener.dataset_processing(client=None if internal else client, params=params) + reply = listener.dataset_processing(client=client, params=params) assert isinstance(reply, websocket.DatasetProcessing) assert reply.datasetId == str(dataset.id) assert reply.totalRows == len(list(dataset.df.index)) @@ -429,13 +347,10 @@ def head_slice(df: pd.DataFrame) -> pd.DataFrame: return df.head(1) -# FIXME: Internal worker cannot yet load callable due to yaml deserialization issue -@pytest.mark.parametrize("callable_under_project", [False, True]) -def test_websocket_actor_dataset_processing_head_slicing_with_cache(callable_under_project, request): +def test_websocket_actor_dataset_processing_head_slicing_with_cache(request): dataset: Dataset = request.getfixturevalue("enron_data") project_key = str(uuid.uuid4()) # Use a UUID to separate the resources used by the tests - callable_function_project_key = project_key if callable_under_project else None with utils.MockedProjectCacheDir(): # Prepare dataset @@ -448,7 +363,7 @@ def test_websocket_actor_dataset_processing_head_slicing_with_cache(callable_und functions=[ websocket.DatasetProcessingFunction( slicingFunction=websocket.ArtifactRef( - project_key=callable_function_project_key, + project_key=project_key, id=head_slice.meta.uuid, ) ) @@ -458,8 +373,8 @@ def test_websocket_actor_dataset_processing_head_slicing_with_cache(callable_und # Prepare URL for meta info cf = head_slice # The slicing function will be loaded from the current module, without further requests - utils.register_uri_for_artifact_meta_info(mr, cf, project_key=callable_function_project_key) - utils.register_uri_for_artifact_info(mr, cf, project_key=callable_function_project_key) + utils.register_uri_for_artifact_meta_info(mr, cf, project_key=project_key) + utils.register_uri_for_artifact_info(mr, cf, project_key=project_key) # The dataset can be then loaded from the cache, without further requests utils.register_uri_for_dataset_meta_info(mr, dataset, project_key) @@ -480,9 +395,7 @@ def do_nothing(row): return row -# FIXME: Internal worker cannot yet load callable due to yaml deserialization issue -@pytest.mark.parametrize("callable_under_project", [False, True]) -def test_websocket_actor_dataset_processing_do_nothing_transform_with_cache(callable_under_project, request): +def test_websocket_actor_dataset_processing_do_nothing_transform_with_cache(request): dataset: Dataset = request.getfixturevalue("enron_data") project_key = str(uuid.uuid4()) # Use a UUID to separate the resources used by the tests @@ -490,7 +403,6 @@ def test_websocket_actor_dataset_processing_do_nothing_transform_with_cache(call with utils.MockedProjectCacheDir(): # Prepare dataset utils.local_save_dataset_under_giskard_home_cache(dataset) - callable_function_project_key = project_key if callable_under_project else None do_nothing.meta.uuid = str(uuid.uuid4()) @@ -499,7 +411,7 @@ def test_websocket_actor_dataset_processing_do_nothing_transform_with_cache(call functions=[ websocket.DatasetProcessingFunction( transformationFunction=websocket.ArtifactRef( - project_key=callable_function_project_key, + project_key=project_key, id=do_nothing.meta.uuid, ) ) @@ -509,8 +421,8 @@ def test_websocket_actor_dataset_processing_do_nothing_transform_with_cache(call # Prepare URL for meta info cf = do_nothing # The slicing function will be loaded from the current module, without further requests - utils.register_uri_for_artifact_meta_info(mr, cf, project_key=callable_function_project_key) - utils.register_uri_for_artifact_info(mr, cf, project_key=callable_function_project_key) + utils.register_uri_for_artifact_meta_info(mr, cf, project_key=project_key) + utils.register_uri_for_artifact_info(mr, cf, project_key=project_key) # The dataset can be then loaded from the cache, without further requests utils.register_uri_for_dataset_meta_info(mr, dataset, project_key) diff --git a/tests/communications/test_websocket_actor_tests.py b/tests/communications/test_websocket_actor_tests.py index 54b8b9e116..79287eb59d 100644 --- a/tests/communications/test_websocket_actor_tests.py +++ b/tests/communications/test_websocket_actor_tests.py @@ -40,12 +40,10 @@ def my_simple_test_error(): def test_websocket_actor_run_ad_hoc_test_no_debug(debug): with utils.MockedProjectCacheDir(): params = websocket.RunAdHocTestParam( - testUuid=my_simple_test.meta.uuid, - arguments=[], - debug=debug, + testUuid=my_simple_test.meta.uuid, arguments=[], debug=debug, projectKey="project_key" ) with utils.MockedClient(mock_all=False) as (client, mr): - utils.register_uri_for_artifact_meta_info(mr, my_simple_test, None) + utils.register_uri_for_artifact_meta_info(mr, my_simple_test, "project_key") reply = listener.run_ad_hoc_test(client=client, params=params) assert isinstance(reply, websocket.RunAdHocTest) @@ -106,59 +104,6 @@ def test_websocket_actor_run_ad_hoc_test_legacy_debug(enron_data: Dataset): assert reply.results[0].result.failed_indexes -def test_websocket_actor_run_ad_hoc_test_legacy_no_client(enron_data: Dataset): - project_key = str(uuid.uuid4()) - - with utils.MockedProjectCacheDir(): - utils.local_save_dataset_under_giskard_home_cache(enron_data) - - params = websocket.RunAdHocTestParam( - testUuid=my_simple_test_legacy_debug.meta.uuid, - arguments=[ - websocket.FuncArgument( - name="dataset", - none=False, - dataset=websocket.ArtifactRef( - project_key=project_key, - id=str(enron_data.id), - ), - ), - ], - debug=True, - ) - - with pytest.raises(RuntimeError): - listener.run_ad_hoc_test(client=None, params=params) - - -def test_websocket_actor_run_ad_hoc_test_legacy_no_project_key(enron_data: Dataset): - project_key = str(uuid.uuid4()) - - with utils.MockedProjectCacheDir(): - utils.local_save_dataset_under_giskard_home_cache(enron_data) - - params = websocket.RunAdHocTestParam( - testUuid=my_simple_test_legacy_debug.meta.uuid, - arguments=[ - websocket.FuncArgument( - name="dataset", - none=False, - dataset=websocket.ArtifactRef( - project_key=project_key, - id=str(enron_data.id), - ), - ), - ], - debug=True, - ) - - with utils.MockedClient(mock_all=False) as (client, mr), pytest.raises(ValueError): - utils.register_uri_for_artifact_meta_info(mr, my_simple_test_legacy_debug, None) - utils.register_uri_for_dataset_meta_info(mr, enron_data, project_key) - - listener.run_ad_hoc_test(client=client, params=params) - - @test def my_simple_test_debug(dataset: Dataset, debug: bool = False): return GiskardTestResult(passed=False, output_ds=[dataset.slice(lambda df: df.head(1), row_level=False)]) @@ -183,9 +128,10 @@ def test_websocket_actor_run_ad_hoc_test_debug(enron_data: Dataset): ), ], debug=True, + projectKey=project_key, ) with utils.MockedClient(mock_all=False) as (client, mr): - utils.register_uri_for_artifact_meta_info(mr, my_simple_test_debug, None) + utils.register_uri_for_artifact_meta_info(mr, my_simple_test_debug, project_key) utils.register_uri_for_dataset_meta_info(mr, enron_data, project_key) utils.register_uri_for_any_dataset_artifact_info_upload(mr, register_files=True) @@ -246,9 +192,10 @@ def test_websocket_actor_run_ad_hoc_test_debug_multiple_datasets(enron_data: Dat ), ], debug=True, + projectKey=project_key, ) with utils.MockedClient(mock_all=False) as (client, mr): - utils.register_uri_for_artifact_meta_info(mr, my_simple_test_debug_multiple_datasets, None) + utils.register_uri_for_artifact_meta_info(mr, my_simple_test_debug_multiple_datasets, project_key) utils.register_uri_for_dataset_meta_info(mr, enron_data, project_key) utils.register_uri_for_dataset_meta_info(mr, dataset2, project_key) utils.register_uri_for_any_dataset_artifact_info_upload(mr, register_files=True) diff --git a/tests/core/test_suite.py b/tests/core/test_suite.py index ae844bfe11..91359e037f 100644 --- a/tests/core/test_suite.py +++ b/tests/core/test_suite.py @@ -21,6 +21,7 @@ def my_test(model: BaseModel): def test_save_suite_with_artifact_error(): + project_key = "project_key" model = FailingModel(model_type="regression") regex_model_name = str(model).replace("(", "\\(").replace(")", "\\)") @@ -28,13 +29,13 @@ def test_save_suite_with_artifact_error(): UserWarning, match=f"Failed to upload {regex_model_name} used in the test suite. The test suite will be partially uploaded.", ): - utils.register_uri_for_artifact_meta_info(mr, my_test, None) + utils.register_uri_for_artifact_meta_info(mr, my_test, project_key) mr.register_uri( method=requests_mock.POST, - url="http://giskard-host:12345/api/v2/testing/project/titanic/suites", + url="http://giskard-host:12345/api/v2/testing/project/project_key/suites", json={"id": 1, "tests": [{"id": 2}]}, ) test_suite = Suite().add_test(my_test, model=model) - test_suite.upload(client, "titanic") + test_suite.upload(client, project_key) diff --git a/tests/test_artifact_download.py b/tests/test_artifact_download.py index aa555d6f53..aed2f3070c 100644 --- a/tests/test_artifact_download.py +++ b/tests/test_artifact_download.py @@ -26,6 +26,7 @@ ) BASE_CLIENT_URL = "http://giskard-host:12345/api/v2" +PROJECT_KEY = "project_key" # Define a test function @@ -46,39 +47,6 @@ def do_nothing(row): return row -def test_download_global_test_function_from_registry(): - cf: Artifact = my_custom_test - - # Load from registry using uuid without client - download_cf = cf.__class__.download(uuid=cf.meta.uuid, client=None, project_key=None) - - # Check the downloaded info - assert download_cf.__class__ is cf.__class__ - assert download_cf.meta.uuid == cf.meta.uuid - - -@pytest.mark.parametrize( - "cf", - [ - my_custom_test, # Test - head_slice, # Slice - do_nothing, # Transformation - ], -) -def test_download_global_test_function_from_local(cf): - with MockedProjectCacheDir(): - cf.meta.uuid = str(uuid.uuid4()) # Regenerate a UUID to ensure not loading from registry - - local_save_artifact_under_giskard_home_cache(cf) - - # Load from registry using uuid without client - download_cf = cf.__class__.download(uuid=cf.meta.uuid, client=None, project_key=None) - - # Check the downloaded info - assert download_cf.__class__ is cf.__class__ - assert download_cf.meta.uuid == cf.meta.uuid - - @pytest.mark.parametrize( "cf", [ @@ -109,12 +77,12 @@ def test_download_callable_function(cf: Artifact): "name": f"fake_{cf._get_name()}", } ) - url = get_url_for_artifact_meta_info(cf, project_key=None) + url = get_url_for_artifact_meta_info(cf, PROJECT_KEY) mr.register_uri(method=requests_mock.GET, url=url, json=meta_info) requested_urls.append(url) # Register for Artifact info - requested_urls.extend(register_uri_for_artifact_info(mr, cf, project_key=None)) + requested_urls.extend(register_uri_for_artifact_info(mr, cf, PROJECT_KEY)) # Register for Artifacts content artifacts_base_url = get_url_for_artifacts_base(cf) @@ -128,7 +96,7 @@ def test_download_callable_function(cf: Artifact): requested_urls.append(posixpath.join(artifacts_base_url, file)) # Download: should not call load_artifact to request and download - download_cf = cf.__class__.download(uuid=cf.meta.uuid, client=client, project_key=None) + download_cf = cf.__class__.download(uuid=cf.meta.uuid, client=client, project_key=PROJECT_KEY) for requested_url in requested_urls: assert is_url_requested(mr.request_history, requested_url) @@ -149,17 +117,17 @@ def test_download_callable_function(cf: Artifact): do_nothing, # Transformation ], ) -def test_download_global_callable_function_from_module(cf: Artifact): +def test_download_callable_function_from_module(cf: Artifact): with MockedClient(mock_all=False) as (client, mr): cf.meta.uuid = str(uuid.uuid4()) # Regenerate a UUID to ensure not loading from registry cache_dir = get_local_cache_callable_artifact(artifact=cf) requested_urls = [] - # Prepare global URL - requested_urls.extend(register_uri_for_artifact_meta_info(mr, cf, project_key=None)) + # Prepare URL + requested_urls.extend(register_uri_for_artifact_meta_info(mr, cf, project_key=PROJECT_KEY)) # Download: should not call load_artifact to request and download - download_cf = cf.__class__.download(uuid=cf.meta.uuid, client=client, project_key=None) + download_cf = cf.__class__.download(uuid=cf.meta.uuid, client=client, project_key=PROJECT_KEY) for requested_url in requested_urls: assert is_url_requested(mr.request_history, requested_url) @@ -180,7 +148,7 @@ def test_download_global_callable_function_from_module(cf: Artifact): do_nothing, # Transformation ], ) -def test_download_global_callable_function_from_cache(cf: Artifact): +def test_download_callable_function_from_cache(cf: Artifact): with MockedClient(mock_all=False) as (client, mr): cf.meta.uuid = str(uuid.uuid4()) # Regenerate a UUID cache_dir = get_local_cache_callable_artifact(artifact=cf) @@ -190,10 +158,10 @@ def test_download_global_callable_function_from_cache(cf: Artifact): assert (cache_dir / CALLABLE_FUNCTION_PKL_CACHE).exists() assert (cache_dir / CALLABLE_FUNCTION_META_CACHE).exists() - requested_urls = register_uri_for_artifact_meta_info(mr, cf, project_key=None) + requested_urls = register_uri_for_artifact_meta_info(mr, cf, project_key=PROJECT_KEY) # Download: should not call load_artifact to request and download - download_cf = cf.__class__.download(uuid=cf.meta.uuid, client=client, project_key=None) + download_cf = cf.__class__.download(uuid=cf.meta.uuid, client=client, project_key=PROJECT_KEY) for requested_url in requested_urls: assert is_url_requested(mr.request_history, requested_url) @@ -281,7 +249,7 @@ def test_download_callable_function_from_module_in_project(cf: Artifact): cache_dir = get_local_cache_callable_artifact(artifact=cf) requested_urls = [] - # Prepare global URL + # Prepare URL requested_urls.extend(register_uri_for_artifact_meta_info(mr, cf, project_key)) # Download: should not call load_artifact to request and download diff --git a/tests/test_programmatic.py b/tests/test_programmatic.py index 95f9354edc..f0e9b877f3 100644 --- a/tests/test_programmatic.py +++ b/tests/test_programmatic.py @@ -434,7 +434,7 @@ def test_download_suite_run_and_upload_results(): "projectKey": "test_project", }, ) - register_uri_for_artifact_meta_info(mr, test_auc) + register_uri_for_artifact_meta_info(mr, test_auc, "test_project") mr.register_uri(requests_mock.GET, UPLOAD_RESULTS_URL, json={}) diff --git a/tests/test_upload.py b/tests/test_upload.py index 75c91f61c2..6f401292c5 100644 --- a/tests/test_upload.py +++ b/tests/test_upload.py @@ -17,6 +17,7 @@ ) model_name = "uploaded model" +PROJECT_KEY = "project_key" def test_upload_df(diabetes_dataset: Dataset, diabetes_dataset_with_target: Dataset): @@ -151,7 +152,7 @@ def test_upload_callable_function(cf: Artifact): + "/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/.*" ) with MockedClient() as (client, mr): - cf.upload(client=client, project_key=None) + cf.upload(client=client, project_key=PROJECT_KEY) # Check local cache cache_dir = get_local_cache_callable_artifact(artifact=cf) assert (cache_dir / CALLABLE_FUNCTION_PKL_CACHE).exists() diff --git a/tests/utils.py b/tests/utils.py index d4137bd1eb..2095e351dc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -264,12 +264,8 @@ def mock_model_meta_info(model: BaseModel, project_key: str): return model_meta_info.dict() -def get_url_for_artifact_meta_info(cf: Artifact, project_key: Optional[str] = None): - return ( - posixpath.join(CLIENT_BASE_URL, "project", project_key, cf._get_name(), cf.meta.uuid) - if project_key - else posixpath.join(CLIENT_BASE_URL, cf._get_name(), cf.meta.uuid) - ) +def get_url_for_artifact_meta_info(cf: Artifact, project_key: str): + return posixpath.join(CLIENT_BASE_URL, "project", project_key, cf._get_name(), cf.meta.uuid) def get_url_for_artifacts_base(cf: Artifact): @@ -284,7 +280,7 @@ def get_url_for_model(model: BaseModel, project_key: str): return posixpath.join(CLIENT_BASE_URL, "project", project_key, "models", str(model.id)) -def register_uri_for_artifact_meta_info(mr: requests_mock.Mocker, cf: Artifact, project_key: Optional[str] = None): +def register_uri_for_artifact_meta_info(mr: requests_mock.Mocker, cf: Artifact, project_key: str): url = get_url_for_artifact_meta_info(cf, project_key) # Fixup the differences from Backend meta_info = fixup_mocked_artifact_meta_version(cf.meta.to_json())