|
| 1 | +from typing import Dict, Hashable, List, Optional, Union |
| 2 | + |
1 | 3 | import inspect |
2 | 4 | import logging |
3 | 5 | import posixpath |
|
12 | 14 | import yaml |
13 | 15 | from mlflow import MlflowClient |
14 | 16 | from pandas.api.types import is_list_like, is_numeric_dtype |
15 | | -from typing import Dict, Hashable, List, Optional, Union |
16 | 17 | from xxhash import xxh3_128_hexdigest |
17 | 18 | from zstandard import ZstdDecompressor |
18 | 19 |
|
19 | 20 | from giskard.client.giskard_client import GiskardClient |
20 | 21 | from giskard.client.io_utils import compress, save_df |
21 | 22 | from giskard.client.python_utils import warning |
22 | | -from giskard.core.core import DatasetMeta, SupportedColumnTypes, NOT_GIVEN, NotGivenOr |
| 23 | +from giskard.core.core import NOT_GIVEN, DatasetMeta, NotGivenOr, SupportedColumnTypes |
23 | 24 | from giskard.core.errors import GiskardImportError |
24 | 25 | from giskard.core.validation import configured_validate_arguments |
25 | 26 | from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction, SlicingFunctionType |
|
28 | 29 | TransformationFunctionType, |
29 | 30 | ) |
30 | 31 | from giskard.settings import settings |
31 | | -from ..metadata.indexing import ColumnMetadataMixin |
| 32 | + |
32 | 33 | from ...ml_worker.utils.file_utils import get_file_name |
| 34 | +from ..metadata.indexing import ColumnMetadataMixin |
33 | 35 |
|
34 | 36 | try: |
35 | 37 | import wandb # noqa |
@@ -77,8 +79,11 @@ def add_step(self, processor: Union[SlicingFunction, TransformationFunction]): |
77 | 79 | self.pipeline.append(processor) |
78 | 80 | return self |
79 | 81 |
|
80 | | - def apply(self, dataset: "Dataset", apply_only_last=False, get_mask: bool = False): |
81 | | - ds = dataset.copy() |
| 82 | + def apply(self, dataset: "Dataset", apply_only_last=False, get_mask: bool = False, copy: bool = True): |
| 83 | + if copy: |
| 84 | + ds = dataset.copy() |
| 85 | + else: |
| 86 | + ds = dataset |
82 | 87 | is_slicing_only = True |
83 | 88 |
|
84 | 89 | while len(self.pipeline): |
@@ -330,7 +335,9 @@ def slice( |
330 | 335 | **{key: value for key, value in slicing_function.params.items() if key != "column_name"}, |
331 | 336 | ) |
332 | 337 |
|
333 | | - return self.data_processor.add_step(slicing_function).apply(self, apply_only_last=True, get_mask=get_mask) |
| 338 | + return self.data_processor.add_step(slicing_function).apply( |
| 339 | + self, apply_only_last=True, get_mask=get_mask, copy=False |
| 340 | + ) |
334 | 341 |
|
335 | 342 | @configured_validate_arguments |
336 | 343 | def transform( |
|
0 commit comments