Skip to content

Commit 348233c

Browse files
[GSK-2346] Avoid copying whole dataset when doing slicing (#1673)
Avoid copying whole dataset when doing slicing Co-authored-by: Kevin Messiaen <[email protected]>
1 parent 4cce93c commit 348233c

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

giskard/datasets/base/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Dict, Hashable, List, Optional, Union
2+
13
import inspect
24
import logging
35
import posixpath
@@ -12,14 +14,13 @@
1214
import yaml
1315
from mlflow import MlflowClient
1416
from pandas.api.types import is_list_like, is_numeric_dtype
15-
from typing import Dict, Hashable, List, Optional, Union
1617
from xxhash import xxh3_128_hexdigest
1718
from zstandard import ZstdDecompressor
1819

1920
from giskard.client.giskard_client import GiskardClient
2021
from giskard.client.io_utils import compress, save_df
2122
from giskard.client.python_utils import warning
22-
from giskard.core.core import DatasetMeta, SupportedColumnTypes, NOT_GIVEN, NotGivenOr
23+
from giskard.core.core import NOT_GIVEN, DatasetMeta, NotGivenOr, SupportedColumnTypes
2324
from giskard.core.errors import GiskardImportError
2425
from giskard.core.validation import configured_validate_arguments
2526
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction, SlicingFunctionType
@@ -28,8 +29,9 @@
2829
TransformationFunctionType,
2930
)
3031
from giskard.settings import settings
31-
from ..metadata.indexing import ColumnMetadataMixin
32+
3233
from ...ml_worker.utils.file_utils import get_file_name
34+
from ..metadata.indexing import ColumnMetadataMixin
3335

3436
try:
3537
import wandb # noqa
@@ -77,8 +79,11 @@ def add_step(self, processor: Union[SlicingFunction, TransformationFunction]):
7779
self.pipeline.append(processor)
7880
return self
7981

80-
def apply(self, dataset: "Dataset", apply_only_last=False, get_mask: bool = False):
81-
ds = dataset.copy()
82+
def apply(self, dataset: "Dataset", apply_only_last=False, get_mask: bool = False, copy: bool = True):
83+
if copy:
84+
ds = dataset.copy()
85+
else:
86+
ds = dataset
8287
is_slicing_only = True
8388

8489
while len(self.pipeline):
@@ -330,7 +335,9 @@ def slice(
330335
**{key: value for key, value in slicing_function.params.items() if key != "column_name"},
331336
)
332337

333-
return self.data_processor.add_step(slicing_function).apply(self, apply_only_last=True, get_mask=get_mask)
338+
return self.data_processor.add_step(slicing_function).apply(
339+
self, apply_only_last=True, get_mask=get_mask, copy=False
340+
)
334341

335342
@configured_validate_arguments
336343
def transform(

0 commit comments

Comments
 (0)