diff --git a/giskard/registry/slicing_function.py b/giskard/registry/slicing_function.py index 6eae148e2b..63b4b7075c 100644 --- a/giskard/registry/slicing_function.py +++ b/giskard/registry/slicing_function.py @@ -1,5 +1,6 @@ from typing import Callable, Dict, List, Optional, Set, Type, Union +import copy import functools import inspect from pathlib import Path @@ -73,13 +74,15 @@ def __init__(self, func: Optional[SlicingFunctionType], row_level=True, cell_lev super().__init__(meta) def __call__(self, *args, **kwargs) -> "SlicingFunction": - self.is_initialized = True - self.params = kwargs + instance = copy.deepcopy(self) + + instance.is_initialized = True + instance.params = kwargs for idx, arg in enumerate(args): - self.params[next(iter([arg.name for arg in self.meta.args.values() if arg.argOrder == idx]))] = arg + instance.params[next(iter([arg.name for arg in instance.meta.args.values() if arg.argOrder == idx]))] = arg - return self + return instance @property def dependencies(self) -> Set[Artifact]: diff --git a/giskard/registry/transformation_function.py b/giskard/registry/transformation_function.py index c75a1285e3..ec0bd2e4dd 100644 --- a/giskard/registry/transformation_function.py +++ b/giskard/registry/transformation_function.py @@ -1,5 +1,6 @@ from typing import Callable, List, Optional, Set, Type, Union +import copy import functools import inspect @@ -49,13 +50,15 @@ def __init__( super().__init__(meta) def __call__(self, *args, **kwargs) -> "TransformationFunction": - self.is_initialized = True - self.params = kwargs + instance = copy.deepcopy(self) + + instance.is_initialized = True + instance.params = kwargs for idx, arg in enumerate(args): - self.params[next(iter([arg.name for arg in self.meta.args.values() if arg.argOrder == idx]))] = arg + instance.params[next(iter([arg.name for arg in instance.meta.args.values() if arg.argOrder == idx]))] = arg - return self + return instance @property def dependencies(self) -> Set[Artifact]: diff --git a/tests/test_data_processing_pipeline.py b/tests/test_data_processing_pipeline.py index 0fd56d3d69..4843d6e443 100644 --- a/tests/test_data_processing_pipeline.py +++ b/tests/test_data_processing_pipeline.py @@ -193,3 +193,39 @@ def add_positive_sentence(row): transformed_dataset = dataset.transform(add_positive_sentence) assert transformed_dataset.df.iloc[0].text == "testing. I love this!" + + +def test_slicing_function_multiple_instances(): + @slicing_function(name="slice cell level", cell_level=True) + def filter_cell_level_by(amount: int, min_value: int) -> bool: + return amount >= min_value + + df = pd.DataFrame({"quantity": [1, 2, 3, 5, 7, 11, 13]}) + dataset = Dataset(df, cat_columns=[]) + + min_five = filter_cell_level_by(min_value=5) + min_six = filter_cell_level_by(min_value=6) + + dataset_greater_equals_five = dataset.slice(min_five, column_name="quantity") + dataset_greater_equals_six = dataset.slice(min_six, column_name="quantity") + + assert list(dataset_greater_equals_five.df["quantity"]) == [5, 7, 11, 13] + assert list(dataset_greater_equals_six.df["quantity"]) == [7, 11, 13] + + +def test_transformation_multiple_instances(): + @transformation_function(cell_level=True) + def column_level_divide(nb: float, amount: int) -> float: + return nb / amount + + df = pd.DataFrame({"quantity": [100, 200, 300]}) + dataset = Dataset(df, cat_columns=[]) + + divide_by_ten = column_level_divide(amount=10) + divide_by_a_hundred = column_level_divide(amount=100) + + dataset_by_ten = dataset.transform(divide_by_ten, column_name="quantity") + dataset_by_a_hundred = dataset.transform(divide_by_a_hundred, column_name="quantity") + + assert list(dataset_by_ten.df["quantity"]) == [10, 20, 30] + assert list(dataset_by_a_hundred.df["quantity"]) == [1, 2, 3]