Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions giskard/registry/slicing_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Dict, List, Optional, Set, Type, Union

import copy
import functools
import inspect
from pathlib import Path
Expand Down Expand Up @@ -73,13 +74,15 @@ def __init__(self, func: Optional[SlicingFunctionType], row_level=True, cell_lev
super().__init__(meta)

def __call__(self, *args, **kwargs) -> "SlicingFunction":
self.is_initialized = True
self.params = kwargs
instance = copy.deepcopy(self)

instance.is_initialized = True
instance.params = kwargs

for idx, arg in enumerate(args):
self.params[next(iter([arg.name for arg in self.meta.args.values() if arg.argOrder == idx]))] = arg
instance.params[next(iter([arg.name for arg in instance.meta.args.values() if arg.argOrder == idx]))] = arg

return self
return instance

@property
def dependencies(self) -> Set[Artifact]:
Expand Down
11 changes: 7 additions & 4 deletions giskard/registry/transformation_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, List, Optional, Set, Type, Union

import copy
import functools
import inspect

Expand Down Expand Up @@ -49,13 +50,15 @@ def __init__(
super().__init__(meta)

def __call__(self, *args, **kwargs) -> "TransformationFunction":
self.is_initialized = True
self.params = kwargs
instance = copy.deepcopy(self)

instance.is_initialized = True
instance.params = kwargs

for idx, arg in enumerate(args):
self.params[next(iter([arg.name for arg in self.meta.args.values() if arg.argOrder == idx]))] = arg
instance.params[next(iter([arg.name for arg in instance.meta.args.values() if arg.argOrder == idx]))] = arg

return self
return instance

@property
def dependencies(self) -> Set[Artifact]:
Expand Down
36 changes: 36 additions & 0 deletions tests/test_data_processing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,39 @@ def add_positive_sentence(row):
transformed_dataset = dataset.transform(add_positive_sentence)

assert transformed_dataset.df.iloc[0].text == "testing. I love this!"


def test_slicing_function_multiple_instances():
@slicing_function(name="slice cell level", cell_level=True)
def filter_cell_level_by(amount: int, min_value: int) -> bool:
return amount >= min_value

df = pd.DataFrame({"quantity": [1, 2, 3, 5, 7, 11, 13]})
dataset = Dataset(df, cat_columns=[])

min_five = filter_cell_level_by(min_value=5)
min_six = filter_cell_level_by(min_value=6)

dataset_greater_equals_five = dataset.slice(min_five, column_name="quantity")
dataset_greater_equals_six = dataset.slice(min_six, column_name="quantity")

assert list(dataset_greater_equals_five.df["quantity"]) == [5, 7, 11, 13]
assert list(dataset_greater_equals_six.df["quantity"]) == [7, 11, 13]


def test_transformation_multiple_instances():
@transformation_function(cell_level=True)
def column_level_divide(nb: float, amount: int) -> float:
return nb / amount

df = pd.DataFrame({"quantity": [100, 200, 300]})
dataset = Dataset(df, cat_columns=[])

divide_by_ten = column_level_divide(amount=10)
divide_by_a_hundred = column_level_divide(amount=100)

dataset_by_ten = dataset.transform(divide_by_ten, column_name="quantity")
dataset_by_a_hundred = dataset.transform(divide_by_a_hundred, column_name="quantity")

assert list(dataset_by_ten.df["quantity"]) == [10, 20, 30]
assert list(dataset_by_a_hundred.df["quantity"]) == [1, 2, 3]