@@ -249,7 +249,8 @@ def GenBuildInputArgsStr(
249249
250250
251251def GenBuildInsertFullForMutableAttribute (
252- op_class_name ,
252+ args ,
253+ op_info ,
253254 op_attribute_name_list ,
254255 op_attribute_build_arg_type_list ,
255256 op_mutable_attribute_name_list ,
@@ -286,6 +287,39 @@ def GenBuildInsertFullForMutableAttribute(
286287 build_mutable_attribute += BUILD_SCALAR_ATTRIBUTE_TEMPLATE .format (
287288 attr_name = attr_name , phi_dtype = phi_dtype
288289 )
290+ if (
291+ args .with_distributed
292+ and len (op_info .input_name_list ) > 0
293+ and len (op_mutable_attribute_name_list ) > 0
294+ ):
295+ TEMPLATE = """
296+ bool is_dist_input = {input_name}_.type().isa<DistDenseTensorType>();
297+ if(is_dist_input) {{
298+ auto ctx = pir::IrContext::Instance();
299+ auto dist_type = {input_name}_.type().dyn_cast<DistDenseTensorType>();
300+ auto mesh_attr = dist_type.process_mesh_attr();
301+ """
302+ build_mutable_attribute += TEMPLATE .format (
303+ input_name = op_info .input_name_list [0 ]
304+ )
305+ TEMPLATE = """
306+ auto {attr_name}_type = DistDenseTensorType::get(ctx, {attr_name}_.type().dyn_cast<DenseTensorType>(), mesh_attr);
307+ {attr_name}_.set_type({attr_name}_type);
308+ full_{attr_name}_op->set_attribute(
309+ kAttrOpDistAttr,
310+ OperationDistAttribute::get(
311+ pir::IrContext::Instance(),
312+ mesh_attr,
313+ {{{attr_name}_type.tensor_dist_attr() }},
314+ {{}}
315+ )
316+ );
317+ """
318+ for mutable_attr_name in op_mutable_attribute_name_list :
319+ build_mutable_attribute += TEMPLATE .format (
320+ attr_name = mutable_attr_name
321+ )
322+ build_mutable_attribute += "\n }"
289323 return build_mutable_attribute
290324
291325
@@ -757,10 +791,8 @@ def GenBuildOutputs(
757791
758792
759793def gen_build_func_str (
760- op_class_name ,
761- op_input_name_list ,
762- op_input_type_list ,
763- op_input_optional_list ,
794+ args ,
795+ op_info ,
764796 op_attribute_name_list ,
765797 op_attribute_type_list ,
766798 op_attribute_build_arg_type_list ,
@@ -771,18 +803,13 @@ def gen_build_func_str(
771803 op_non_mutable_attribute_type_list ,
772804 op_non_mutable_attribute_build_arg_type_list ,
773805 op_non_mutable_attribute_default_value_list ,
774- op_output_name_list ,
775- op_output_type_list ,
776- op_output_size_list ,
777- op_output_optional_list ,
778- op_infer_meta_map ,
779- op_inplace_map ,
780806 muta_attr_is_input = False ,
781807 attr_args_is_map = False ,
782808):
809+ op_input_name_list = op_info .input_name_list
783810 build_args_for_declare = ""
784811 build_func = ""
785- build_info_str = OP_INFO_TEMPLATE .format (op_name = op_class_name )
812+ build_info_str = OP_INFO_TEMPLATE .format (op_name = op_info . class_name )
786813
787814 build_args_for_declare = GenBuildInputArgsStr (
788815 op_input_name_list ,
@@ -815,7 +842,8 @@ def gen_build_func_str(
815842 if not muta_attr_is_input :
816843 inset_full_for_mutable_attributes_str = (
817844 GenBuildInsertFullForMutableAttribute (
818- op_class_name ,
845+ args ,
846+ op_info ,
819847 op_attribute_name_list ,
820848 op_attribute_build_arg_type_list ,
821849 op_mutable_attribute_name_list ,
@@ -836,7 +864,7 @@ def gen_build_func_str(
836864 argument.AddAttributes(argument_attributes);
837865 argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
838866 ::pir::PassStopGradientsDefaultly(argument);""" .format (
839- op_name = op_class_name
867+ op_name = op_info . class_name
840868 )
841869
842870 GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """
@@ -912,7 +940,7 @@ def gen_build_func_str(
912940 data_name = "AsString"
913941 get_attributes_str += (
914942 GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE .format (
915- op_name = op_class_name ,
943+ op_name = op_info . class_name ,
916944 attr_type = attr_type ,
917945 attribute_name = attr_names [idx ],
918946 inner_type = inner_type ,
@@ -922,38 +950,38 @@ def gen_build_func_str(
922950 elif "paddle::dialect::IntArrayAttribute" in attr_types [idx ]:
923951 get_attributes_str += (
924952 GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE .format (
925- op_name = op_class_name ,
953+ op_name = op_info . class_name ,
926954 attr_type = attr_type ,
927955 attribute_name = attr_names [idx ],
928956 )
929957 )
930958 elif "paddle::dialect::ScalarAttribute" in attr_types [idx ]:
931959 get_attributes_str += (
932960 GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE .format (
933- op_name = op_class_name ,
961+ op_name = op_info . class_name ,
934962 attr_type = attr_type ,
935963 attribute_name = attr_names [idx ],
936964 )
937965 )
938966 elif "pir::StrAttribute" in attr_types [idx ]:
939967 get_attributes_str += (
940968 GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE .format (
941- op_name = op_class_name ,
969+ op_name = op_info . class_name ,
942970 attr_type = attr_type ,
943971 attribute_name = attr_names [idx ],
944972 attr_ir_type = attr_types [idx ],
945973 )
946974 )
947975 else :
948976 get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE .format (
949- op_name = op_class_name ,
977+ op_name = op_info . class_name ,
950978 attr_type = attr_type ,
951979 attribute_name = attr_names [idx ],
952980 attr_ir_type = attr_types [idx ],
953981 )
954982
955983 build_func = OP_BUILD_TEMPLATE .format (
956- op_name = op_class_name ,
984+ op_name = op_info . class_name ,
957985 build_info = build_info_str ,
958986 build_args = build_args_for_define ,
959987 build_mutable_attributes = inset_full_for_mutable_attributes_str ,
0 commit comments