Skip to content

Commit d7d077b

Browse files
committed
add dist attribute for mutable attribute.
1 parent 1007c39 commit d7d077b

5 files changed

Lines changed: 67 additions & 52 deletions

File tree

paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
7979
pir::IrContext* ctx,
8080
ProcessMeshAttribute mesh,
8181
const std::vector<int64_t>& dims_mapping,
82-
const flat_hash_map<int64_t, phi::ReduceType>& partial_status);
82+
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {});
8383
static TensorDistAttribute get(
8484
pir::IrContext* ctx,
8585
const phi::distributed::ProcessMesh& mesh,
8686
const std::vector<int64_t>& dims_mapping,
87-
const flat_hash_map<int64_t, phi::ReduceType>& partial_status) {
87+
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {}) {
8888
return get(ctx,
8989
ProcessMeshAttribute::get(ctx, mesh),
9090
dims_mapping,

paddle/fluid/pir/dialect/distributed/ir/dist_type.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ class DistDenseTensorType
6666
InferLocalDDim(dense_tensor_type.dims(), tensor_dist_attr);
6767
return get(ctx, dense_tensor_type, tensor_dist_attr, local_ddim);
6868
}
69+
70+
// return the replicated dist dense tensor type.
71+
static DistDenseTensorType get(pir::IrContext* ctx,
72+
pir::DenseTensorType dense_tensor_type,
73+
ProcessMeshAttribute process_mesh_attr) {
74+
auto& ddim = dense_tensor_type.dims();
75+
auto attr = TensorDistAttribute::get(
76+
ctx, process_mesh_attr, std::vector<int64_t>(ddim.size(), -1));
77+
return get(ctx, dense_tensor_type, attr, ddim);
78+
}
6979
};
7080

7181
} // namespace dialect

paddle/fluid/pir/dialect/op_generator/op_build_gen.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ def GenBuildInputArgsStr(
249249

250250

251251
def GenBuildInsertFullForMutableAttribute(
252-
op_class_name,
252+
args,
253+
op_info,
253254
op_attribute_name_list,
254255
op_attribute_build_arg_type_list,
255256
op_mutable_attribute_name_list,
@@ -286,6 +287,39 @@ def GenBuildInsertFullForMutableAttribute(
286287
build_mutable_attribute += BUILD_SCALAR_ATTRIBUTE_TEMPLATE.format(
287288
attr_name=attr_name, phi_dtype=phi_dtype
288289
)
290+
if (
291+
args.with_distributed
292+
and len(op_info.input_name_list) > 0
293+
and len(op_mutable_attribute_name_list) > 0
294+
):
295+
TEMPLATE = """
296+
bool is_dist_input = {input_name}_.type().isa<DistDenseTensorType>();
297+
if(is_dist_input) {{
298+
auto ctx = pir::IrContext::Instance();
299+
auto dist_type = {input_name}_.type().dyn_cast<DistDenseTensorType>();
300+
auto mesh_attr = dist_type.process_mesh_attr();
301+
"""
302+
build_mutable_attribute += TEMPLATE.format(
303+
input_name=op_info.input_name_list[0]
304+
)
305+
TEMPLATE = """
306+
auto {attr_name}_type = DistDenseTensorType::get(ctx, {attr_name}_.type().dyn_cast<DenseTensorType>(), mesh_attr);
307+
{attr_name}_.set_type({attr_name}_type);
308+
full_{attr_name}_op->set_attribute(
309+
kAttrOpDistAttr,
310+
OperationDistAttribute::get(
311+
pir::IrContext::Instance(),
312+
mesh_attr,
313+
{{{attr_name}_type.tensor_dist_attr() }},
314+
{{}}
315+
)
316+
);
317+
"""
318+
for mutable_attr_name in op_mutable_attribute_name_list:
319+
build_mutable_attribute += TEMPLATE.format(
320+
attr_name=mutable_attr_name
321+
)
322+
build_mutable_attribute += "\n }"
289323
return build_mutable_attribute
290324

291325

@@ -757,10 +791,8 @@ def GenBuildOutputs(
757791

758792

759793
def gen_build_func_str(
760-
op_class_name,
761-
op_input_name_list,
762-
op_input_type_list,
763-
op_input_optional_list,
794+
args,
795+
op_info,
764796
op_attribute_name_list,
765797
op_attribute_type_list,
766798
op_attribute_build_arg_type_list,
@@ -771,18 +803,13 @@ def gen_build_func_str(
771803
op_non_mutable_attribute_type_list,
772804
op_non_mutable_attribute_build_arg_type_list,
773805
op_non_mutable_attribute_default_value_list,
774-
op_output_name_list,
775-
op_output_type_list,
776-
op_output_size_list,
777-
op_output_optional_list,
778-
op_infer_meta_map,
779-
op_inplace_map,
780806
muta_attr_is_input=False,
781807
attr_args_is_map=False,
782808
):
809+
op_input_name_list = op_info.input_name_list
783810
build_args_for_declare = ""
784811
build_func = ""
785-
build_info_str = OP_INFO_TEMPLATE.format(op_name=op_class_name)
812+
build_info_str = OP_INFO_TEMPLATE.format(op_name=op_info.class_name)
786813

787814
build_args_for_declare = GenBuildInputArgsStr(
788815
op_input_name_list,
@@ -815,7 +842,8 @@ def gen_build_func_str(
815842
if not muta_attr_is_input:
816843
inset_full_for_mutable_attributes_str = (
817844
GenBuildInsertFullForMutableAttribute(
818-
op_class_name,
845+
args,
846+
op_info,
819847
op_attribute_name_list,
820848
op_attribute_build_arg_type_list,
821849
op_mutable_attribute_name_list,
@@ -836,7 +864,7 @@ def gen_build_func_str(
836864
argument.AddAttributes(argument_attributes);
837865
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
838866
::pir::PassStopGradientsDefaultly(argument);""".format(
839-
op_name=op_class_name
867+
op_name=op_info.class_name
840868
)
841869

842870
GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """
@@ -912,7 +940,7 @@ def gen_build_func_str(
912940
data_name = "AsString"
913941
get_attributes_str += (
914942
GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format(
915-
op_name=op_class_name,
943+
op_name=op_info.class_name,
916944
attr_type=attr_type,
917945
attribute_name=attr_names[idx],
918946
inner_type=inner_type,
@@ -922,38 +950,38 @@ def gen_build_func_str(
922950
elif "paddle::dialect::IntArrayAttribute" in attr_types[idx]:
923951
get_attributes_str += (
924952
GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format(
925-
op_name=op_class_name,
953+
op_name=op_info.class_name,
926954
attr_type=attr_type,
927955
attribute_name=attr_names[idx],
928956
)
929957
)
930958
elif "paddle::dialect::ScalarAttribute" in attr_types[idx]:
931959
get_attributes_str += (
932960
GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE.format(
933-
op_name=op_class_name,
961+
op_name=op_info.class_name,
934962
attr_type=attr_type,
935963
attribute_name=attr_names[idx],
936964
)
937965
)
938966
elif "pir::StrAttribute" in attr_types[idx]:
939967
get_attributes_str += (
940968
GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE.format(
941-
op_name=op_class_name,
969+
op_name=op_info.class_name,
942970
attr_type=attr_type,
943971
attribute_name=attr_names[idx],
944972
attr_ir_type=attr_types[idx],
945973
)
946974
)
947975
else:
948976
get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format(
949-
op_name=op_class_name,
977+
op_name=op_info.class_name,
950978
attr_type=attr_type,
951979
attribute_name=attr_names[idx],
952980
attr_ir_type=attr_types[idx],
953981
)
954982

955983
build_func = OP_BUILD_TEMPLATE.format(
956-
op_name=op_class_name,
984+
op_name=op_info.class_name,
957985
build_info=build_info_str,
958986
build_args=build_args_for_define,
959987
build_mutable_attributes=inset_full_for_mutable_attributes_str,

paddle/fluid/pir/dialect/op_generator/op_gen.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,10 +1451,8 @@ def AutoCodeGen(
14511451
build_args_with_muta_attr_not_input_for_declare,
14521452
build_func_with_muta_attr_not_input,
14531453
) = gen_build_func_str(
1454-
op_class_name,
1455-
op_input_name_list,
1456-
op_input_type_list,
1457-
op_input_optional_list,
1454+
args,
1455+
op_info,
14581456
op_attribute_name_list,
14591457
op_attribute_type_list,
14601458
op_attribute_build_arg_type_list,
@@ -1465,23 +1463,15 @@ def AutoCodeGen(
14651463
op_non_mutable_attribute_type_list,
14661464
op_non_mutable_attribute_build_arg_type_list,
14671465
op_non_mutable_attribute_default_value_list,
1468-
op_output_name_list,
1469-
op_output_type_list,
1470-
op_output_size_list,
1471-
op_output_optional_list,
1472-
op_infer_meta_map,
1473-
op_inplace_map,
14741466
muta_attr_is_input=False,
14751467
)
14761468
if len(op_attribute_name_list) > 0:
14771469
(
14781470
build_args_with_attr_is_map_for_declare,
14791471
build_func_with_attr_is_map,
14801472
) = gen_build_func_str(
1481-
op_class_name,
1482-
op_input_name_list,
1483-
op_input_type_list,
1484-
op_input_optional_list,
1473+
args,
1474+
op_info,
14851475
op_attribute_name_list,
14861476
op_attribute_type_list,
14871477
op_attribute_build_arg_type_list,
@@ -1492,12 +1482,6 @@ def AutoCodeGen(
14921482
op_non_mutable_attribute_type_list,
14931483
op_non_mutable_attribute_build_arg_type_list,
14941484
op_non_mutable_attribute_default_value_list,
1495-
op_output_name_list,
1496-
op_output_type_list,
1497-
op_output_size_list,
1498-
op_output_optional_list,
1499-
op_infer_meta_map,
1500-
op_inplace_map,
15011485
muta_attr_is_input=False,
15021486
attr_args_is_map=True,
15031487
)
@@ -1508,10 +1492,8 @@ def AutoCodeGen(
15081492
build_args_with_muta_attr_is_input_for_declare,
15091493
build_func_with_muta_attr_is_input,
15101494
) = gen_build_func_str(
1511-
op_class_name,
1512-
op_input_name_list,
1513-
op_input_type_list,
1514-
op_input_optional_list,
1495+
args,
1496+
op_info,
15151497
op_attribute_name_list,
15161498
op_attribute_type_list,
15171499
op_attribute_build_arg_type_list,
@@ -1522,12 +1504,6 @@ def AutoCodeGen(
15221504
op_non_mutable_attribute_type_list,
15231505
op_non_mutable_attribute_build_arg_type_list,
15241506
op_non_mutable_attribute_default_value_list,
1525-
op_output_name_list,
1526-
op_output_type_list,
1527-
op_output_size_list,
1528-
op_output_optional_list,
1529-
op_infer_meta_map,
1530-
op_inplace_map,
15311507
muta_attr_is_input=True,
15321508
)
15331509

paddle/fluid/pir/dialect/operator/ir/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@
728728
infer_meta :
729729
func : CreateLikeInferMeta
730730
param : [x, dtype]
731+
spmd_rule : FullLikeInferSpmd
731732
kernel :
732733
func : full_like
733734
param : [x, value, dtype]

0 commit comments

Comments
 (0)