Skip to content

Commit e328cf7

Browse files
authored
[CINN] Refine fully_insert_broadcast_pass (#60676)
* refine fully_insert_broadcast_pass * fix complie bug * fix complie * fix conflict
1 parent d604bcd commit e328cf7

File tree

7 files changed

+66
-56
lines changed

7 files changed

+66
-56
lines changed

paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ if(NOT CINN_ONLY)
1111
cinn_runtime_dialect
1212
pir_compiler)
1313

14-
cc_library(
15-
cinn_transforms
16-
SRCS ${cinn_transforms_srcs}
17-
DEPS ${cinn_transforms_deps})
14+
cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
15+
${cinn_transforms_deps})
1816

1917
endif()

paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc renamed to paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h"
15+
#include "paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h"
1616

1717
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
1818
#include "paddle/cinn/hlir/framework/pir/utils.h"
@@ -50,6 +50,14 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
5050
}
5151
pir::Value x = op->operand_source(0);
5252
pir::Value y = op->operand_source(1);
53+
pir::ShapeConstraintIRAnalysis& shape_analysis =
54+
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
55+
const auto& x_shape = shape_analysis.GetShapeOrDataForValue(x);
56+
const auto& y_shape = shape_analysis.GetShapeOrDataForValue(y);
57+
if (x_shape.shape() == y_shape.shape() && x_shape.data() == y_shape.data()) {
58+
return false;
59+
}
60+
5361
pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y);
5462
{
5563
pir::Value broadcasted_x =
@@ -67,7 +75,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
6775
} // namespace
6876

6977
template <typename OPTYPE>
70-
class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
78+
class InsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
7179
public:
7280
using pir::OpRewritePattern<OPTYPE>::OpRewritePattern;
7381

@@ -77,42 +85,46 @@ class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
7785
}
7886
};
7987

80-
FullyInsertBroadcastPass::FullyInsertBroadcastPass()
81-
: pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {}
82-
83-
pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns(
84-
pir::IrContext* context) {
85-
pir::RewritePatternSet ps(context);
86-
// elementwise ops
87-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::AddOp>>(context);
88-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
89-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
90-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
91-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(
92-
context);
93-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
94-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
95-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
96-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);
97-
98-
// compare ops
99-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
100-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
101-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
102-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
103-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
104-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);
105-
106-
// bitwise ops
107-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
108-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
109-
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);
110-
111-
return ps;
112-
}
88+
class InsertBroadcastPass : public pir::PatternRewritePass {
89+
public:
90+
InsertBroadcastPass() : pir::PatternRewritePass("insert_broadcast_pass", 1) {}
91+
92+
pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
93+
pir::RewritePatternSet ps(context);
94+
// elementwise ops
95+
ps.Add<InsertBroadcastPattern<paddle::dialect::AddOp>>(context);
96+
ps.Add<InsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
97+
ps.Add<InsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
98+
ps.Add<InsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
99+
ps.Add<InsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(context);
100+
ps.Add<InsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
101+
ps.Add<InsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
102+
ps.Add<InsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
103+
ps.Add<InsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);
104+
105+
// compare ops
106+
ps.Add<InsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
107+
ps.Add<InsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
108+
ps.Add<InsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
109+
ps.Add<InsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
110+
ps.Add<InsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
111+
ps.Add<InsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);
112+
113+
// bitwise ops
114+
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
115+
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
116+
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);
117+
118+
return ps;
119+
}
120+
121+
bool CanApplyOn(pir::Operation* op) const override {
122+
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
123+
}
124+
};
113125

114-
bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const {
115-
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
126+
std::unique_ptr<pir::Pass> CreateInsertBroadcastPass() {
127+
return std::make_unique<InsertBroadcastPass>();
116128
}
117129

118130
} // namespace ir

paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h renamed to paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,12 @@
1515
#pragma once
1616

1717
#include "paddle/pir/pass/pass.h"
18-
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
1918

2019
namespace cinn {
2120
namespace dialect {
2221
namespace ir {
2322

24-
class FullyInsertBroadcastPass : public pir::PatternRewritePass {
25-
public:
26-
FullyInsertBroadcastPass();
27-
28-
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;
29-
30-
bool CanApplyOn(pir::Operation *op) const override;
31-
};
23+
IR_API std::unique_ptr<pir::Pass> CreateInsertBroadcastPass();
3224

3325
} // namespace ir
3426
} // namespace dialect

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ bool InferSymbolicShapeElementWiseBinary(
6767
std::vector<symbol::DimExpr> shapes;
6868
symbol::DimExprBuilder builder{nullptr};
6969
for (size_t i = 0; i < shape_0.size(); i++) {
70-
shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i]));
70+
if (shape_0[i] == shape_1[i]) {
71+
shapes.emplace_back(shape_0[i]);
72+
} else {
73+
shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i]));
74+
}
7175
}
7276

7377
// TODO(lanxianghit): fill data when the operation is on shape computation

paddle/fluid/pir/transforms/shape_optimization_pass.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,8 @@ void DebugPrintOpInfo(
8383
}
8484

8585
void InferSymExprForAllValues(ModuleOp module_op) {
86-
auto shape_analysis_mgr = ShapeAnalysisManager::Instance();
8786
ShapeConstraintIRAnalysis& shape_analysis =
88-
shape_analysis_mgr.Get(module_op.program());
87+
ShapeAnalysisManager::Instance().Get(module_op.program());
8988
for (uint32_t i = 0; i < module_op->num_regions(); i++) {
9089
for (auto& block : module_op->region(i)) {
9190
for (auto& op : block) {

paddle/pir/dialect/shape/utils/shape_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ class IR_API ShapeAnalysisManager {
120120
static ShapeAnalysisManager& Instance();
121121
ShapeConstraintIRAnalysis& Get(pir::Program* program);
122122

123+
ShapeAnalysisManager(const ShapeAnalysisManager&) = delete;
124+
ShapeAnalysisManager(ShapeAnalysisManager&&) = delete;
125+
ShapeAnalysisManager& operator=(const ShapeAnalysisManager&) = delete;
126+
123127
private:
124128
ShapeAnalysisManager() {}
125129
std::unordered_map<uint64_t, ShapeConstraintIRAnalysis> tables_;

paddle/pir/pass/utils.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ namespace pir {
1818
namespace detail {
1919

2020
void PrintHeader(const std::string &header, std::ostream &os) {
21-
unsigned padding = (80 - header.size()) / 2;
22-
os << "===" << std::string(73, '-') << "===\n";
21+
const size_t padding = 8;
22+
size_t line_len = header.size() + ((padding - 3) * 2);
23+
os << "===" << std::string(line_len, '-') << "===\n";
2324
os << std::string(padding, ' ') << header << "\n";
24-
os << "===" << std::string(73, '-') << "===\n";
25+
os << "===" << std::string(line_len, '-') << "===\n";
2526
}
2627

2728
} // namespace detail

0 commit comments

Comments
 (0)