Skip to content

Commit 632615c

Browse files
authored
[API Compatiblity] Support the Args Mapper mechanism when the Python API is integrated into the C++ layer (#74750)
* support args mapper * format * fix none * add test time out
1 parent 2f156c7 commit 632615c

File tree

8 files changed

+256
-78
lines changed

8 files changed

+256
-78
lines changed

paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -499,10 +499,8 @@ def __init__(self, forward_api_contents, namespace):
499499
self.dygraph_pre_process = (
500500
"" # The pre_process function calling code for dygraph
501501
)
502-
self.static_pre_process = (
503-
"" # The pre_process function calling code for static graph
504-
)
505-
self.args_parser_func_name = "" # The custom args parser function
502+
503+
self.args_mapper_func_name = None # The custom args parser function
506504
self.python_api_names = ""
507505

508506
def ParseForwardInplaceInfo(self):
@@ -535,20 +533,19 @@ def ParsePythonAPIInfo(self):
535533
self.args_alias_map = args_alias
536534
if 'pre_process' in python_api_info.keys():
537535
pre_process = python_api_info['pre_process']
538-
if 'func' in pre_process.keys():
539-
self.dygraph_pre_process = pre_process['func']
540-
self.static_pre_process = pre_process['func']
541-
# TODO check len(pre_process) > 1
542-
543-
if 'dygraph_func' in pre_process.keys():
544-
self.dygraph_pre_process = pre_process['dygraph_func']
545-
if 'static_func' in pre_process.keys():
546-
self.static_pre_process = pre_process['static_func']
547-
if (
548-
'args_parser' in python_api_info.keys()
549-
and 'func' in python_api_info['args_parser']
550-
):
551-
self.args_parser_func_name = python_api_info['args_parser']['func']
536+
if pre_process is not None:
537+
if 'dygraph_func' in pre_process.keys():
538+
self.dygraph_pre_process = pre_process['dygraph_func']
539+
elif 'func' in pre_process.keys():
540+
self.dygraph_pre_process = pre_process['func']
541+
542+
if 'args_mapper' in python_api_info.keys():
543+
args_mapper = python_api_info['args_mapper']
544+
if args_mapper is not None:
545+
if 'dygraph_func' in args_mapper.keys():
546+
self.args_mapper_func_name = args_mapper['dygraph_func']
547+
elif 'func' in args_mapper.keys():
548+
self.args_mapper_func_name = args_mapper['func']
552549

553550
def ParseNoNeedBuffer(self):
554551
grad_api_contents = self.grad_api_contents

paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py

Lines changed: 89 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ def FindParsingFunctionFromAttributeType(atype):
128128
"""
129129
CALL_PRE_PROCESS_TEMPLATE = """ {};
130130
"""
131+
PARAMS_DECLARE_TEMPLE = """ {type} {name};\n"""
132+
CALL_ARGS_MAPPER_TEMPLATE = """ {func_name}(args,kwargs{params});
133+
"""
134+
DISABLE_TIPS = (
135+
" // This part of the function will be performed by a custom args mapper"
136+
)
131137
RECORD_EVENT_TEMPLATE = (
132138
'phi::RecordEvent {}("{} {}", phi::TracerEventType::UserDefined, 1);'
133139
)
@@ -152,6 +158,10 @@ def FindParsingFunctionFromAttributeType(atype):
152158
// Parse Attributes if needed
153159
{}
154160
// Check Reminding Params validity if needed
161+
{}
162+
// Custom Args Mapper if need
163+
{}
164+
// Convert to Dist
155165
{}
156166
// Call Pre_Process before calling dygraph function if needed
157167
{}
@@ -234,7 +244,7 @@ def FindParsingFunctionFromAttributeType(atype):
234244
#include "paddle/fluid/pybind/eager.h"
235245
#include "paddle/fluid/pybind/eager_op_function.h"
236246
#include "paddle/fluid/pybind/arg_pre_process.h"
237-
247+
#include "paddle/fluid/pybind/args_mapper.h"
238248
namespace paddle {{
239249
namespace pybind {{
240250
@@ -384,10 +394,10 @@ def GeneratePythonCFunction(self, no_input_out_tensor=False):
384394
forward_inputs_position_map
385395
)
386396
dygraph_pre_process = self.dygraph_pre_process
387-
397+
args_mapper_func = self.args_mapper_func_name
388398
inplace_args_pos_map = {}
389399
inplace_returns_pos_map = {}
390-
get_params_nums_and_check_str = "// NO NEED"
400+
get_params_nums_and_check_str = " // NO NEED"
391401
if need_parse_python_api_args:
392402
get_params_nums_and_check_str = (
393403
PARSE_PYTHON_C_NUM_ARGS_TEMPLATE.format(max_args)
@@ -480,52 +490,7 @@ def _get_keywords(name, alias_map):
480490
keywords,
481491
"false",
482492
)
483-
# No inputs, skip convert to DistTensor
484-
if len(input_names) > 0:
485-
optional_and_vector_convert_code = ""
486-
for name, (ttype, pos) in forward_inputs_position_map.items():
487-
is_optional = name in optional_inputs
488-
if IsVectorTensorType(ttype):
489-
if is_optional:
490-
optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
491-
name,
492-
"GetOptionalTensorListFromArgs",
493-
forward_api_name,
494-
name,
495-
pos,
496-
"true",
497-
)
498-
else:
499-
optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
500-
name,
501-
"GetTensorListFromArgs",
502-
forward_api_name,
503-
name,
504-
pos,
505-
"false",
506-
)
507-
else:
508-
if is_optional:
509-
optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
510-
name,
511-
"GetOptionalTensorFromArgs",
512-
forward_api_name,
513-
name,
514-
pos,
515-
"true",
516-
)
517493

518-
if len(input_single_tensor_names) > 0:
519-
get_eager_tensor_str += CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITH_SINGLE_TENSOR_TEMPLATE.format(
520-
input_names=input_names,
521-
input_single_tensor_names=input_single_tensor_names,
522-
optional_and_vector_convert_code=optional_and_vector_convert_code,
523-
)
524-
else:
525-
get_eager_tensor_str += CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITHOUT_SINGLE_TENSOR_TEMPLATE.format(
526-
input_names=input_names,
527-
optional_and_vector_convert_code=optional_and_vector_convert_code,
528-
)
529494
if forward_inplace_map:
530495
for name, (ttype, pos) in forward_outputs_position_map.items():
531496
if name in forward_inplace_map.values():
@@ -593,7 +558,7 @@ def _get_keywords(name, alias_map):
593558
check_remaining_params_validity_str = (
594559
CHECK_REMAINING_ARGS_VALID_TEMPLATE
595560
)
596-
pre_process_str = " //NO NEED"
561+
pre_process_str = " // NO NEED"
597562
if need_parse_python_api_args and len(dygraph_pre_process) > 0:
598563

599564
def pre_process_add_ampersand(s):
@@ -602,6 +567,77 @@ def pre_process_add_ampersand(s):
602567
pre_process_str = CALL_PRE_PROCESS_TEMPLATE.format(
603568
pre_process_add_ampersand(dygraph_pre_process)
604569
)
570+
args_mapper_str = " // NO NEED"
571+
if args_mapper_func is not None:
572+
all_params_list = []
573+
args_mapper_str = ""
574+
for name, (ttype, pos) in forward_inputs_position_map.items():
575+
args_mapper_str += PARAMS_DECLARE_TEMPLE.format(
576+
type=ttype, name=name
577+
)
578+
all_params_list.append(name)
579+
for name, atype, default_value, pos in orig_forward_attrs_list:
580+
args_mapper_str += PARAMS_DECLARE_TEMPLE.format(
581+
type=atype, name=name
582+
)
583+
all_params_list.append(name)
584+
params = ',&' + ',&'.join(all_params_list)
585+
args_mapper_str += CALL_ARGS_MAPPER_TEMPLATE.format(
586+
func_name=args_mapper_func, params=params
587+
)
588+
# disable the generated args parser
589+
get_params_nums_and_check_str = DISABLE_TIPS
590+
get_eager_tensor_str = DISABLE_TIPS
591+
parse_attributes_str = DISABLE_TIPS
592+
check_remaining_params_validity_str = DISABLE_TIPS
593+
594+
convert_to_dist_str = ""
595+
# No inputs, skip convert to DistTensor
596+
if len(input_names) > 0:
597+
optional_and_vector_convert_code = ""
598+
for name, (ttype, pos) in forward_inputs_position_map.items():
599+
is_optional = name in optional_inputs
600+
if IsVectorTensorType(ttype):
601+
if is_optional:
602+
optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
603+
name,
604+
"GetOptionalTensorListFromArgs",
605+
forward_api_name,
606+
name,
607+
pos,
608+
"true",
609+
)
610+
else:
611+
optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
612+
name,
613+
"GetTensorListFromArgs",
614+
forward_api_name,
615+
name,
616+
pos,
617+
"false",
618+
)
619+
else:
620+
if is_optional:
621+
optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
622+
name,
623+
"GetOptionalTensorFromArgs",
624+
forward_api_name,
625+
name,
626+
pos,
627+
"true",
628+
)
629+
if len(input_single_tensor_names) > 0:
630+
convert_to_dist_str += CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITH_SINGLE_TENSOR_TEMPLATE.format(
631+
input_names=input_names,
632+
input_single_tensor_names=input_single_tensor_names,
633+
optional_and_vector_convert_code=optional_and_vector_convert_code,
634+
)
635+
else:
636+
convert_to_dist_str += CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITHOUT_SINGLE_TENSOR_TEMPLATE.format(
637+
input_names=input_names,
638+
optional_and_vector_convert_code=optional_and_vector_convert_code,
639+
)
640+
605641
set_device_str = FUNCTION_SET_DEVICE_TEMPLATE.format(expected_place_str)
606642

607643
# Generate Dygraph Function Call Logic
@@ -658,6 +694,8 @@ def pre_process_add_ampersand(s):
658694
get_eager_tensor_str,
659695
parse_attributes_str,
660696
check_remaining_params_validity_str,
697+
args_mapper_str,
698+
convert_to_dist_str,
661699
pre_process_str,
662700
get_input_out_str,
663701
set_device_str,
@@ -720,6 +758,8 @@ def pre_process_add_ampersand(s):
720758
get_eager_tensor_str,
721759
parse_attributes_str,
722760
check_remaining_params_validity_str,
761+
args_mapper_str,
762+
convert_to_dist_str,
723763
pre_process_str,
724764
"",
725765
set_device_str,

0 commit comments

Comments
 (0)