1717import argparse
1818
1919
20- def map_code_template (attrs_str , attrs_checker_str ):
20+ def map_code_template (default_attrs_str , dynamic_attr_str , attrs_checker_str ):
2121 return f"""// This file is generated by paddle/phi/api/yaml/generator/ops_extra_info_gen.py
2222#include "paddle/fluid/operators/ops_extra_info.h"
2323
@@ -28,7 +28,11 @@ def map_code_template(attrs_str, attrs_checker_str):
2828
2929ExtraInfoUtils::ExtraInfoUtils() {{
3030 g_extra_attrs_map_ = {{
31- { attrs_str }
31+ { default_attrs_str }
32+ }};
33+
34+ g_extra_dynamic_attrs_map_ = {{
35+ { dynamic_attr_str }
3236 }};
3337
3438 g_extra_attrs_checker_ = {{
@@ -64,10 +68,7 @@ def parse_attr(attr_str):
6468 'name' ), result .group ('default_val' )
6569
6670
67- def generate_extra_info (op_compat_yaml_path , ops_extra_info_path ):
68- compat_apis = []
69- with open (op_compat_yaml_path , 'rt' ) as f :
70- compat_apis = yaml .safe_load (f )
71+ def generate_attr_info (attr_type , op_compat_args ):
7172
7273 def get_op_name (api_item ):
7374 names = api_item .split ('(' )
@@ -76,49 +77,62 @@ def get_op_name(api_item):
7677 else :
7778 return names [1 ].split (')' )[0 ].strip ()
7879
79- extra_map_str_list = []
80- extra_checker_str_list = []
80+ attr_map_str_list = []
81+ attr_checker_str_list = []
82+ extra_args_map = op_compat_args ['extra' ]
83+ if attr_type in extra_args_map :
84+ attr_map_list = []
85+ attr_checker_func_list = []
86+ for attr in extra_args_map [attr_type ]:
87+ attr_type , attr_name , default_val = parse_attr (attr )
88+ attr_checker_func_list .append (
89+ f"[](framework::AttributeMap* attr_map, bool only_check_exist_value)-> void {{ ExtraAttrChecker<{ attr_type } >(\" { attr_name } \" , { default_val } )(attr_map, only_check_exist_value);}}"
90+ )
91+ if attr_type .startswith ("std::vector" ):
92+ attr_map_list .append (
93+ f"{{\" { attr_name } \" , { attr_type } { default_val } }}" )
94+ else :
95+ attr_map_list .append (
96+ f"{{\" { attr_name } \" , { attr_type } {{{ default_val } }}}}" )
97+ api_extra_attr_map = ", " .join (attr_map_list )
98+ api_extra_attr_checkers = ",\n " .join (attr_checker_func_list )
99+ attr_map_str_list .append (
100+ f"{{\" { get_op_name (op_compat_args ['op' ])} \" , {{ { api_extra_attr_map } }}}}"
101+ )
102+ attr_checker_str_list .append (
103+ f"{{\" { get_op_name (op_compat_args ['op' ])} \" , {{ { api_extra_attr_checkers } }}}}"
104+ )
105+ if 'backward' in op_compat_args :
106+ for bw_item in op_compat_args ['backward' ].split (',' ):
107+ bw_op_name = get_op_name (bw_item )
108+ attr_map_str_list .append (
109+ f"{{\" { bw_op_name } \" , {{ { api_extra_attr_map } }}}}" )
110+ attr_checker_str_list .append (
111+ f"{{\" { bw_op_name } \" , {{ { api_extra_attr_checkers } }}}}" )
112+ return attr_map_str_list , attr_checker_str_list
81113
114+
115+ def generate_extra_info (op_compat_yaml_path , ops_extra_info_path ):
116+ compat_apis = []
117+ with open (op_compat_yaml_path , 'rt' ) as f :
118+ compat_apis = yaml .safe_load (f )
119+ extra_default_attr_str_list = []
120+ extra_dynamic_attr_str_list = []
121+ extra_checker_str_list = []
82122 for op_compat_args in compat_apis :
83123 if 'extra' in op_compat_args :
84- extra_args_map = op_compat_args ['extra' ]
85124 # TODO(chenweihang): add inputs and outputs
86- if 'attrs' in extra_args_map :
87- attr_map_list = []
88- attr_checker_func_list = []
89- for attr in extra_args_map ['attrs' ]:
90- attr_type , attr_name , default_val = parse_attr (attr )
91- attr_checker_func_list .append (
92- f"[](framework::AttributeMap* attr_map, bool only_check_exist_value)-> void {{ ExtraAttrChecker<{ attr_type } >(\" { attr_name } \" , { default_val } )(attr_map, only_check_exist_value);}}"
93- )
94- if attr_type .startswith ("std::vector" ):
95- attr_map_list .append (
96- f"{{\" { attr_name } \" , { attr_type } { default_val } }}" )
97- else :
98- attr_map_list .append (
99- f"{{\" { attr_name } \" , { attr_type } {{{ default_val } }}}}"
100- )
101- api_extra_attr_map = ", " .join (attr_map_list )
102- api_extra_attr_checkers = ",\n " .join (
103- attr_checker_func_list )
104- extra_map_str_list .append (
105- f"{{\" { get_op_name (op_compat_args ['op' ])} \" , {{ { api_extra_attr_map } }}}}"
106- )
107- extra_checker_str_list .append (
108- f"{{\" { get_op_name (op_compat_args ['op' ])} \" , {{ { api_extra_attr_checkers } }}}}"
109- )
110- if 'backward' in op_compat_args :
111- for bw_item in op_compat_args ['backward' ].split (',' ):
112- bw_op_name = get_op_name (bw_item )
113- extra_map_str_list .append (
114- f"{{\" { bw_op_name } \" , {{ { api_extra_attr_map } }}}}" )
115- extra_checker_str_list .append (
116- f"{{\" { bw_op_name } \" , {{ { api_extra_attr_checkers } }}}}"
117- )
118-
125+ default_attr_map_str , default_attr_checker_str = generate_attr_info (
126+ 'attrs' , op_compat_args )
127+ dynamic_attr_map_str , _ = generate_attr_info (
128+ 'dynamic_attrs' , op_compat_args )
129+ extra_default_attr_str_list .extend (default_attr_map_str )
130+ extra_dynamic_attr_str_list .extend (dynamic_attr_map_str )
131+ extra_checker_str_list .extend (default_attr_checker_str )
119132 ops_extra_info_file = open (ops_extra_info_path , 'w' )
120133 ops_extra_info_file .write (
121- map_code_template (",\n " .join (extra_map_str_list ),
134+ map_code_template (",\n " .join (extra_default_attr_str_list ),
135+ ",\n " .join (extra_dynamic_attr_str_list ),
122136 ",\n " .join (extra_checker_str_list )))
123137 ops_extra_info_file .close ()
124138
0 commit comments