diff --git a/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc b/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc index 4191eaa4bce50e..5d1a9b87431f1a 100644 --- a/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc +++ b/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc @@ -35,7 +35,11 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase { const auto &c_allreduce_sum_ = pat.Op(paddle::dialect::CAllreduceSum_Op::name(), {{"ring_id", pat.Attr("ring_id")}, - {"use_calc_stream", pat.Attr("use_calc_stream")}}); + {"use_calc_stream", pat.Attr("use_calc_stream")}, + {"execution_stream", pat.Attr("execution_stream")}, + {"force_record_event", pat.Attr("force_record_event")}, + {"event_to_record", pat.Attr("event_to_record")}, + {"events_to_wait", pat.Attr("events_to_wait")}}); const auto &assign = pat.Op(paddle::dialect::AssignOp::name()); const auto &full = pat.Op(paddle::dialect::FullOp::name()); const auto &split_with_num = pat.Op(paddle::dialect::SplitWithNumOp::name(), @@ -74,7 +78,11 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase { res.Op(paddle::dialect::CReducescatterOp::name(), {{"ring_id", pat.Attr("ring_id")}, {"nranks", pat.Attr("num")}, - {"use_calc_stream", pat.Attr("use_calc_stream")}}); + {"use_calc_stream", pat.Attr("use_calc_stream")}}, + {{"execution_stream", pat.Attr("execution_stream")}, + {"force_record_event", pat.Attr("force_record_event")}, + {"event_to_record", pat.Attr("event_to_record")}, + {"events_to_wait", pat.Attr("events_to_wait")}}); c_reducescatter({&res.Tensor("input_grad_partial")}, {&res.Tensor("out")}); } diff --git a/test/distributed_passes/test_fuse_allreduce_split_to_reducescatter_pass.py b/test/distributed_passes/test_fuse_allreduce_split_to_reducescatter_pass.py index b36a5121d2e820..5127589c36396d 100644 --- a/test/distributed_passes/test_fuse_allreduce_split_to_reducescatter_pass.py +++ b/test/distributed_passes/test_fuse_allreduce_split_to_reducescatter_pass.py @@ -22,7 +22,7 @@ (%38) = "pd_op.data" () {dtype:(pd_op.DataType)bfloat16,name:"linear_0.tmp_0",persistable:[false],place:(pd_op.Place)Place(gpu:0),shape:(pd_op.IntArray)[4096,1,28672],stop_gradient:[false]} : () -> builtin.tensor<4096x1x28672xbf16> (%48) = "pd_op.data" () {dtype:(pd_op.DataType)bfloat16,name:"input",persistable:[false],place:(pd_op.Place)Place(gpu:0),shape:(pd_op.IntArray)[4096,1,28672],stop_gradient:[false]} : () -> builtin.tensor<4096x1x28672xbf16> (%50) = "pd_op.matmul" (%48, %2) {persistable:[false],stop_gradient:[false],transpose_x:false,transpose_y:true} : (builtin.tensor<4096x1x28672xbf16>, builtin.tensor<8192x28672xbf16>) -> builtin.tensor<4096x1x8192xbf16> - (%57) = "pd_op.c_allreduce_sum_" (%50) {persistable:[false],ring_id:(Int32)36,stop_gradient:[false],use_calc_stream:true,use_model_parallel:true} : (builtin.tensor<4096x1x8192xbf16>) -> builtin.tensor<4096x1x8192xbf16> + (%57) = "pd_op.c_allreduce_sum_" (%50) {event_to_record:"event_7989",events_to_wait:[],execution_stream:"auto_parallel_mp",force_record_event:false,persistable:[false],ring_id:(Int32)36,stop_gradient:[false],use_calc_stream:true,use_model_parallel:true} : (builtin.tensor<4096x1x8192xbf16>) -> builtin.tensor<4096x1x8192xbf16> (%63) = "pd_op.assign" (%57) {persistable:[false],stop_gradient:[false]} : (builtin.tensor<4096x1x8192xbf16>) -> builtin.tensor<4096x1x8192xbf16> (%64) = "pd_op.full" () {dtype:(pd_op.DataType)int32,place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[1],stop_gradient:[true],value:(Float)0} : () -> builtin.tensor<1xi32> (%65) = "pd_op.split_with_num" (%63, %64) {num:(Int32)2,persistable:[false],stop_gradient:[false]} : (builtin.tensor<4096x1x8192xbf16>, builtin.tensor<1xi32>) -> vec[builtin.tensor<2048x1x8192xbf16>,builtin.tensor<2048x1x8192xbf16>]