Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/inference/api/mkldnn_quantizer.h"
#include "paddle/fluid/pir/transforms/onednn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/pir/transforms/onednn/conv_bias_fuse_pass.h"
#endif

Expand Down Expand Up @@ -979,6 +980,9 @@ bool AnalysisPredictor::PrepareExecutor() {
::pir::PassManager mkldnn_pm(::pir::IrContext::Instance(), 2);

mkldnn_pm.AddPass(::pir::CreateConv2dBiasFusePass());
mkldnn_pm.AddPass(::pir::CreateConv2dTransposeBiasFusePass());
mkldnn_pm.AddPass(::pir::CreateConv3dBiasFusePass());
mkldnn_pm.AddPass(::pir::CreateBatchNormActFusePass());

auto constant_folding_pass = ::pir::CreateConstantFoldingPass();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place_);
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@
data_type : x
backward : conv2d_transpose_grad

- op : conv2d_transpose_bias
args : (Tensor x, Tensor filter, Tensor bias, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW")
output : Tensor(out)
infer_meta :
func : Conv2dTransposeInferMeta
param: [x, filter, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format]
kernel :
func : conv2d_transpose_bias
data_type : x

- op : copy_to
args : (Tensor x, Place place, bool blocking)
output : Tensor(out)
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@
extra_args : bool is_test=false
data_format_tensors : input, out_grad

- op : conv2d_transpose
extra_args : bool is_test=false
data_format_tensors : x

- op : conv2d_transpose_bias
extra_args : bool is_test=false, bool force_fp32_output = false, str mkldnn_data_type = "float32", bool fuse_relu = false, str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f
data_format_tensors : x

- op : conv3d
extra_args : bool is_test=false
data_format_tensors : input
Expand Down
111 changes: 111 additions & 0 deletions paddle/fluid/pir/drr/src/ir_operation_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/operation.h"
#include "paddle/pir/include/core/value.h"
#ifdef PADDLE_WITH_DNNL
#include "build/paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build前缀需要删去

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的好的,感谢感谢。现在每次从icoding里复制文件路径都带build,汗

#endif

namespace paddle {
namespace drr {
Expand Down Expand Up @@ -61,6 +64,114 @@ void OperationFactory::RegisterManualOpCreator() {
attrs.at("bias").dyn_cast<pir::FloatAttribute>().data(),
attrs.at("bias_after_scale").dyn_cast<pir::BoolAttribute>().data());
});

#ifdef PADDLE_WITH_DNNL
op_creator_map["onednn_op.conv2d_transpose_bias"] =
[](const std::vector<pir::Value>& inputs,
const pir::AttributeMap& attrs,
pir::PatternRewriter& rewriter) {
if (inputs.size() == 4) {
IR_ENFORCE(
attrs.find("strides") != attrs.end(),
"'strides' Attribute is expected for Conv2dTransposeBiasOp. ");
std::vector<int> strides;
for (size_t i = 0;
i < attrs.at("strides").dyn_cast<pir::ArrayAttribute>().size();
i++) {
strides.push_back(attrs.at("strides")
.dyn_cast<pir::ArrayAttribute>()
.at(i)
.dyn_cast<pir::Int32Attribute>()
.data());
}
Comment on lines +77 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayAttribute转普通vector有没有提供一些工具函数,用公共函数能省不少行

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一块我觉得可以放在pattern实现里的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是从生成的代码里copy过来的。这样的工具函数暂时没有。先写成这样吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一块我觉得可以放在pattern实现里的

pattern是负责匹配的。匹配阶段好像只有“描述”信息,没有实例化Operation。pattern能实现承载这个逻辑吗?


IR_ENFORCE(
attrs.find("paddings") != attrs.end(),
"'paddings' Attribute is expected for Conv2dTransposeBiasOp. ");
std::vector<int> paddings;
for (size_t i = 0;
i < attrs.at("paddings").dyn_cast<pir::ArrayAttribute>().size();
i++) {
paddings.push_back(attrs.at("paddings")
.dyn_cast<pir::ArrayAttribute>()
.at(i)
.dyn_cast<pir::Int32Attribute>()
.data());
}

IR_ENFORCE(attrs.find("output_padding") != attrs.end(),
"'output_padding' Attribute is expected for "
"Conv2dTransposeBiasOp. ");
std::vector<int> output_padding;
for (size_t i = 0; i < attrs.at("output_padding")
.dyn_cast<pir::ArrayAttribute>()
.size();
i++) {
output_padding.push_back(attrs.at("output_padding")
.dyn_cast<pir::ArrayAttribute>()
.at(i)
.dyn_cast<pir::Int32Attribute>()
.data());
}

IR_ENFORCE(attrs.find("padding_algorithm") != attrs.end(),
Copy link
Contributor

@winter-wang winter-wang Mar 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

禁止使用IR_ENFORCE, 这儿需要改成PADDLE_ENFORCE_EQ。其它地方类似。具体参考https://github.com/PaddlePaddle/Paddle/wiki/PADDLE_ENFORCE-Rewriting-Specification

"'padding_algorithm' Attribute is expected for "
"Conv2dTransposeBiasOp. ");
std::string padding_algorithm = attrs.at("padding_algorithm")
.dyn_cast<pir::StrAttribute>()
.AsString();

IR_ENFORCE(
attrs.find("groups") != attrs.end(),
"'groups' Attribute is expected for Conv2dTransposeBiasOp. ");
int groups =
attrs.at("groups").dyn_cast<pir::Int32Attribute>().data();

IR_ENFORCE(
attrs.find("dilations") != attrs.end(),
"'dilations' Attribute is expected for Conv2dTransposeBiasOp. ");
std::vector<int> dilations;
for (size_t i = 0;
i < attrs.at("dilations").dyn_cast<pir::ArrayAttribute>().size();
i++) {
dilations.push_back(attrs.at("dilations")
.dyn_cast<pir::ArrayAttribute>()
.at(i)
.dyn_cast<pir::Int32Attribute>()
.data());
}

IR_ENFORCE(attrs.find("data_format") != attrs.end(),
"'data_format' Attribute is expected for "
"Conv2dTransposeBiasOp. ");
std::string data_format =
attrs.at("data_format").dyn_cast<pir::StrAttribute>().AsString();

IR_ENFORCE(
attrs.find("is_test") != attrs.end(),
"'is_test' Attribute is expected for Conv2dTransposeBiasOp. ");
bool is_test =
attrs.at("is_test").dyn_cast<pir::BoolAttribute>().data();

return rewriter.Build<paddle::onednn::dialect::Conv2dTransposeBiasOp>(
inputs[0],
inputs[1],
inputs[2],
inputs[3],
strides,
paddings,
output_padding,
padding_algorithm,
groups,
dilations,
data_format,
is_test);
}

return rewriter.Build<paddle::onednn::dialect::Conv2dTransposeBiasOp>(
inputs[0], inputs[1], inputs[2], attrs);
};
#endif
}

pir::Attribute CreateIrAttribute(const std::any& obj) {
Expand Down
186 changes: 169 additions & 17 deletions paddle/fluid/pir/transforms/onednn/conv_bias_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,157 @@ class FusedConvAddFusePattern : public paddle::drr::DrrPatternBase {
}
};

class ConvTransposeBiasFusePattern : public paddle::drr::DrrPatternBase {
std::string name() const override { return "ConvTransposeBiasFusePattern"; }

uint32_t benefit() const override { return 2; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();

const auto &conv =
pat.Op(paddle::dialect::Conv2dTransposeOp::name(),
{{"strides", pat.Attr("strides")},
{"paddings", pat.Attr("paddings")},
{"output_padding", pat.Attr("output_padding")},
{"padding_algorithm", pat.Attr("padding_algorithm")},
{"dilations", pat.Attr("dilations")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")}});

const auto &add = pat.Op(paddle::dialect::AddOp::name());
conv({&pat.Tensor("input"),
&pat.Tensor("filter"),
&pat.Tensor("output_size")},
{&pat.Tensor("conv_out")});
const auto &parameter_bias = pat.Op(
pir::ParameterOp::name(), {{"parameter_name", pat.Attr("param_name")}});
pat.Tensor("bias") = parameter_bias();
pat.Tensor("add_out") = add(pat.Tensor("conv_out"), pat.Tensor("bias"));

pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) {
std::set<std::string> padding_algorithm = {"EXPLICIT", "SAME", "VALID"};
std::set<std::string> data_format = {"NCHW", "NHWC", "AnyLayout"};
if (padding_algorithm.count(
match_ctx.Attr<std::string>("padding_algorithm")) == 0 ||
data_format.count(match_ctx.Attr<std::string>("data_format")) == 0 ||
match_ctx.Attr<int>("groups") < 1) {
return false;
}
return true;
});

paddle::drr::ResultPattern res = pat.ResultPattern();

const auto &fused_conv =
res.Op(paddle::onednn::dialect::Conv2dTransposeBiasOp::name(),
{{
{"strides", pat.Attr("strides")},
{"paddings", pat.Attr("paddings")},
{"output_padding", pat.Attr("output_padding")},
{"padding_algorithm", pat.Attr("padding_algorithm")},
{"dilations", pat.Attr("dilations")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")},
{"force_fp32_output", res.BoolAttr(false)},
{"mkldnn_data_type", res.StrAttr("float32")},
{"fuse_relu", res.BoolAttr(false)},
{"fuse_activation", res.StrAttr("")},
{"fuse_alpha", res.Float32Attr(0.0f)},
{"fuse_beta", res.Float32Attr(0.0f)},
{"is_test", res.BoolAttr(true)},
}});

fused_conv({&res.Tensor("input"),
&res.Tensor("filter"),
&res.Tensor("bias"),
&res.Tensor("output_size")},
{&res.Tensor("add_out")});
}
};

class FusedConvTransposeAddFusePattern : public paddle::drr::DrrPatternBase {
std::string name() const override {
return "FusedConvTransposeAddFusePattern";
}

uint32_t benefit() const override { return 3; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &conv =
pat.Op(paddle::dialect::Conv2dTransposeOp::name(),
{{"strides", pat.Attr("strides")},
{"paddings", pat.Attr("paddings")},
{"output_padding", pat.Attr("output_padding")},
{"padding_algorithm", pat.Attr("padding_algorithm")},
{"dilations", pat.Attr("dilations")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")}});

const auto &add = pat.Op(paddle::dialect::AddOp::name());
const auto &add2 = pat.Op(paddle::dialect::AddOp::name());
conv({&pat.Tensor("input"),
&pat.Tensor("filter"),
&pat.Tensor("output_size")},
{&pat.Tensor("conv_out")});
const auto &parameter_bias = pat.Op(
pir::ParameterOp::name(), {{"parameter_name", pat.Attr("param_name")}});
pat.Tensor("bias") = parameter_bias();

pat.Tensor("add_out") = add(pat.Tensor("conv_out"), pat.Tensor("bias"));

const auto &parameter = pat.Op(
pir::ParameterOp::name(), {{"parameter_name", pat.Attr("param_name")}});
pat.Tensor("other_param") = parameter();
pat.Tensor("result") =
add2(pat.Tensor("add_out"), pat.Tensor("other_param"));

pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) {
std::set<std::string> padding_algorithm = {"EXPLICIT", "SAME", "VALID"};
std::set<std::string> data_format = {"NCHW", "NHWC", "AnyLayout"};
if (padding_algorithm.count(
match_ctx.Attr<std::string>("padding_algorithm")) == 0 ||
data_format.count(match_ctx.Attr<std::string>("data_format")) == 0 ||
match_ctx.Attr<int>("groups") < 1) {
return false;
}
return true;
});

paddle::drr::ResultPattern res = pat.ResultPattern();

const auto &fused_add = res.Op(paddle::dialect::AddOp::name());
res.Tensor("bias2") =
fused_add(res.Tensor("bias"), res.Tensor("other_param"));

const auto &fused_conv =
res.Op(paddle::onednn::dialect::Conv2dTransposeBiasOp::name(),
{{
{"strides", pat.Attr("strides")},
{"paddings", pat.Attr("paddings")},
{"output_padding", pat.Attr("output_padding")},
{"padding_algorithm", pat.Attr("padding_algorithm")},
{"dilations", pat.Attr("dilations")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")},
{"force_fp32_output", res.BoolAttr(false)},
{"mkldnn_data_type", res.StrAttr("float32")},
{"fuse_relu", res.BoolAttr(false)},
{"fuse_activation", res.StrAttr("")},
{"fuse_alpha", res.Float32Attr(0.0f)},
{"fuse_beta", res.Float32Attr(0.0f)},
{"is_test", res.BoolAttr(true)},
}});

fused_conv({&res.Tensor("input"),
&res.Tensor("filter"),
&res.Tensor("bias2"),
&res.Tensor("output_size")},
{&res.Tensor("result")});
}
};

class Conv2dBiasFusePass : public pir::PatternRewritePass {
public:
Conv2dBiasFusePass() : pir::PatternRewritePass("conv2d_bias_fuse_pass", 2) {}
Expand All @@ -240,18 +391,18 @@ class Conv2dBiasFusePass : public pir::PatternRewritePass {
}
};

// class Conv2dTransposeBiasFusePass : public pir::PatternRewritePass {
// public:
// Conv2dTransposeBiasFusePass()
// : pir::PatternRewritePass("conv2d_transpose_bias_fuse_pass", 2) {}
class Conv2dTransposeBiasFusePass : public pir::PatternRewritePass {
public:
Conv2dTransposeBiasFusePass()
: pir::PatternRewritePass("conv2d_transpose_bias_fuse_pass", 2) {}

// pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override
// {
// pir::RewritePatternSet ps(context);
// ps.Add(paddle::drr::Create<Conv2dBiasFusePattern>(context));
// return ps;
// }
// };
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(paddle::drr::Create<ConvTransposeBiasFusePattern>(context));
ps.Add(paddle::drr::Create<FusedConvTransposeAddFusePattern>(context));
return ps;
}
};

class Conv3dBiasFusePass : public pir::PatternRewritePass {
public:
Expand Down Expand Up @@ -281,10 +432,12 @@ std::unique_ptr<Pass> CreateConv2dBiasFusePass() {
return std::make_unique<Conv2dBiasFusePass>();
}

// std::unique_ptr<Pass> CreateConv2dTransposeBiasFusePass() {
// // pd_op.conv2d_transpose + pd_op.add -> onednn_op.fused_conv2d
// return std::make_unique<Conv2dTransposeBiasFusePass>();
// }
std::unique_ptr<Pass> CreateConv2dTransposeBiasFusePass() {
// pd_op.conv2d_transpose + pd_op.add -> onednn_op.conv2d_transpose_bias
// onednn_op.conv2d_transpose_bias + pd_op.add ->
// onednn_op.conv2d_transpose_bias + pd_op.add
return std::make_unique<Conv2dTransposeBiasFusePass>();
}

std::unique_ptr<Pass> CreateConv3dBiasFusePass() {
// pd_op.conv3d + pd_op.add -> onednn_op.fused_conv3d
Expand All @@ -294,6 +447,5 @@ std::unique_ptr<Pass> CreateConv3dBiasFusePass() {
} // namespace pir

REGISTER_IR_PASS(conv2d_bias_fuse_pass, Conv2dBiasFusePass);
// REGISTER_IR_PASS(conv2d_transpose_bias_fuse_pass,
// Conv2dTransposeBiasFusePass);
REGISTER_IR_PASS(conv2d_transpose_bias_fuse_pass, Conv2dTransposeBiasFusePass);
REGISTER_IR_PASS(conv3d_bias_fuse_pass, Conv3dBiasFusePass);
Loading