Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ void SplitOp::Build(pir::Builder& builder, // NOLINT
pir::Value input,
const std::vector<int>& sections,
int axis) {
VLOG(4) << "Start build ConcatOp";
VLOG(4) << "Start build SplitOp";

argument.inputs.push_back(input);

Expand Down
8 changes: 6 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,25 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : reduce_prod
args : (Tensor x, int64_t[] dim, bool keep_dim)
args : (Tensor x, int64_t[] dim, bool keep_dim, bool reduce_all)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce_all和dim为空是等价的,这里需要reduce_all参数吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为了对齐pd_op和cinn_op的attr,reduce_all参数不参与cinn op的InferMeta和kernel运算

output : Tensor(out)
infer_meta :
func : ReduceInferMeta
param : [x, dim, keep_dim]
kernel :
func : frobenius_norm
param : [x, dim, keep_dim]
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : reduce_sum
args : (Tensor x, int64_t[] dim, bool keep_dim)
args : (Tensor x, int64_t[] dim, bool keep_dim, DataType dtype=DataType::UNDEFINED)
output : Tensor(out)
infer_meta :
func : ReduceInferMeta
param : [x, dim, keep_dim]
kernel :
func : frobenius_norm
param : [x, dim, keep_dim]
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : reshape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ class AddAccuracyCheckPattern
}
}
pir::Operation* pd_op =
cinn::dialect::details::RewriteCinnOpToPdOp(op, builder);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
cinn::dialect::details::RewriteCinnOpToPdOp(op, ir_mapping, builder);
};

const auto& ClonePdOp = [&](::pir::Operation* op) -> void {
Expand Down Expand Up @@ -106,7 +103,7 @@ class AddAccuracyCheckPattern

class AccuarcyCheckPass : public pir::Pass {
public:
AccuarcyCheckPass() : pir::Pass("accuarcy_check_pass", /*opt_level=*/4) {}
AccuarcyCheckPass() : pir::Pass("accuracy_check_pass", /*opt_level=*/4) {}

bool Initialize(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
#include "paddle/fluid/pir/transforms/build_cinn_pass.h"
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
Expand Down Expand Up @@ -161,6 +162,10 @@ void ApplyDivideGroupOpToFusionOpPass(
pass_manager->AddPass(
cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass());
}

pass_manager->AddPass(cinn::dialect::ir::CreateSingleOpFallbackToPhiPass());
pass_manager->AddPass(cinn::dialect::ir::CreateShapeOpsFallbackToPhiPass());

pass_manager->Run(program);
}

Expand All @@ -181,7 +186,6 @@ void ApplyCinnLowerPass(
pass_manager->AddPass(std::move(pass.value()));
}

pass_manager->AddPass(cinn::dialect::ir::CreateSingleOpFallbackToPhiPass());
if (FLAGS_enable_cinn_accuracy_check) {
VLOG(0) << "Enable CINN Accuracy Check Pass";
pass_manager->AddPass(cinn::dialect::ir::CreateAccuarcyCheckPass());
Expand Down
166 changes: 147 additions & 19 deletions paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/builtin_dialect.h"
#include "paddle/pir/include/core/ir_mapping.h"
namespace cinn::dialect::details {

pir::Attribute ArrayAttributeToIntArrayAttribute(
Expand All @@ -36,37 +37,147 @@ pir::Attribute ArrayAttributeToIntArrayAttribute(
return attr_data;
}

const auto& handler_reduce_sum_op =
[](::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"axis", attr_axis});
attrs.insert({"dtype", attrs["dtype"]});
attrs.insert({"keepdim", attrs["keep_dim"]});
attrs.erase("dim");
attrs.erase("keep_dim");

auto pd_op = builder.Build<paddle::dialect::SumOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
};

const auto& handler_reduce_max_op =
[&](::pir::Operation* op,
const ::pir::Builder& builder) -> ::pir::Operation* {
[](::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto cinn_op = op->dyn_cast<cinn::dialect::ReduceMaxOp>();
auto attr = cinn_op.attributes();
auto attrs = op->attributes();

// TODO(chenxi67): 1. CINN op Dialect Normalization;2.AST Op compute
// Normalization
pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attr.at("dim").dyn_cast<::pir::ArrayAttribute>());
attr.insert({"axis", attr_axis});
attr.insert({"keepdim", attr["keep_dim"]});
attr.erase("dim");
attr.erase("keep_dim");

auto pd_op =
const_cast<::pir::Builder*>(&builder)->Build<paddle::dialect::MaxOp>(
cinn_op.operand_source(0), attr);
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"axis", attr_axis});
attrs.insert({"keepdim", attrs["keep_dim"]});
attrs.erase("dim");
attrs.erase("keep_dim");

auto pd_op = builder.Build<paddle::dialect::MaxOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
};

const auto& handler_reduce_min_op =
[](::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"axis", attr_axis});
attrs.insert({"keepdim", attrs["keep_dim"]});
attrs.erase("dim");
attrs.erase("keep_dim");

auto pd_op = builder.Build<paddle::dialect::MinOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
};

const auto& handler_reduce_prod_op =
[](::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"dims", attr_axis});
attrs.erase("dim");

auto pd_op = builder.Build<paddle::dialect::ProdOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
};

::pir::Operation* ConvertSliceOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
pir::Attribute starts = ArrayAttributeToIntArrayAttribute(
attrs.at("starts").dyn_cast<::pir::ArrayAttribute>());
pir::Attribute ends = ArrayAttributeToIntArrayAttribute(
attrs.at("ends").dyn_cast<::pir::ArrayAttribute>());
attrs["starts"] = starts;
attrs["ends"] = ends;
auto pd_op = builder.Build<paddle::dialect::SliceOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

::pir::Operation* ConvertConcatOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
for (auto item : attrs) {
VLOG(0) << item.first;
}
std::vector<pir::Value> vec_inputs;
for (uint32_t i = 0; i < op->num_operands(); ++i) {
vec_inputs.push_back(ir_mapping.Lookup(op->operand_source(i)));
}
auto op_input = builder.Build<pir::CombineOp>(vec_inputs).result(0);

int axis = attrs.at("axis").dyn_cast<::pir::Int32Attribute>().data();

auto pd_op = builder.Build<paddle::dialect::ConcatOp>(op_input, axis);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

bool CanApplyOn(::pir::Operation* op) {
return op->dialect()->name() == "cinn_op";
}

::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation* op,
const ::pir::Builder& builder) {
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(8) << "Rewrite CinnOp to PdOp for op: " << op->name();
auto& op_transformers = TransformContext::Instance();
return op_transformers[op->name()](op, builder);
return op_transformers[op->name()](op, ir_mapping, builder);
}

void RewriteCinnOpToPdOp(const ::pir::Block& src_block,
Expand All @@ -91,20 +202,37 @@ void RewriteCinnOpToPdOp(const ::pir::Block& src_block,
}
::pir::Operation* new_op;
if (CanApplyOn(&op)) {
new_op = RewriteCinnOpToPdOp(&op, builder);
new_op = RewriteCinnOpToPdOp(&op, ir_mapping, builder);
new_op->MoveTo(target_block, target_block->end());
} else {
new_op = op.Clone(ir_mapping, clone_options);
new_op->MoveTo(target_block, target_block->end());
}
for (uint32_t i = 0; i < op.num_results(); ++i) {
ir_mapping.Add(op.result(i), new_op->result(i));
}
}
}

} // namespace cinn::dialect::details

REGISTER_TRANSFORM_RULES(reduce_sum_op,
cinn::dialect::ReduceSumOp::name(),
cinn::dialect::details::handler_reduce_sum_op);

REGISTER_TRANSFORM_RULES(reduce_max_op,
cinn::dialect::ReduceMaxOp::name(),
cinn::dialect::details::handler_reduce_max_op);

REGISTER_TRANSFORM_RULES(reduce_min_op,
cinn::dialect::ReduceMinOp::name(),
cinn::dialect::details::handler_reduce_min_op);

REGISTER_TRANSFORM_RULES(reduce_prod_op,
cinn::dialect::ReduceProdOp::name(),
cinn::dialect::details::handler_reduce_prod_op);

REGISTER_TRANSFORM_RULES(slice_op,
cinn::dialect::SliceOp::name(),
cinn::dialect::details::ConvertSliceOp);

REGISTER_TRANSFORM_RULES(concat_op,
cinn::dialect::ConcatOp::name(),
cinn::dialect::details::ConvertConcatOp);
10 changes: 7 additions & 3 deletions paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
#include "paddle/common/enforce.h"

namespace pir {
class IrMapping;
class Block;
class Operation;
class Builder;
class IrMapping;
} // namespace pir

namespace cinn::dialect::details {

using TRule =
std::function<::pir::Operation*(::pir::Operation*, const ::pir::Builder&)>;
using TRule = std::function<::pir::Operation*(
::pir::Operation*, ::pir::IrMapping&, ::pir::Builder&)>;

class TransformContext {
private:
Expand Down Expand Up @@ -86,6 +88,8 @@ class TransformRegistrar {

void RewriteCinnOpToPdOp(const ::pir::Block& src_block,
::pir::Block* target_block);
::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation*, const ::pir::Builder&);
::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation*,
::pir::IrMapping&, // NOLINT
::pir::Builder&); // NOLINT

} // namespace cinn::dialect::details
10 changes: 7 additions & 3 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class SumOpPattern : public paddle::drr::DrrPatternBase {
const auto &cinn_reduce_sum =
res.Op(cinn::dialect::ReduceSumOp::name(),
{{"dim", pattern.Attr("axis_info")},
{"dtype", pattern.Attr("dtype")},
{"keep_dim", pattern.Attr("keep_dim")}});
res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0"));
}
Expand Down Expand Up @@ -128,16 +129,19 @@ class ProdOpPattern : public paddle::drr::DrrPatternBase {
{"dtype", pattern.Attr("dtype_2")},
{"place", pattern.Attr("place_2")}});

const auto &pd_max = pattern.Op(paddle::dialect::ProdOp::name(),
{{"keep_dim", pattern.Attr("keep_dim")}});
const auto &pd_max =
pattern.Op(paddle::dialect::ProdOp::name(),
{{"keep_dim", pattern.Attr("keep_dim")},
{"reduce_all", pattern.Attr("reduce_all")}});
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceProdOp::name(),
{{"dim", pattern.Attr("axis_info")},
{"keep_dim", pattern.Attr("keep_dim")}});
{"keep_dim", pattern.Attr("keep_dim")},
{"reduce_all", pattern.Attr("reduce_all")}});
res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0"));
}
};
Expand Down
Loading