@@ -170,7 +170,7 @@ class MultiHeadMatmulFuseNoBiasQKPattern : public paddle::drr::DrrPatternBase {
170170 paddle::drr::ResultPattern res = src.ResultPattern ();
171171
172172 // W reshape.
173- const auto &reshape_w_shape_attr = res.Attr (
173+ const auto &reshape_w_shape_attr = res.ComputeAttr (
174174 [](const paddle::drr::MatchContext &match_ctx) -> std::vector<int64_t > {
175175 auto matmul_1_in_2 =
176176 pir::GetShapeFromValue (match_ctx.Tensor (" matmul_1_in_2" ));
@@ -195,14 +195,12 @@ class MultiHeadMatmulFuseNoBiasQKPattern : public paddle::drr::DrrPatternBase {
195195 &res.Tensor (" reshape_6_out" ),
196196 &res.Tensor (" reshape_7_out" )},
197197 {&res.Tensor (" combine_1_out" )});
198- const auto &concat_1_axis_attr = res.Attr (
199- [](const paddle::drr::MatchContext &match_ctx) -> int { return 1 ; });
200- const auto &concat_1 =
201- res.Op (" pd_op.concat" , {{" axis" , concat_1_axis_attr}});
198+
199+ const auto &concat_1 = res.Op (" pd_op.concat" , {{" axis" , res.Int32Attr (1 )}});
202200 res.Tensor (" concat_1_out" ) = concat_1 (res.Tensor (" combine_1_out" ));
203201
204202 // Bias reshape.
205- const auto &reshape_b_shape_attr = res.Attr (
203+ const auto &reshape_b_shape_attr = res.ComputeAttr (
206204 [](const paddle::drr::MatchContext &match_ctx) -> std::vector<int64_t > {
207205 auto add_1_in_2 =
208206 pir::GetShapeFromValue (match_ctx.Tensor (" add_1_in_2" ));
@@ -227,38 +225,26 @@ class MultiHeadMatmulFuseNoBiasQKPattern : public paddle::drr::DrrPatternBase {
227225 &res.Tensor (" reshape_9_out" ),
228226 &res.Tensor (" reshape_10_out" )},
229227 {&res.Tensor (" combine_2_out" )});
230- const auto &concat_2_axis_attr = res.Attr (
231- [](const paddle::drr::MatchContext &match_ctx) -> int { return 0 ; });
232- const auto &concat_2 =
233- res.Op (" pd_op.concat" , {{" axis" , concat_2_axis_attr}});
228+
229+ const auto &concat_2 = res.Op (" pd_op.concat" , {{" axis" , res.Int32Attr (0 )}});
234230 res.Tensor (" concat_2_out" ) = concat_2 (res.Tensor (" combine_2_out" ));
235231
236232 const auto &head_number =
237- res.Attr ([](const paddle::drr::MatchContext &match_ctx) -> int {
233+ res.ComputeAttr ([](const paddle::drr::MatchContext &match_ctx) -> int {
238234 const auto &full_int_array_1_value =
239235 match_ctx.Attr <std::vector<int64_t >>(" full_int_array_1_value" );
240236 return full_int_array_1_value.at (2 );
241237 });
242- const auto &alpha =
243- res. Attr ( [](const paddle::drr::MatchContext &match_ctx) -> float {
238+ const auto &alpha = res. ComputeAttr (
239+ [](const paddle::drr::MatchContext &match_ctx) -> float {
244240 return match_ctx.Attr <float >(" full_1_value" );
245241 });
246- const auto &multihead_matmul =
247- res.Op (" pd_op.multihead_matmul" ,
248- {{" transpose_q" ,
249- res.Attr ([](const paddle::drr::MatchContext &match_ctx) {
250- return false ;
251- })},
252- {" transpose_k" ,
253- res.Attr ([](const paddle::drr::MatchContext &match_ctx) {
254- return true ;
255- })},
256- {" transpose_v" ,
257- res.Attr ([](const paddle::drr::MatchContext &match_ctx) {
258- return false ;
259- })},
260- {" head_number" , head_number},
261- {" alpha" , alpha}});
242+ const auto &multihead_matmul = res.Op (" pd_op.multihead_matmul" ,
243+ {{" transpose_q" , res.BoolAttr (false )},
244+ {" transpose_k" , res.BoolAttr (true )},
245+ {" transpose_v" , res.BoolAttr (false )},
246+ {" head_number" , head_number},
247+ {" alpha" , alpha}});
262248 multihead_matmul ({&res.Tensor (" matmul_1_in_1" ),
263249 &res.Tensor (" concat_1_out" ),
264250 &res.Tensor (" concat_2_out" ),
@@ -423,7 +409,7 @@ class MultiHeadMatmulFuseWithBiasQKPattern
423409 paddle::drr::ResultPattern res = src.ResultPattern ();
424410
425411 // W reshape.
426- const auto &reshape_w_shape_attr = res.Attr (
412+ const auto &reshape_w_shape_attr = res.ComputeAttr (
427413 [](const paddle::drr::MatchContext &match_ctx) -> std::vector<int64_t > {
428414 auto matmul_1_in_2 =
429415 pir::GetShapeFromValue (match_ctx.Tensor (" matmul_1_in_2" ));
@@ -448,14 +434,12 @@ class MultiHeadMatmulFuseWithBiasQKPattern
448434 &res.Tensor (" reshape_6_out" ),
449435 &res.Tensor (" reshape_7_out" )},
450436 {&res.Tensor (" combine_1_out" )});
451- const auto &concat_1_axis_attr = res.Attr (
452- [](const paddle::drr::MatchContext &match_ctx) -> int { return 1 ; });
453- const auto &concat_1 =
454- res.Op (" pd_op.concat" , {{" axis" , concat_1_axis_attr}});
437+
438+ const auto &concat_1 = res.Op (" pd_op.concat" , {{" axis" , res.Int32Attr (1 )}});
455439 res.Tensor (" concat_1_out" ) = concat_1 (res.Tensor (" combine_1_out" ));
456440
457441 // Bias reshape.
458- const auto &reshape_b_shape_attr = res.Attr (
442+ const auto &reshape_b_shape_attr = res.ComputeAttr (
459443 [](const paddle::drr::MatchContext &match_ctx) -> std::vector<int64_t > {
460444 auto add_1_in_2 =
461445 pir::GetShapeFromValue (match_ctx.Tensor (" add_1_in_2" ));
@@ -480,38 +464,26 @@ class MultiHeadMatmulFuseWithBiasQKPattern
480464 &res.Tensor (" reshape_9_out" ),
481465 &res.Tensor (" reshape_10_out" )},
482466 {&res.Tensor (" combine_2_out" )});
483- const auto &concat_2_axis_attr = res.Attr (
484- [](const paddle::drr::MatchContext &match_ctx) -> int { return 0 ; });
485- const auto &concat_2 =
486- res.Op (" pd_op.concat" , {{" axis" , concat_2_axis_attr}});
467+
468+ const auto &concat_2 = res.Op (" pd_op.concat" , {{" axis" , res.Int32Attr (0 )}});
487469 res.Tensor (" concat_2_out" ) = concat_2 (res.Tensor (" combine_2_out" ));
488470
489471 const auto &head_number =
490- res.Attr ([](const paddle::drr::MatchContext &match_ctx) -> int {
472+ res.ComputeAttr ([](const paddle::drr::MatchContext &match_ctx) -> int {
491473 const auto &full_int_array_1_value =
492474 match_ctx.Attr <std::vector<int64_t >>(" full_int_array_1_value" );
493475 return full_int_array_1_value.at (2 );
494476 });
495- const auto &alpha =
496- res. Attr ( [](const paddle::drr::MatchContext &match_ctx) -> float {
477+ const auto &alpha = res. ComputeAttr (
478+ [](const paddle::drr::MatchContext &match_ctx) -> float {
497479 return match_ctx.Attr <float >(" full_1_value" );
498480 });
499- const auto &multihead_matmul =
500- res.Op (" pd_op.multihead_matmul" ,
501- {{" transpose_q" ,
502- res.Attr ([](const paddle::drr::MatchContext &match_ctx) {
503- return false ;
504- })},
505- {" transpose_k" ,
506- res.Attr ([](const paddle::drr::MatchContext &match_ctx) {
507- return true ;
508- })},
509- {" transpose_v" ,
510- res.Attr ([](const paddle::drr::MatchContext &match_ctx) {
511- return false ;
512- })},
513- {" head_number" , head_number},
514- {" alpha" , alpha}});
481+ const auto &multihead_matmul = res.Op (" pd_op.multihead_matmul" ,
482+ {{" transpose_q" , res.BoolAttr (false )},
483+ {" transpose_k" , res.BoolAttr (true )},
484+ {" transpose_v" , res.BoolAttr (false )},
485+ {" head_number" , head_number},
486+ {" alpha" , alpha}});
515487 multihead_matmul ({&res.Tensor (" matmul_1_in_1" ),
516488 &res.Tensor (" concat_1_out" ),
517489 &res.Tensor (" concat_2_out" ),
0 commit comments