@@ -202,14 +202,11 @@ void ReplaceExpandWithBroadcast(pir::IrContext* ir_context,
202202
203203std::tuple<pir::Value, pir::Value, pir::Value> BroadcastableToCondValue (
204204 const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
205- pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
205+ const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
206206 const std::vector<pir::Value>& group_inputs,
207207 pir::Builder& builder) { // NOLINT
208208 const auto & lhs_expr = broadcastable_condition->lhs ;
209209 const auto & rhs_expr = broadcastable_condition->rhs ;
210- auto ShapeOrDataDimExprs4Value = [&shape_analysis](pir::Value value) {
211- return shape_analysis.GetShapeOrDataForValue (value);
212- };
213210
214211 std::vector<pir::Value> lhs_minimal_inputs;
215212 std::vector<pir::Attribute> lhs_output_dim_expr_attrs;
@@ -322,7 +319,7 @@ void InsertYieldOpForCondBlock(pir::Operation* cond_op,
322319pir::Operation* CreateConditionBlock (
323320 const cinn::common::BroadcastTree& broadcast_tree,
324321 const OpLoweringGroupPtr& origin_group,
325- pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
322+ const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
326323 const std::unordered_map<pir::Value, size_t >& value_to_dim_expr_idx,
327324 const std::vector<pir::Value>& group_inputs,
328325 const std::vector<pir::Type>& output_types,
@@ -345,7 +342,7 @@ pir::Operation* CreateConditionBlock(
345342 .Get <cinn::common::BroadcastBranch<cinn::common::BroadcastTree>>();
346343 const auto & [lhs_eq_rhs_cond, lhs_eq_one_cond, rhs_eq_one_cond] =
347344 BroadcastableToCondValue (
348- branch.Get <0 >(), shape_analysis , group_inputs, builder);
345+ branch.Get <0 >(), ShapeOrDataDimExprs4Value , group_inputs, builder);
349346
350347 // lhs == rhs
351348 auto lhs_eq_rhs_cond_op = builder.Build <paddle::dialect::IfOp>(
@@ -354,7 +351,7 @@ pir::Operation* CreateConditionBlock(
354351 builder.SetInsertionPointToBlockEnd (&lhs_eq_rhs_block);
355352 auto * lhs_eq_rhs_block_op = CreateConditionBlock (branch.Get <1 >(),
356353 origin_group,
357- shape_analysis ,
354+ ShapeOrDataDimExprs4Value ,
358355 value_to_dim_expr_idx,
359356 group_inputs,
360357 output_types,
@@ -373,7 +370,7 @@ pir::Operation* CreateConditionBlock(
373370 builder.SetInsertionPointToBlockEnd (&lhs_eq_one_block);
374371 auto * lhs_eq_one_block_op = CreateConditionBlock (branch.Get <2 >(),
375372 origin_group,
376- shape_analysis ,
373+ ShapeOrDataDimExprs4Value ,
377374 value_to_dim_expr_idx,
378375 group_inputs,
379376 output_types,
@@ -387,7 +384,7 @@ pir::Operation* CreateConditionBlock(
387384 builder.SetInsertionPointToBlockEnd (&rhs_eq_one_block);
388385 auto * rhs_eq_one_block_op = CreateConditionBlock (branch.Get <3 >(),
389386 origin_group,
390- shape_analysis ,
387+ ShapeOrDataDimExprs4Value ,
391388 value_to_dim_expr_idx,
392389 group_inputs,
393390 output_types,
@@ -487,17 +484,20 @@ bool NeedBroadcastWithCF(const cinn::common::BroadcastLeaf& leaves) {
487484pir::Operation* CompileBroadcastTreeToConditionBlock (
488485 const OpLoweringGroupPtr& group,
489486 const BroadcastTree& broadcast_tree,
490- pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
491487 const std::unordered_map<pir::Value, size_t >& value_to_dim_expr_idx,
492488 const std::vector<pir::Value>& group_inputs,
493489 const std::vector<pir::Type>& output_types,
494490 pir::PatternRewriter& rewriter) { // NOLINT
491+ auto ShapeOrDataDimExprs4Value =
492+ [&group](pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
493+ return group->GetShapeOrDataExprs (value);
494+ };
495495 // 1. broadcast tree to condition op
496496 VLOG (4 ) << " broadcast tree to condition op" ;
497497 std::unordered_map<pir::Block*, OpLoweringGroupPtr> group_map;
498498 pir::Operation* cond_op = CreateConditionBlock (broadcast_tree,
499499 group,
500- shape_analysis ,
500+ ShapeOrDataDimExprs4Value ,
501501 value_to_dim_expr_idx,
502502 group_inputs,
503503 output_types,
0 commit comments