Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions giskard/ml_worker/websocket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,5 +430,12 @@ class CreateSubDatasetParam(ConfiguredBaseModel):
copiedRows: Dict[str, List[int]]


class CreateDatasetParam(ConfiguredBaseModel):
projectKey: str
name: str
headers: List[str]
rows: List[List[str]]


class CreateSubDataset(WorkerReply):
datasetUuid: str
1 change: 1 addition & 0 deletions giskard/ml_worker/websocket/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MLWorkerAction(Enum):
createSubDataset = 14
abort = 15
getLogs = 16
createDataset = 17

@classmethod
def __get_validators__(cls):
Expand Down
10 changes: 10 additions & 0 deletions giskard/ml_worker/websocket/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from giskard.ml_worker.websocket import CallToActionKind, GetInfoParam, PushKind
from giskard.ml_worker.websocket.action import ActionPayload, MLWorkerAction
from giskard.ml_worker.websocket.utils import (
do_create_dataset,
do_create_sub_dataset,
do_run_adhoc_test,
function_argument_to_ws,
Expand Down Expand Up @@ -750,6 +751,15 @@ def create_sub_dataset(
return websocket.CreateSubDataset(datasetUuid=sub_dataset.upload(client=client, project_key=params.projectKey))


@websocket_actor(MLWorkerAction.createDataset)
def create_dataset(
client: Optional[GiskardClient], params: websocket.CreateDatasetParam, *arg, **kwargs
) -> websocket.CreateSubDataset:
dataset = do_create_dataset(params.name, params.headers, params.rows)

return websocket.CreateSubDataset(datasetUuid=dataset.upload(client=client, project_key=params.projectKey))


def tail_file(file_path: Path, n_lines: int):
if not file_path.exists():
raise FileNotFoundError(f"File {file_path.name} does not exist")
Expand Down
7 changes: 7 additions & 0 deletions giskard/ml_worker/websocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from giskard.ml_worker import websocket
from giskard.ml_worker.websocket import (
AbortParams,
CreateDatasetParam,
CreateSubDatasetParam,
DatasetProcessingParam,
Documentation,
Expand Down Expand Up @@ -67,6 +68,8 @@ def parse_action_param(action: MLWorkerAction, params):
return GetPushParam.parse_obj(params)
elif action == MLWorkerAction.createSubDataset:
return CreateSubDatasetParam.parse_obj(params)
elif action == MLWorkerAction.createDataset:
return CreateDatasetParam.parse_obj(params)
elif action == MLWorkerAction.getLogs:
return GetLogsParams.parse_obj(params)
return params
Expand Down Expand Up @@ -404,3 +407,7 @@ def do_create_sub_dataset(datasets: Dict[str, Dataset], name: Optional[str], row
column_types=dataset_list[0].column_types,
validation=False,
)


def do_create_dataset(name: Optional[str], headers: List[str], rows: List[List[str]]):
return Dataset(pd.DataFrame(rows, columns=headers), name=name, validation=False)
47 changes: 46 additions & 1 deletion tests/communications/fixtures/with_alias.json
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,51 @@
}
}
],
"CreateDatasetParam": [
{
"projectKey": "gSNnzxHPebRwNaxLlXUb",
"name": "HaTJYLBNWCrJNXWvVVfG",
"headers": [],
"rows": []
},
{
"projectKey": "bCWtlvblwzmmlxpbRCHj",
"name": "HaTJYLBNWCrJNXWvVVfG",
"headers": [
"bCWtlvblwzmmlxpbRCHj"
],
"rows": [
[
"bCWtlvblwzmmlxpbRCHj"
],
[
"bCWtlvblwzmmlxpbRCHj"
]
]
},
{
"projectKey": "ZSgrLKWvRIpETSAEbmPv",
"name": "HaTJYLBNWCrJNXWvVVfG",
"headers": [
"first",
"second"
],
"rows": [
[
"1",
"2"
],
[
"3",
"4"
],
[
"5",
"6"
]
]
}
],
"CreateSubDataset": [
{
"datasetUuid": "gSNnzxHPebRwNaxLlXUb"
Expand Down Expand Up @@ -3227,4 +3272,4 @@
{},
{}
]
}
}
2 changes: 2 additions & 0 deletions tests/communications/test_dto_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"AbortParams": ["job_id"],
"ArtifactRef": ["id"],
"Catalog": ["tests", "slices", "transformations"],
"CreateDatasetParam": ["projectKey", "name", "headers", "rows"],
"CreateSubDataset": ["datasetUuid"],
"CreateSubDatasetParam": ["projectKey", "sample", "name", "copiedRows"],
"DataFrame": ["rows"],
Expand Down Expand Up @@ -95,6 +96,7 @@
"AbortParams": [],
"ArtifactRef": ["project_key", "sample"],
"Catalog": [],
"CreateDatasetParam": [],
"CreateSubDataset": [],
"CreateSubDatasetParam": [],
"DataFrame": [],
Expand Down