@@ -35,7 +35,11 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase {
3535 const auto &c_allreduce_sum_ =
3636 pat.Op (paddle::dialect::CAllreduceSum_Op::name (),
3737 {{" ring_id" , pat.Attr (" ring_id" )},
38- {" use_calc_stream" , pat.Attr (" use_calc_stream" )}});
38+ {" use_calc_stream" , pat.Attr (" use_calc_stream" )},
39+ {" execution_stream" , pat.Attr (" execution_stream" )},
40+ {" force_record_event" , pat.Attr (" force_record_event" )},
41+ {" event_to_record" , pat.Attr (" event_to_record" )},
42+ {" events_to_wait" , pat.Attr (" events_to_wait" )}});
3943 const auto &assign = pat.Op (paddle::dialect::AssignOp::name ());
4044 const auto &full = pat.Op (paddle::dialect::FullOp::name ());
4145 const auto &split_with_num = pat.Op (paddle::dialect::SplitWithNumOp::name (),
@@ -74,7 +78,11 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase {
7478 res.Op (paddle::dialect::CReducescatterOp::name (),
7579 {{" ring_id" , pat.Attr (" ring_id" )},
7680 {" nranks" , pat.Attr (" num" )},
77- {" use_calc_stream" , pat.Attr (" use_calc_stream" )}});
81+ {" use_calc_stream" , pat.Attr (" use_calc_stream" )}},
82+ {{" execution_stream" , pat.Attr (" execution_stream" )},
83+ {" force_record_event" , pat.Attr (" force_record_event" )},
84+ {" event_to_record" , pat.Attr (" event_to_record" )},
85+ {" events_to_wait" , pat.Attr (" events_to_wait" )}});
7886
7987 c_reducescatter ({&res.Tensor (" input_grad_partial" )}, {&res.Tensor (" out" )});
8088 }
0 commit comments