Skip to content

Commit 496a2b4

Browse files
kevinmessiaenHartorn
authored andcommitted
Made kwargs default value persistent in the Hub
1 parent 4abcaa5 commit 496a2b4

File tree

3 files changed

+61
-37
lines changed

3 files changed

+61
-37
lines changed

giskard/core/core.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -169,26 +169,27 @@ def __init__(
169169
self.tags = self.populate_tags(tags)
170170

171171
parameters = self.extract_parameters(callable_obj)
172+
for param in parameters:
173+
param.default = serialize_parameter(param.default)
172174

173-
self.args = {
174-
parameter.name: FunctionArgument(
175-
name=parameter.name,
176-
type=extract_optional(parameter.annotation).__qualname__,
177-
optional=parameter.default != inspect.Parameter.empty,
178-
default=serialize_parameter(parameter.default),
179-
argOrder=idx,
180-
)
181-
for idx, parameter in enumerate(parameters.values())
182-
if name != "self"
183-
}
175+
self.args = {param.name: param for param in parameters}
184176

185-
def extract_parameters(self, callable_obj):
177+
def extract_parameters(self, callable_obj) -> List[FunctionArgument]:
186178
if inspect.isclass(callable_obj):
187179
parameters = list(inspect.signature(callable_obj.__init__).parameters.values())[1:]
188180
else:
189181
parameters = list(inspect.signature(callable_obj).parameters.values())
190182

191-
return parameters
183+
return [
184+
FunctionArgument(
185+
name=parameter.name,
186+
type=extract_optional(parameter.annotation).__qualname__,
187+
optional=parameter.default != inspect.Parameter.empty,
188+
default=parameter.default,
189+
argOrder=idx,
190+
)
191+
for idx, parameter in enumerate(parameters)
192+
]
192193

193194
@staticmethod
194195
def extract_module_doc(func_doc):
@@ -293,10 +294,8 @@ def __init__(
293294
super().__init__(callable_obj, name, tags, version, type)
294295
self.debug_description = debug_description
295296

296-
def extract_parameters(self, callable_obj):
297-
parameters = unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj))
298-
299-
return {p.name: p for p in parameters}
297+
def extract_parameters(self, callable_obj) -> List[FunctionArgument]:
298+
return unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj))
300299

301300
def to_json(self):
302301
json = super().to_json()
@@ -346,10 +345,8 @@ def __init__(
346345
else:
347346
self.column_type = None
348347

349-
def extract_parameters(self, callable_obj):
350-
parameters = unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj)[1:])
351-
352-
return {p.name: p for p in parameters}
348+
def extract_parameters(self, callable_obj) -> List[FunctionArgument]:
349+
return unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj)[1:])
353350

354351
def to_json(self):
355352
json = super().to_json()
@@ -373,25 +370,37 @@ def init_from_json(self, json: Dict[str, Any]):
373370
SMT = TypeVar("SMT", bound=SavableMeta)
374371

375372

376-
def unknown_annotations_to_kwargs(parameters: List[inspect.Parameter]) -> List[inspect.Parameter]:
373+
def unknown_annotations_to_kwargs(parameters: List[FunctionArgument]) -> List[FunctionArgument]:
377374
from giskard.models.base import BaseModel
378375
from giskard.datasets.base import Dataset
379376
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction
380377
from giskard.ml_worker.testing.registry.transformation_function import TransformationFunction
381378

382379
allowed_types = [str, bool, int, float, BaseModel, Dataset, SlicingFunction, TransformationFunction]
383-
allowed_types = allowed_types + list(map(lambda x: Optional[x], allowed_types))
380+
allowed_types = list(map(lambda x: x.__qualname__, allowed_types))
384381

385-
has_kwargs = any(
386-
[param for param in parameters if not any([param.annotation == allowed_type for allowed_type in allowed_types])]
387-
)
382+
kwargs = [param for param in parameters if not any([param.type == allowed_type for allowed_type in allowed_types])]
388383

389-
parameters = [
390-
param for param in parameters if any([param.annotation == allowed_type for allowed_type in allowed_types])
391-
]
384+
parameters = [param for param in parameters if any([param.type == allowed_type for allowed_type in allowed_types])]
392385

393-
if has_kwargs:
394-
parameters.append(inspect.Parameter(name="kwargs", kind=4, annotation=Kwargs))
386+
for idx, parameter in enumerate(parameters):
387+
parameter.argOrder = idx
388+
389+
if any(kwargs) > 0:
390+
kwargs_with_default = [param for param in kwargs if param.default != inspect.Parameter.empty]
391+
default_value = (
392+
dict({param.name: param.default for param in kwargs_with_default}) if any(kwargs_with_default) else None
393+
)
394+
395+
parameters.append(
396+
FunctionArgument(
397+
name="kwargs",
398+
type="Kwargs",
399+
default=default_value,
400+
optional=len(kwargs_with_default) == len(kwargs),
401+
argOrder=len(parameters),
402+
)
403+
)
395404

396405
return parameters
397406

giskard/testing/tests/performance.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Performance tests"""
2-
import inspect
32
from typing import Optional
43

4+
import inspect
5+
56
import numpy as np
67
import pandas as pd
78
from sklearn.metrics import (
@@ -22,7 +23,10 @@
2223
from giskard.ml_worker.testing.utils import Direction, check_slice_not_empty
2324
from giskard.models.base import BaseModel
2425
from giskard.models.utils import np_type_to_native
25-
from giskard.testing.tests.debug_slicing_functions import incorrect_rows_slicing_fn, nlargest_abs_err_rows_slicing_fn
26+
from giskard.testing.tests.debug_slicing_functions import (
27+
incorrect_rows_slicing_fn,
28+
nlargest_abs_err_rows_slicing_fn,
29+
)
2630

2731
from . import debug_description_prefix, debug_prefix
2832

@@ -149,11 +153,11 @@ def _test_diff_prediction(
149153
" reference_dataset is equal to zero"
150154
)
151155

152-
if direction == Direction.Invariant:
156+
if direction == Direction.Invariant or direction == Direction.Invariant.value:
153157
passed = abs(rel_change) < threshold
154-
elif direction == Direction.Decreasing:
158+
elif direction == Direction.Decreasing or direction == Direction.Decreasing.value:
155159
passed = rel_change < threshold
156-
elif direction == Direction.Increasing:
160+
elif direction == Direction.Increasing or direction == Direction.Increasing.value:
157161
passed = rel_change > threshold
158162
else:
159163
raise ValueError(f"Invalid direction: {direction}")

giskard/utils/artifacts.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
import uuid
3-
from typing import Any, Optional, Union
3+
from enum import Enum
4+
from typing import Any, Optional, Union, Dict
45

56
try:
67
from types import NoneType
@@ -18,13 +19,23 @@ def _serialize_artifact(artifact, artifact_uuid: Optional[Union[str, uuid.UUID]]
1819
return str(artifact_uuid)
1920

2021

22+
def repr_parameter(value: Any) -> str:
23+
if isinstance(value, Enum):
24+
return repr(value.value)
25+
26+
return repr(value)
27+
28+
2129
def serialize_parameter(default_value: Any) -> PRIMITIVES:
2230
if default_value == inspect.Parameter.empty:
2331
return None
2432

2533
if isinstance(default_value, PRIMITIVES.__args__):
2634
return default_value
2735

36+
if isinstance(default_value, Dict):
37+
return "\n".join(f"kwargs[{repr(key)}] = {repr_parameter(value)}" for key, value in default_value.items())
38+
2839
from ..ml_worker.core.savable import Artifact
2940

3041
if isinstance(default_value, Artifact):

0 commit comments

Comments
 (0)