@@ -128,6 +128,12 @@ def FindParsingFunctionFromAttributeType(atype):
128128"""
129129CALL_PRE_PROCESS_TEMPLATE = """ {};
130130"""
131+ PARAMS_DECLARE_TEMPLE = """ {type} {name};\n """
132+ CALL_ARGS_MAPPER_TEMPLATE = """ {func_name}(args,kwargs{params});
133+ """
134+ DISABLE_TIPS = (
135+ " // This part of the function will be performed by a custom args mapper"
136+ )
131137RECORD_EVENT_TEMPLATE = (
132138 'phi::RecordEvent {}("{} {}", phi::TracerEventType::UserDefined, 1);'
133139)
@@ -152,6 +158,10 @@ def FindParsingFunctionFromAttributeType(atype):
152158 // Parse Attributes if needed
153159{}
154160 // Check Reminding Params validity if needed
161+ {}
162+ // Custom Args Mapper if need
163+ {}
164+ // Convert to Dist
155165{}
156166 // Call Pre_Process before calling dygraph function if needed
157167{}
@@ -234,7 +244,7 @@ def FindParsingFunctionFromAttributeType(atype):
234244#include "paddle/fluid/pybind/eager.h"
235245#include "paddle/fluid/pybind/eager_op_function.h"
236246#include "paddle/fluid/pybind/arg_pre_process.h"
237-
247+ #include "paddle/fluid/pybind/args_mapper.h"
238248namespace paddle {{
239249namespace pybind {{
240250
@@ -384,10 +394,10 @@ def GeneratePythonCFunction(self, no_input_out_tensor=False):
384394 forward_inputs_position_map
385395 )
386396 dygraph_pre_process = self .dygraph_pre_process
387-
397+ args_mapper_func = self . args_mapper_func_name
388398 inplace_args_pos_map = {}
389399 inplace_returns_pos_map = {}
390- get_params_nums_and_check_str = "// NO NEED"
400+ get_params_nums_and_check_str = " // NO NEED"
391401 if need_parse_python_api_args :
392402 get_params_nums_and_check_str = (
393403 PARSE_PYTHON_C_NUM_ARGS_TEMPLATE .format (max_args )
@@ -480,52 +490,7 @@ def _get_keywords(name, alias_map):
480490 keywords ,
481491 "false" ,
482492 )
483- # No inputs, skip convert to DistTensor
484- if len (input_names ) > 0 :
485- optional_and_vector_convert_code = ""
486- for name , (ttype , pos ) in forward_inputs_position_map .items ():
487- is_optional = name in optional_inputs
488- if IsVectorTensorType (ttype ):
489- if is_optional :
490- optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE .format (
491- name ,
492- "GetOptionalTensorListFromArgs" ,
493- forward_api_name ,
494- name ,
495- pos ,
496- "true" ,
497- )
498- else :
499- optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE .format (
500- name ,
501- "GetTensorListFromArgs" ,
502- forward_api_name ,
503- name ,
504- pos ,
505- "false" ,
506- )
507- else :
508- if is_optional :
509- optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE .format (
510- name ,
511- "GetOptionalTensorFromArgs" ,
512- forward_api_name ,
513- name ,
514- pos ,
515- "true" ,
516- )
517493
518- if len (input_single_tensor_names ) > 0 :
519- get_eager_tensor_str += CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITH_SINGLE_TENSOR_TEMPLATE .format (
520- input_names = input_names ,
521- input_single_tensor_names = input_single_tensor_names ,
522- optional_and_vector_convert_code = optional_and_vector_convert_code ,
523- )
524- else :
525- get_eager_tensor_str += CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITHOUT_SINGLE_TENSOR_TEMPLATE .format (
526- input_names = input_names ,
527- optional_and_vector_convert_code = optional_and_vector_convert_code ,
528- )
529494 if forward_inplace_map :
530495 for name , (ttype , pos ) in forward_outputs_position_map .items ():
531496 if name in forward_inplace_map .values ():
@@ -593,7 +558,7 @@ def _get_keywords(name, alias_map):
593558 check_remaining_params_validity_str = (
594559 CHECK_REMAINING_ARGS_VALID_TEMPLATE
595560 )
596- pre_process_str = " //NO NEED"
561+ pre_process_str = " // NO NEED"
597562 if need_parse_python_api_args and len (dygraph_pre_process ) > 0 :
598563
599564 def pre_process_add_ampersand (s ):
@@ -602,6 +567,77 @@ def pre_process_add_ampersand(s):
602567 pre_process_str = CALL_PRE_PROCESS_TEMPLATE .format (
603568 pre_process_add_ampersand (dygraph_pre_process )
604569 )
570+ args_mapper_str = " // NO NEED"
571+ if args_mapper_func is not None :
572+ all_params_list = []
573+ args_mapper_str = ""
574+ for name , (ttype , pos ) in forward_inputs_position_map .items ():
575+ args_mapper_str += PARAMS_DECLARE_TEMPLE .format (
576+ type = ttype , name = name
577+ )
578+ all_params_list .append (name )
579+ for name , atype , default_value , pos in orig_forward_attrs_list :
580+ args_mapper_str += PARAMS_DECLARE_TEMPLE .format (
581+ type = atype , name = name
582+ )
583+ all_params_list .append (name )
584+ params = ',&' + ',&' .join (all_params_list )
585+ args_mapper_str += CALL_ARGS_MAPPER_TEMPLATE .format (
586+ func_name = args_mapper_func , params = params
587+ )
588+ # disable the generated args parser
589+ get_params_nums_and_check_str = DISABLE_TIPS
590+ get_eager_tensor_str = DISABLE_TIPS
591+ parse_attributes_str = DISABLE_TIPS
592+ check_remaining_params_validity_str = DISABLE_TIPS
593+
594+ convert_to_dist_str = ""
595+ # No inputs, skip convert to DistTensor
596+ if len (input_names ) > 0 :
597+ optional_and_vector_convert_code = ""
598+ for name , (ttype , pos ) in forward_inputs_position_map .items ():
599+ is_optional = name in optional_inputs
600+ if IsVectorTensorType (ttype ):
601+ if is_optional :
602+ optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE .format (
603+ name ,
604+ "GetOptionalTensorListFromArgs" ,
605+ forward_api_name ,
606+ name ,
607+ pos ,
608+ "true" ,
609+ )
610+ else :
611+ optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE .format (
612+ name ,
613+ "GetTensorListFromArgs" ,
614+ forward_api_name ,
615+ name ,
616+ pos ,
617+ "false" ,
618+ )
619+ else :
620+ if is_optional :
621+ optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE .format (
622+ name ,
623+ "GetOptionalTensorFromArgs" ,
624+ forward_api_name ,
625+ name ,
626+ pos ,
627+ "true" ,
628+ )
629+ if len (input_single_tensor_names ) > 0 :
630+ convert_to_dist_str += CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITH_SINGLE_TENSOR_TEMPLATE .format (
631+ input_names = input_names ,
632+ input_single_tensor_names = input_single_tensor_names ,
633+ optional_and_vector_convert_code = optional_and_vector_convert_code ,
634+ )
635+ else :
636+ convert_to_dist_str += CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITHOUT_SINGLE_TENSOR_TEMPLATE .format (
637+ input_names = input_names ,
638+ optional_and_vector_convert_code = optional_and_vector_convert_code ,
639+ )
640+
605641 set_device_str = FUNCTION_SET_DEVICE_TEMPLATE .format (expected_place_str )
606642
607643 # Generate Dygraph Function Call Logic
@@ -658,6 +694,8 @@ def pre_process_add_ampersand(s):
658694 get_eager_tensor_str ,
659695 parse_attributes_str ,
660696 check_remaining_params_validity_str ,
697+ args_mapper_str ,
698+ convert_to_dist_str ,
661699 pre_process_str ,
662700 get_input_out_str ,
663701 set_device_str ,
@@ -720,6 +758,8 @@ def pre_process_add_ampersand(s):
720758 get_eager_tensor_str ,
721759 parse_attributes_str ,
722760 check_remaining_params_validity_str ,
761+ args_mapper_str ,
762+ convert_to_dist_str ,
723763 pre_process_str ,
724764 "" ,
725765 set_device_str ,
0 commit comments