Skip to content

Commit 9c4f344

Browse files
[PIR] pir onednn bn act fuse pass (#61307)
* pir onednn add conv_bias_pass
1 parent b657279 commit 9c4f344

9 files changed

Lines changed: 711 additions & 72 deletions

File tree

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/pir/transforms/onednn/batch_norm_act_fuse_pass.h"
16+
17+
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
18+
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
19+
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
20+
21+
#include "paddle/pir/pass/pass.h"
22+
#include "paddle/pir/pass/pass_registry.h"
23+
24+
namespace {
25+
class BatchNormActFusePattern : public paddle::drr::DrrPatternBase {
26+
public:
27+
BatchNormActFusePattern(const std::string &bn_name,
28+
const std::string &fused_bn_name)
29+
: bn_name_(bn_name), fused_bn_name_(fused_bn_name) {}
30+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
31+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
32+
33+
const auto &bn =
34+
pat.Op(bn_name_,
35+
{{"momentum", pat.Attr("momentum")},
36+
{"epsilon", pat.Attr("epsilon")},
37+
{"data_format", pat.Attr("data_format")},
38+
{"use_global_stats", pat.Attr("use_global_stats")},
39+
{"trainable_statistics", pat.Attr("trainable_statistics")},
40+
{"is_test", pat.Attr("is_test")}});
41+
const auto &relu = pat.Op(paddle::dialect::ReluOp::name());
42+
bn({&pat.Tensor("x"),
43+
&pat.Tensor("mean"),
44+
&pat.Tensor("variance"),
45+
&pat.Tensor("scale"),
46+
&pat.Tensor("bias")},
47+
{&pat.Tensor("bn_out"),
48+
&pat.Tensor("mean_out"),
49+
&pat.Tensor("variance_out"),
50+
&pat.Tensor("saved_mean"),
51+
&pat.Tensor("saved_variance"),
52+
&pat.Tensor("reserve_space")});
53+
pat.Tensor("relu_out") = relu(pat.Tensor("bn_out"));
54+
55+
pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) {
56+
float epsilon = match_ctx.Attr<float>("epsilon");
57+
if (epsilon < 0.0 || epsilon > 0.001 ||
58+
match_ctx.Attr<bool>("trainable_statistics") == true ||
59+
match_ctx.Attr<bool>("is_test") == false) {
60+
return false;
61+
}
62+
return true;
63+
});
64+
65+
paddle::drr::ResultPattern res = pat.ResultPattern();
66+
67+
const auto &fused_bn =
68+
res.Op(fused_bn_name_,
69+
{{
70+
{"is_test", res.BoolAttr(true)},
71+
{"momentum", pat.Attr("momentum")},
72+
{"epsilon", pat.Attr("epsilon")},
73+
{"data_format", pat.Attr("data_format")},
74+
{"use_global_stats", pat.Attr("use_global_stats")},
75+
{"trainable_statistics", res.BoolAttr(false)},
76+
{"fuse_with_relu", res.BoolAttr(true)},
77+
}});
78+
79+
fused_bn({&res.Tensor("x"),
80+
&res.Tensor("mean"),
81+
&res.Tensor("variance"),
82+
&res.Tensor("scale"),
83+
&res.Tensor("bias")},
84+
{&res.Tensor("relu_out"),
85+
&res.Tensor("mean_out"),
86+
&res.Tensor("variance_out"),
87+
&res.Tensor("saved_mean"),
88+
&res.Tensor("saved_variance"),
89+
&res.Tensor("reserve_space")});
90+
}
91+
92+
std::string name() const override { return "BatchNormActFusePattern"; }
93+
94+
uint32_t benefit() const override { return 2; }
95+
96+
private:
97+
std::string bn_name_;
98+
std::string fused_bn_name_;
99+
};
100+
101+
class BatchNormActFusePass : public pir::PatternRewritePass {
102+
public:
103+
BatchNormActFusePass()
104+
: pir::PatternRewritePass("batch_norm_act_fuse_pass", 2) {}
105+
106+
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
107+
pir::RewritePatternSet ps(context);
108+
ps.Add(BatchNormActFusePattern(paddle::dialect::BatchNormOp::name(),
109+
paddle::onednn::dialect::BatchNormOp::name())
110+
.Build(context));
111+
ps.Add(
112+
BatchNormActFusePattern(paddle::dialect::BatchNorm_Op::name(),
113+
paddle::onednn::dialect::BatchNorm_Op::name())
114+
.Build(context));
115+
return ps;
116+
}
117+
};
118+
119+
} // namespace
120+
121+
namespace pir {
122+
123+
std::unique_ptr<Pass> CreateBatchNormActFusePass() {
124+
// pd_op.batch_norm + pd_op.relu -> onednn_op.batch_norm
125+
return std::make_unique<BatchNormActFusePass>();
126+
}
127+
128+
} // namespace pir
129+
130+
REGISTER_IR_PASS(batch_norm_act_fuse_pass, BatchNormActFusePass);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include "paddle/pir/core/dll_decl.h"
19+
20+
namespace pir {
21+
22+
class Pass;
23+
24+
IR_API std::unique_ptr<Pass> CreateBatchNormActFusePass();
25+
26+
} // namespace pir

0 commit comments

Comments
 (0)