Skip to content

Commit d8cd08d

Browse files
Merge pull request #1979 from Giskard-AI/fix/slicing-function-params
Fixed issue with slicing function and transformation function instances being shared
2 parents 6334b41 + 4da44fb commit d8cd08d

File tree

3 files changed

+50
-8
lines changed

3 files changed

+50
-8
lines changed

giskard/registry/slicing_function.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Callable, Dict, List, Optional, Set, Type, Union
22

3+
import copy
34
import functools
45
import inspect
56
from pathlib import Path
@@ -73,13 +74,15 @@ def __init__(self, func: Optional[SlicingFunctionType], row_level=True, cell_lev
7374
super().__init__(meta)
7475

7576
def __call__(self, *args, **kwargs) -> "SlicingFunction":
76-
self.is_initialized = True
77-
self.params = kwargs
77+
instance = copy.deepcopy(self)
78+
79+
instance.is_initialized = True
80+
instance.params = kwargs
7881

7982
for idx, arg in enumerate(args):
80-
self.params[next(iter([arg.name for arg in self.meta.args.values() if arg.argOrder == idx]))] = arg
83+
instance.params[next(iter([arg.name for arg in instance.meta.args.values() if arg.argOrder == idx]))] = arg
8184

82-
return self
85+
return instance
8386

8487
@property
8588
def dependencies(self) -> Set[Artifact]:

giskard/registry/transformation_function.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Callable, List, Optional, Set, Type, Union
22

3+
import copy
34
import functools
45
import inspect
56

@@ -49,13 +50,15 @@ def __init__(
4950
super().__init__(meta)
5051

5152
def __call__(self, *args, **kwargs) -> "TransformationFunction":
52-
self.is_initialized = True
53-
self.params = kwargs
53+
instance = copy.deepcopy(self)
54+
55+
instance.is_initialized = True
56+
instance.params = kwargs
5457

5558
for idx, arg in enumerate(args):
56-
self.params[next(iter([arg.name for arg in self.meta.args.values() if arg.argOrder == idx]))] = arg
59+
instance.params[next(iter([arg.name for arg in instance.meta.args.values() if arg.argOrder == idx]))] = arg
5760

58-
return self
61+
return instance
5962

6063
@property
6164
def dependencies(self) -> Set[Artifact]:

tests/test_data_processing_pipeline.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,39 @@ def add_positive_sentence(row):
193193
transformed_dataset = dataset.transform(add_positive_sentence)
194194

195195
assert transformed_dataset.df.iloc[0].text == "testing. I love this!"
196+
197+
198+
def test_slicing_function_multiple_instances():
199+
@slicing_function(name="slice cell level", cell_level=True)
200+
def filter_cell_level_by(amount: int, min_value: int) -> bool:
201+
return amount >= min_value
202+
203+
df = pd.DataFrame({"quantity": [1, 2, 3, 5, 7, 11, 13]})
204+
dataset = Dataset(df, cat_columns=[])
205+
206+
min_five = filter_cell_level_by(min_value=5)
207+
min_six = filter_cell_level_by(min_value=6)
208+
209+
dataset_greater_equals_five = dataset.slice(min_five, column_name="quantity")
210+
dataset_greater_equals_six = dataset.slice(min_six, column_name="quantity")
211+
212+
assert list(dataset_greater_equals_five.df["quantity"]) == [5, 7, 11, 13]
213+
assert list(dataset_greater_equals_six.df["quantity"]) == [7, 11, 13]
214+
215+
216+
def test_transformation_multiple_instances():
217+
@transformation_function(cell_level=True)
218+
def column_level_divide(nb: float, amount: int) -> float:
219+
return nb / amount
220+
221+
df = pd.DataFrame({"quantity": [100, 200, 300]})
222+
dataset = Dataset(df, cat_columns=[])
223+
224+
divide_by_ten = column_level_divide(amount=10)
225+
divide_by_a_hundred = column_level_divide(amount=100)
226+
227+
dataset_by_ten = dataset.transform(divide_by_ten, column_name="quantity")
228+
dataset_by_a_hundred = dataset.transform(divide_by_a_hundred, column_name="quantity")
229+
230+
assert list(dataset_by_ten.df["quantity"]) == [10, 20, 30]
231+
assert list(dataset_by_a_hundred.df["quantity"]) == [1, 2, 3]

0 commit comments

Comments
 (0)