From 597c2f91d3f1b9f6389595402e80d1a89e357e4b Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 31 May 2024 03:23:51 +0000 Subject: [PATCH 1/5] add --- .../fuse_allreduce_split_to_reducescatter_pass.cc | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc b/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc index 4191eaa4bce50e..49156fb4fbf775 100644 --- a/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc +++ b/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc @@ -35,7 +35,11 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase { const auto &c_allreduce_sum_ = pat.Op(paddle::dialect::CAllreduceSum_Op::name(), {{"ring_id", pat.Attr("ring_id")}, - {"use_calc_stream", pat.Attr("use_calc_stream")}}); + {"use_calc_stream", pat.Attr("use_calc_stream")}, + {"execution_stream", pat.Attr("execution_stream")}, + {"force_record_event", pat.Attr("force_record_event")}, + {"event_to_record", pat.Attr("event_to_record")}, + {"events_to_wait", pat.Attr("events_to_wait")}}); const auto &assign = pat.Op(paddle::dialect::AssignOp::name()); const auto &full = pat.Op(paddle::dialect::FullOp::name()); const auto &split_with_num = pat.Op(paddle::dialect::SplitWithNumOp::name(), @@ -74,7 +78,11 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase { res.Op(paddle::dialect::CReducescatterOp::name(), {{"ring_id", pat.Attr("ring_id")}, {"nranks", pat.Attr("num")}, - {"use_calc_stream", pat.Attr("use_calc_stream")}}); + {"use_calc_stream", pat.Attr("use_calc_stream")}, + {"execution_stream", pat.Attr("execution_stream")}, + {"force_record_event", pat.Attr("force_record_event")}, + {"event_to_record", pat.Attr("event_to_record")}, + {"events_to_wait", pat.Attr("events_to_wait")}}); c_reducescatter({&res.Tensor("input_grad_partial")}, {&res.Tensor("out")}); } From 7523be741e20d7967455cb6e7f82572b45bcebcf Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 31 May 2024 05:05:16 +0000 Subject: [PATCH 2/5] add --- ...e_allreduce_split_to_reducescatter_pass.cc | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc b/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc index 49156fb4fbf775..c8919014667434 100644 --- a/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc +++ b/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc @@ -57,28 +57,12 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase { split_with_num(pat.Tensor("input_grad_tmp"), pat.Tensor("split_num")); pat.Tensor("out") = builtin_slice(pat.Tensor("input_grad_group")); - pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) { - const auto &x_trans = match_ctx.Attr("trans_x"); - const auto &y_trans = match_ctx.Attr("trans_y"); - auto input_grad_partial_count = - match_ctx.Tensor("input_grad_partial").use_count(); - auto input_grad_count = match_ctx.Tensor("input_grad").use_count(); - auto input_grad_tmp_count = - match_ctx.Tensor("input_grad_tmp").use_count(); - auto input_grad_group_count = - match_ctx.Tensor("input_grad_group").use_count(); - return (x_trans == false && y_trans == true && - input_grad_partial_count == 1 && input_grad_count == 1 && - input_grad_tmp_count == 1 && input_grad_group_count == 1); - }); - paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &c_reducescatter = res.Op(paddle::dialect::CReducescatterOp::name(), - {{"ring_id", pat.Attr("ring_id")}, - {"nranks", pat.Attr("num")}, - {"use_calc_stream", pat.Attr("use_calc_stream")}, + {{"ring_id", pat.Attr("ring_id")}, {"nranks", pat.Attr("num")}}, + {{"use_calc_stream", pat.Attr("use_calc_stream")}, {"execution_stream", pat.Attr("execution_stream")}, {"force_record_event", pat.Attr("force_record_event")}, {"event_to_record", pat.Attr("event_to_record")}, From 47478da87203aa90c4d836bf0710e3a05ee172ac Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 31 May 2024 05:11:35 +0000 Subject: [PATCH 3/5] add --- .../fuse_allreduce_split_to_reducescatter_pass.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc b/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc index c8919014667434..0960f687ff77ee 100644 --- a/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc +++ b/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc @@ -61,9 +61,10 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase { const auto &c_reducescatter = res.Op(paddle::dialect::CReducescatterOp::name(), - {{"ring_id", pat.Attr("ring_id")}, {"nranks", pat.Attr("num")}}, - {{"use_calc_stream", pat.Attr("use_calc_stream")}, - {"execution_stream", pat.Attr("execution_stream")}, + {{"ring_id", pat.Attr("ring_id")}, + {"nranks", pat.Attr("num")}, + {"use_calc_stream", pat.Attr("use_calc_stream")}}, + {{"execution_stream", pat.Attr("execution_stream")}, {"force_record_event", pat.Attr("force_record_event")}, {"event_to_record", pat.Attr("event_to_record")}, {"events_to_wait", pat.Attr("events_to_wait")}}); From 303c6e113cfda978c0cc54ddb9c2835763986957 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 31 May 2024 05:13:16 +0000 Subject: [PATCH 4/5] add --- .../fuse_allreduce_split_to_reducescatter_pass.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc b/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc index 0960f687ff77ee..5d1a9b87431f1a 100644 --- a/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc +++ b/paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc @@ -57,6 +57,21 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase { split_with_num(pat.Tensor("input_grad_tmp"), pat.Tensor("split_num")); pat.Tensor("out") = builtin_slice(pat.Tensor("input_grad_group")); + pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) { + const auto &x_trans = match_ctx.Attr("trans_x"); + const auto &y_trans = match_ctx.Attr("trans_y"); + auto input_grad_partial_count = + match_ctx.Tensor("input_grad_partial").use_count(); + auto input_grad_count = match_ctx.Tensor("input_grad").use_count(); + auto input_grad_tmp_count = + match_ctx.Tensor("input_grad_tmp").use_count(); + auto input_grad_group_count = + match_ctx.Tensor("input_grad_group").use_count(); + return (x_trans == false && y_trans == true && + input_grad_partial_count == 1 && input_grad_count == 1 && + input_grad_tmp_count == 1 && input_grad_group_count == 1); + }); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &c_reducescatter = From ae8ece84d505373c66a1245d2db9694b228045b5 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 31 May 2024 09:00:53 +0000 Subject: [PATCH 5/5] add --- .../test_fuse_allreduce_split_to_reducescatter_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributed_passes/test_fuse_allreduce_split_to_reducescatter_pass.py b/test/distributed_passes/test_fuse_allreduce_split_to_reducescatter_pass.py index b36a5121d2e820..5127589c36396d 100644 --- a/test/distributed_passes/test_fuse_allreduce_split_to_reducescatter_pass.py +++ b/test/distributed_passes/test_fuse_allreduce_split_to_reducescatter_pass.py @@ -22,7 +22,7 @@ (%38) = "pd_op.data" () {dtype:(pd_op.DataType)bfloat16,name:"linear_0.tmp_0",persistable:[false],place:(pd_op.Place)Place(gpu:0),shape:(pd_op.IntArray)[4096,1,28672],stop_gradient:[false]} : () -> builtin.tensor<4096x1x28672xbf16> (%48) = "pd_op.data" () {dtype:(pd_op.DataType)bfloat16,name:"input",persistable:[false],place:(pd_op.Place)Place(gpu:0),shape:(pd_op.IntArray)[4096,1,28672],stop_gradient:[false]} : () -> builtin.tensor<4096x1x28672xbf16> (%50) = "pd_op.matmul" (%48, %2) {persistable:[false],stop_gradient:[false],transpose_x:false,transpose_y:true} : (builtin.tensor<4096x1x28672xbf16>, builtin.tensor<8192x28672xbf16>) -> builtin.tensor<4096x1x8192xbf16> - (%57) = "pd_op.c_allreduce_sum_" (%50) {persistable:[false],ring_id:(Int32)36,stop_gradient:[false],use_calc_stream:true,use_model_parallel:true} : (builtin.tensor<4096x1x8192xbf16>) -> builtin.tensor<4096x1x8192xbf16> + (%57) = "pd_op.c_allreduce_sum_" (%50) {event_to_record:"event_7989",events_to_wait:[],execution_stream:"auto_parallel_mp",force_record_event:false,persistable:[false],ring_id:(Int32)36,stop_gradient:[false],use_calc_stream:true,use_model_parallel:true} : (builtin.tensor<4096x1x8192xbf16>) -> builtin.tensor<4096x1x8192xbf16> (%63) = "pd_op.assign" (%57) {persistable:[false],stop_gradient:[false]} : (builtin.tensor<4096x1x8192xbf16>) -> builtin.tensor<4096x1x8192xbf16> (%64) = "pd_op.full" () {dtype:(pd_op.DataType)int32,place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[1],stop_gradient:[true],value:(Float)0} : () -> builtin.tensor<1xi32> (%65) = "pd_op.split_with_num" (%63, %64) {num:(Int32)2,persistable:[false],stop_gradient:[false]} : (builtin.tensor<4096x1x8192xbf16>, builtin.tensor<1xi32>) -> vec[builtin.tensor<2048x1x8192xbf16>,builtin.tensor<2048x1x8192xbf16>]