@@ -169,26 +169,27 @@ def __init__(
169169 self .tags = self .populate_tags (tags )
170170
171171 parameters = self .extract_parameters (callable_obj )
172+ for param in parameters :
173+ param .default = serialize_parameter (param .default )
172174
173- self .args = {
174- parameter .name : FunctionArgument (
175- name = parameter .name ,
176- type = extract_optional (parameter .annotation ).__qualname__ ,
177- optional = parameter .default != inspect .Parameter .empty ,
178- default = serialize_parameter (parameter .default ),
179- argOrder = idx ,
180- )
181- for idx , parameter in enumerate (parameters .values ())
182- if name != "self"
183- }
175+ self .args = {param .name : param for param in parameters }
184176
185- def extract_parameters (self , callable_obj ):
177+ def extract_parameters (self , callable_obj ) -> List [ FunctionArgument ] :
186178 if inspect .isclass (callable_obj ):
187179 parameters = list (inspect .signature (callable_obj .__init__ ).parameters .values ())[1 :]
188180 else :
189181 parameters = list (inspect .signature (callable_obj ).parameters .values ())
190182
191- return parameters
183+ return [
184+ FunctionArgument (
185+ name = parameter .name ,
186+ type = extract_optional (parameter .annotation ).__qualname__ ,
187+ optional = parameter .default != inspect .Parameter .empty ,
188+ default = parameter .default ,
189+ argOrder = idx ,
190+ )
191+ for idx , parameter in enumerate (parameters )
192+ ]
192193
193194 @staticmethod
194195 def extract_module_doc (func_doc ):
@@ -293,10 +294,8 @@ def __init__(
293294 super ().__init__ (callable_obj , name , tags , version , type )
294295 self .debug_description = debug_description
295296
296- def extract_parameters (self , callable_obj ):
297- parameters = unknown_annotations_to_kwargs (CallableMeta .extract_parameters (self , callable_obj ))
298-
299- return {p .name : p for p in parameters }
297+ def extract_parameters (self , callable_obj ) -> List [FunctionArgument ]:
298+ return unknown_annotations_to_kwargs (CallableMeta .extract_parameters (self , callable_obj ))
300299
301300 def to_json (self ):
302301 json = super ().to_json ()
@@ -346,10 +345,8 @@ def __init__(
346345 else :
347346 self .column_type = None
348347
349- def extract_parameters (self , callable_obj ):
350- parameters = unknown_annotations_to_kwargs (CallableMeta .extract_parameters (self , callable_obj )[1 :])
351-
352- return {p .name : p for p in parameters }
348+ def extract_parameters (self , callable_obj ) -> List [FunctionArgument ]:
349+ return unknown_annotations_to_kwargs (CallableMeta .extract_parameters (self , callable_obj )[1 :])
353350
354351 def to_json (self ):
355352 json = super ().to_json ()
@@ -373,25 +370,37 @@ def init_from_json(self, json: Dict[str, Any]):
373370SMT = TypeVar ("SMT" , bound = SavableMeta )
374371
375372
376- def unknown_annotations_to_kwargs (parameters : List [inspect . Parameter ]) -> List [inspect . Parameter ]:
373+ def unknown_annotations_to_kwargs (parameters : List [FunctionArgument ]) -> List [FunctionArgument ]:
377374 from giskard .models .base import BaseModel
378375 from giskard .datasets .base import Dataset
379376 from giskard .ml_worker .testing .registry .slicing_function import SlicingFunction
380377 from giskard .ml_worker .testing .registry .transformation_function import TransformationFunction
381378
382379 allowed_types = [str , bool , int , float , BaseModel , Dataset , SlicingFunction , TransformationFunction ]
383- allowed_types = allowed_types + list (map (lambda x : Optional [ x ] , allowed_types ))
380+ allowed_types = list (map (lambda x : x . __qualname__ , allowed_types ))
384381
385- has_kwargs = any (
386- [param for param in parameters if not any ([param .annotation == allowed_type for allowed_type in allowed_types ])]
387- )
382+ kwargs = [param for param in parameters if not any ([param .type == allowed_type for allowed_type in allowed_types ])]
388383
389- parameters = [
390- param for param in parameters if any ([param .annotation == allowed_type for allowed_type in allowed_types ])
391- ]
384+ parameters = [param for param in parameters if any ([param .type == allowed_type for allowed_type in allowed_types ])]
392385
393- if has_kwargs :
394- parameters .append (inspect .Parameter (name = "kwargs" , kind = 4 , annotation = Kwargs ))
386+ for idx , parameter in enumerate (parameters ):
387+ parameter .argOrder = idx
388+
389+ if any (kwargs ) > 0 :
390+ kwargs_with_default = [param for param in kwargs if param .default != inspect .Parameter .empty ]
391+ default_value = (
392+ dict ({param .name : param .default for param in kwargs_with_default }) if any (kwargs_with_default ) else None
393+ )
394+
395+ parameters .append (
396+ FunctionArgument (
397+ name = "kwargs" ,
398+ type = "Kwargs" ,
399+ default = default_value ,
400+ optional = len (kwargs_with_default ) == len (kwargs ),
401+ argOrder = len (parameters ),
402+ )
403+ )
395404
396405 return parameters
397406
0 commit comments