1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15- #include " paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass .h"
15+ #include " paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass .h"
1616
1717#include " paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
1818#include " paddle/cinn/hlir/framework/pir/utils.h"
@@ -50,6 +50,14 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
5050 }
5151 pir::Value x = op->operand_source (0 );
5252 pir::Value y = op->operand_source (1 );
53+ pir::ShapeConstraintIRAnalysis& shape_analysis =
54+ pir::ShapeAnalysisManager::Instance ().Get (op->GetParentProgram ());
55+ const auto & x_shape = shape_analysis.GetShapeOrDataForValue (x);
56+ const auto & y_shape = shape_analysis.GetShapeOrDataForValue (y);
57+ if (x_shape.shape () == y_shape.shape () && x_shape.data () == y_shape.data ()) {
58+ return false ;
59+ }
60+
5361 pir::Value output_dim_tensor = GetOutputDimTensor (rewriter, x, y);
5462 {
5563 pir::Value broadcasted_x =
@@ -67,7 +75,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
6775} // namespace
6876
6977template <typename OPTYPE>
70- class FullyInsertBroadcastPattern : public pir ::OpRewritePattern<OPTYPE> {
78+ class InsertBroadcastPattern : public pir ::OpRewritePattern<OPTYPE> {
7179 public:
7280 using pir::OpRewritePattern<OPTYPE>::OpRewritePattern;
7381
@@ -77,42 +85,46 @@ class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
7785 }
7886};
7987
80- FullyInsertBroadcastPass::FullyInsertBroadcastPass ()
81- : pir::PatternRewritePass(" fully_insert_broadcast_pass" , 1 ) {}
82-
83- pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns (
84- pir::IrContext* context) {
85- pir::RewritePatternSet ps (context);
86- // elementwise ops
87- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::AddOp>>(context);
88- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
89- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
90- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
91- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(
92- context);
93- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
94- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
95- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
96- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);
97-
98- // compare ops
99- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
100- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
101- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
102- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
103- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
104- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);
105-
106- // bitwise ops
107- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
108- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
109- ps.Add <FullyInsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);
110-
111- return ps;
112- }
88+ class InsertBroadcastPass : public pir ::PatternRewritePass {
89+ public:
90+ InsertBroadcastPass () : pir::PatternRewritePass(" insert_broadcast_pass" , 1 ) {}
91+
92+ pir::RewritePatternSet InitializePatterns (pir::IrContext* context) override {
93+ pir::RewritePatternSet ps (context);
94+ // elementwise ops
95+ ps.Add <InsertBroadcastPattern<paddle::dialect::AddOp>>(context);
96+ ps.Add <InsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
97+ ps.Add <InsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
98+ ps.Add <InsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
99+ ps.Add <InsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(context);
100+ ps.Add <InsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
101+ ps.Add <InsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
102+ ps.Add <InsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
103+ ps.Add <InsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);
104+
105+ // compare ops
106+ ps.Add <InsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
107+ ps.Add <InsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
108+ ps.Add <InsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
109+ ps.Add <InsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
110+ ps.Add <InsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
111+ ps.Add <InsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);
112+
113+ // bitwise ops
114+ ps.Add <InsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
115+ ps.Add <InsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
116+ ps.Add <InsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);
117+
118+ return ps;
119+ }
120+
121+ bool CanApplyOn (pir::Operation* op) const override {
122+ return op->isa <pir::ModuleOp>() && op->num_regions () > 0 ;
123+ }
124+ };
113125
114- bool FullyInsertBroadcastPass::CanApplyOn ( pir::Operation* op) const {
115- return op-> isa <pir::ModuleOp >() && op-> num_regions () > 0 ;
126+ std::unique_ptr< pir::Pass> CreateInsertBroadcastPass () {
127+ return std::make_unique<InsertBroadcastPass >();
116128}
117129
118130} // namespace ir
0 commit comments