Skip to content

Commit c1ce849

Browse files
[CINN]fix CompileBroadcastTreeToConditionBlock (#64659)
* [CINN]fix CompileBroadcastTreeToConditionBlock by using local ShapeOrDataDimExprs * fix lambda return type
1 parent 154b007 commit c1ce849

3 files changed

Lines changed: 17 additions & 23 deletions

File tree

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,11 @@ void ReplaceExpandWithBroadcast(pir::IrContext* ir_context,
202202

203203
std::tuple<pir::Value, pir::Value, pir::Value> BroadcastableToCondValue(
204204
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
205-
pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
205+
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
206206
const std::vector<pir::Value>& group_inputs,
207207
pir::Builder& builder) { // NOLINT
208208
const auto& lhs_expr = broadcastable_condition->lhs;
209209
const auto& rhs_expr = broadcastable_condition->rhs;
210-
auto ShapeOrDataDimExprs4Value = [&shape_analysis](pir::Value value) {
211-
return shape_analysis.GetShapeOrDataForValue(value);
212-
};
213210

214211
std::vector<pir::Value> lhs_minimal_inputs;
215212
std::vector<pir::Attribute> lhs_output_dim_expr_attrs;
@@ -322,7 +319,7 @@ void InsertYieldOpForCondBlock(pir::Operation* cond_op,
322319
pir::Operation* CreateConditionBlock(
323320
const cinn::common::BroadcastTree& broadcast_tree,
324321
const OpLoweringGroupPtr& origin_group,
325-
pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
322+
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
326323
const std::unordered_map<pir::Value, size_t>& value_to_dim_expr_idx,
327324
const std::vector<pir::Value>& group_inputs,
328325
const std::vector<pir::Type>& output_types,
@@ -345,7 +342,7 @@ pir::Operation* CreateConditionBlock(
345342
.Get<cinn::common::BroadcastBranch<cinn::common::BroadcastTree>>();
346343
const auto& [lhs_eq_rhs_cond, lhs_eq_one_cond, rhs_eq_one_cond] =
347344
BroadcastableToCondValue(
348-
branch.Get<0>(), shape_analysis, group_inputs, builder);
345+
branch.Get<0>(), ShapeOrDataDimExprs4Value, group_inputs, builder);
349346

350347
// lhs == rhs
351348
auto lhs_eq_rhs_cond_op = builder.Build<paddle::dialect::IfOp>(
@@ -354,7 +351,7 @@ pir::Operation* CreateConditionBlock(
354351
builder.SetInsertionPointToBlockEnd(&lhs_eq_rhs_block);
355352
auto* lhs_eq_rhs_block_op = CreateConditionBlock(branch.Get<1>(),
356353
origin_group,
357-
shape_analysis,
354+
ShapeOrDataDimExprs4Value,
358355
value_to_dim_expr_idx,
359356
group_inputs,
360357
output_types,
@@ -373,7 +370,7 @@ pir::Operation* CreateConditionBlock(
373370
builder.SetInsertionPointToBlockEnd(&lhs_eq_one_block);
374371
auto* lhs_eq_one_block_op = CreateConditionBlock(branch.Get<2>(),
375372
origin_group,
376-
shape_analysis,
373+
ShapeOrDataDimExprs4Value,
377374
value_to_dim_expr_idx,
378375
group_inputs,
379376
output_types,
@@ -387,7 +384,7 @@ pir::Operation* CreateConditionBlock(
387384
builder.SetInsertionPointToBlockEnd(&rhs_eq_one_block);
388385
auto* rhs_eq_one_block_op = CreateConditionBlock(branch.Get<3>(),
389386
origin_group,
390-
shape_analysis,
387+
ShapeOrDataDimExprs4Value,
391388
value_to_dim_expr_idx,
392389
group_inputs,
393390
output_types,
@@ -487,17 +484,20 @@ bool NeedBroadcastWithCF(const cinn::common::BroadcastLeaf& leaves) {
487484
pir::Operation* CompileBroadcastTreeToConditionBlock(
488485
const OpLoweringGroupPtr& group,
489486
const BroadcastTree& broadcast_tree,
490-
pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
491487
const std::unordered_map<pir::Value, size_t>& value_to_dim_expr_idx,
492488
const std::vector<pir::Value>& group_inputs,
493489
const std::vector<pir::Type>& output_types,
494490
pir::PatternRewriter& rewriter) { // NOLINT
491+
auto ShapeOrDataDimExprs4Value =
492+
[&group](pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
493+
return group->GetShapeOrDataExprs(value);
494+
};
495495
// 1. broadcast tree to condition op
496496
VLOG(4) << "broadcast tree to condition op";
497497
std::unordered_map<pir::Block*, OpLoweringGroupPtr> group_map;
498498
pir::Operation* cond_op = CreateConditionBlock(broadcast_tree,
499499
group,
500-
shape_analysis,
500+
ShapeOrDataDimExprs4Value,
501501
value_to_dim_expr_idx,
502502
group_inputs,
503503
output_types,

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ GroupDimExprInfo GetGroupDimExprInfo(const OpLoweringGroupPtr& group);
3737
pir::Operation* CompileBroadcastTreeToConditionBlock(
3838
const OpLoweringGroupPtr& group,
3939
const BroadcastTree& broadcast_tree,
40-
pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
4140
const std::unordered_map<pir::Value, size_t>& value_to_dim_expr_idx,
4241
const std::vector<pir::Value>& group_inputs,
4342
const std::vector<pir::Type>& output_types,

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@
3030

3131
namespace cinn::dialect::ir::details {
3232

33-
pir::Operation* ProcessDyShapeGroup(
34-
const OpLoweringGroupPtr& group,
35-
pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
36-
pir::PatternRewriter& rewriter) { // NOLINT
33+
pir::Operation* ProcessDyShapeGroup(const OpLoweringGroupPtr& group,
34+
pir::PatternRewriter& rewriter) { // NOLINT
3735
// NOTE(dev): Need UpdateShapeOrDataExprs firstly and the logic
3836
// will be migated into BucketLower later.
3937
UpdateGroupShapeOrDataExprs(const_cast<OpLoweringGroupPtr&>(group));
@@ -53,7 +51,6 @@ pir::Operation* ProcessDyShapeGroup(
5351
}
5452
return CompileBroadcastTreeToConditionBlock(group,
5553
*broadcast_tree,
56-
shape_analysis,
5754
value_to_dim_expr_idx,
5855
group_inputs,
5956
output_types,
@@ -85,7 +82,7 @@ class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
8582

8683
// TODO(zhangyuqin1998): Replace pir::Group with a new structure
8784
OpLoweringGroupPtr group = GetGroup(fusion_op);
88-
pir::Operation* compiled_op = ProcessGroup(group, shape_analysis, rewriter);
85+
pir::Operation* compiled_op = ProcessGroup(group, rewriter);
8986

9087
for (size_t i = 0; i < fusion_op.num_results(); ++i) {
9188
rewriter.ReplaceAllUsesWith(fusion_op.result(i), compiled_op->result(i));
@@ -104,8 +101,7 @@ class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
104101

105102
virtual pir::Operation* ProcessGroup(
106103
const OpLoweringGroupPtr& group,
107-
pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
108-
pir::PatternRewriter& rewriter) const { // NOLINT
104+
pir::PatternRewriter& rewriter) const { // NOLINT
109105
auto group_inputs = GetBlockOutsideInput(group->ops());
110106
// compile group to jit_kernel_op
111107
std::vector<pir::Type> output_types;
@@ -156,9 +152,8 @@ class DyShapeFusionOpPattern : public FusionOpPattern {
156152
protected:
157153
virtual pir::Operation* ProcessGroup(
158154
const OpLoweringGroupPtr& group,
159-
pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
160-
pir::PatternRewriter& rewriter) const { // NOLINT
161-
return ProcessDyShapeGroup(group, shape_analysis, rewriter);
155+
pir::PatternRewriter& rewriter) const { // NOLINT
156+
return ProcessDyShapeGroup(group, rewriter);
162157
}
163158
};
164159

0 commit comments

Comments
 (0)