diff --git a/giskard/ml_worker/websocket/__init__.py b/giskard/ml_worker/websocket/__init__.py index d767e944ce..65fcc338c3 100644 --- a/giskard/ml_worker/websocket/__init__.py +++ b/giskard/ml_worker/websocket/__init__.py @@ -430,5 +430,12 @@ class CreateSubDatasetParam(ConfiguredBaseModel): copiedRows: Dict[str, List[int]] +class CreateDatasetParam(ConfiguredBaseModel): + projectKey: str + name: str + headers: List[str] + rows: List[List[str]] + + class CreateSubDataset(WorkerReply): datasetUuid: str diff --git a/giskard/ml_worker/websocket/action.py b/giskard/ml_worker/websocket/action.py index a9741a9320..15fd5a52d6 100644 --- a/giskard/ml_worker/websocket/action.py +++ b/giskard/ml_worker/websocket/action.py @@ -30,6 +30,7 @@ class MLWorkerAction(Enum): createSubDataset = 14 abort = 15 getLogs = 16 + createDataset = 17 @classmethod def __get_validators__(cls): diff --git a/giskard/ml_worker/websocket/listener.py b/giskard/ml_worker/websocket/listener.py index e6d56604c1..4e41a0e086 100644 --- a/giskard/ml_worker/websocket/listener.py +++ b/giskard/ml_worker/websocket/listener.py @@ -32,6 +32,7 @@ from giskard.ml_worker.websocket import CallToActionKind, GetInfoParam, PushKind from giskard.ml_worker.websocket.action import ActionPayload, MLWorkerAction from giskard.ml_worker.websocket.utils import ( + do_create_dataset, do_create_sub_dataset, do_run_adhoc_test, function_argument_to_ws, @@ -750,6 +751,15 @@ def create_sub_dataset( return websocket.CreateSubDataset(datasetUuid=sub_dataset.upload(client=client, project_key=params.projectKey)) +@websocket_actor(MLWorkerAction.createDataset) +def create_dataset( + client: Optional[GiskardClient], params: websocket.CreateDatasetParam, *arg, **kwargs +) -> websocket.CreateSubDataset: + dataset = do_create_dataset(params.name, params.headers, params.rows) + + return websocket.CreateSubDataset(datasetUuid=dataset.upload(client=client, project_key=params.projectKey)) + + def tail_file(file_path: Path, n_lines: int): if not file_path.exists(): raise FileNotFoundError(f"File {file_path.name} does not exist") diff --git a/giskard/ml_worker/websocket/utils.py b/giskard/ml_worker/websocket/utils.py index 5df6fecfb2..fc989aabd4 100644 --- a/giskard/ml_worker/websocket/utils.py +++ b/giskard/ml_worker/websocket/utils.py @@ -17,6 +17,7 @@ from giskard.ml_worker import websocket from giskard.ml_worker.websocket import ( AbortParams, + CreateDatasetParam, CreateSubDatasetParam, DatasetProcessingParam, Documentation, @@ -67,6 +68,8 @@ def parse_action_param(action: MLWorkerAction, params): return GetPushParam.parse_obj(params) elif action == MLWorkerAction.createSubDataset: return CreateSubDatasetParam.parse_obj(params) + elif action == MLWorkerAction.createDataset: + return CreateDatasetParam.parse_obj(params) elif action == MLWorkerAction.getLogs: return GetLogsParams.parse_obj(params) return params @@ -404,3 +407,7 @@ def do_create_sub_dataset(datasets: Dict[str, Dataset], name: Optional[str], row column_types=dataset_list[0].column_types, validation=False, ) + + +def do_create_dataset(name: Optional[str], headers: List[str], rows: List[List[str]]): + return Dataset(pd.DataFrame(rows, columns=headers), name=name, validation=False) diff --git a/tests/communications/fixtures/with_alias.json b/tests/communications/fixtures/with_alias.json index d2fd30b9c6..afff39af11 100644 --- a/tests/communications/fixtures/with_alias.json +++ b/tests/communications/fixtures/with_alias.json @@ -302,6 +302,51 @@ } } ], + "CreateDatasetParam": [ + { + "projectKey": "gSNnzxHPebRwNaxLlXUb", + "name": "HaTJYLBNWCrJNXWvVVfG", + "headers": [], + "rows": [] + }, + { + "projectKey": "bCWtlvblwzmmlxpbRCHj", + "name": "HaTJYLBNWCrJNXWvVVfG", + "headers": [ + "bCWtlvblwzmmlxpbRCHj" + ], + "rows": [ + [ + "bCWtlvblwzmmlxpbRCHj" + ], + [ + "bCWtlvblwzmmlxpbRCHj" + ] + ] + }, + { + "projectKey": "ZSgrLKWvRIpETSAEbmPv", + "name": "HaTJYLBNWCrJNXWvVVfG", + "headers": [ + "first", + "second" + ], + "rows": [ + [ + "1", + "2" + ], + [ + "3", + "4" + ], + [ + "5", + "6" + ] + ] + } + ], "CreateSubDataset": [ { "datasetUuid": "gSNnzxHPebRwNaxLlXUb" @@ -3227,4 +3272,4 @@ {}, {} ] -} \ No newline at end of file +} diff --git a/tests/communications/test_dto_serialization.py b/tests/communications/test_dto_serialization.py index ebc1a3904c..105a99e5b7 100644 --- a/tests/communications/test_dto_serialization.py +++ b/tests/communications/test_dto_serialization.py @@ -25,6 +25,7 @@ "AbortParams": ["job_id"], "ArtifactRef": ["id"], "Catalog": ["tests", "slices", "transformations"], + "CreateDatasetParam": ["projectKey", "name", "headers", "rows"], "CreateSubDataset": ["datasetUuid"], "CreateSubDatasetParam": ["projectKey", "sample", "name", "copiedRows"], "DataFrame": ["rows"], @@ -95,6 +96,7 @@ "AbortParams": [], "ArtifactRef": ["project_key", "sample"], "Catalog": [], + "CreateDatasetParam": [], "CreateSubDataset": [], "CreateSubDatasetParam": [], "DataFrame": [],