Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ void SplitOp::Build(pir::Builder& builder, // NOLINT
pir::Value input,
const std::vector<int>& sections,
int axis) {
VLOG(4) << "Start build ConcatOp";
VLOG(4) << "Start build SplitOp";

argument.inputs.push_back(input);

Expand Down
8 changes: 6 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,25 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : reduce_prod
args : (Tensor x, int64_t[] dim, bool keep_dim)
args : (Tensor x, int64_t[] dim, bool keep_dim, bool reduce_all)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce_all和dim为空是等价的,这里需要reduce_all参数吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为了对齐pd_op和cinn_op的attr,reduce_all参数不参与cinn op的InferMeta和kernel运算

output : Tensor(out)
infer_meta :
func : ReduceInferMeta
param : [x, dim, keep_dim]
kernel :
func : frobenius_norm
param : [x, dim, keep_dim]
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : reduce_sum
args : (Tensor x, int64_t[] dim, bool keep_dim)
args : (Tensor x, int64_t[] dim, bool keep_dim, DataType dtype=DataType::UNDEFINED)
output : Tensor(out)
infer_meta :
func : ReduceInferMeta
param : [x, dim, keep_dim]
kernel :
func : frobenius_norm
param : [x, dim, keep_dim]
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : reshape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ class AddAccuracyCheckPattern
}
}
pir::Operation* pd_op =
cinn::dialect::details::RewriteCinnOpToPdOp(op, builder);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
cinn::dialect::details::RewriteCinnOpToPdOp(op, ir_mapping, builder);
};

const auto& ClonePdOp = [&](::pir::Operation* op) -> void {
Expand Down Expand Up @@ -106,7 +103,7 @@ class AddAccuracyCheckPattern

class AccuarcyCheckPass : public pir::Pass {
public:
AccuarcyCheckPass() : pir::Pass("accuarcy_check_pass", /*opt_level=*/4) {}
AccuarcyCheckPass() : pir::Pass("accuracy_check_pass", /*opt_level=*/4) {}

bool Initialize(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
#include "paddle/fluid/pir/transforms/build_cinn_pass.h"
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
Expand Down Expand Up @@ -156,6 +157,10 @@ void ApplyDivideGroupOpToFusionOpPass(
pass_manager->AddPass(
cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass());
}

pass_manager->AddPass(cinn::dialect::ir::CreateSingleOpFallbackToPhiPass());
pass_manager->AddPass(cinn::dialect::ir::CreateShapeOpsFallbackToPhiPass());

pass_manager->Run(program);
}

Expand All @@ -176,7 +181,6 @@ void ApplyCinnLowerPass(
pass_manager->AddPass(std::move(pass.value()));
}

pass_manager->AddPass(cinn::dialect::ir::CreateSingleOpFallbackToPhiPass());
if (FLAGS_enable_cinn_accuracy_check) {
VLOG(0) << "Enable CINN Accuracy Check Pass";
pass_manager->AddPass(cinn::dialect::ir::CreateAccuarcyCheckPass());
Expand Down
166 changes: 147 additions & 19 deletions paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/builtin_dialect.h"
#include "paddle/pir/include/core/ir_mapping.h"
namespace cinn::dialect::details {

pir::Attribute ArrayAttributeToIntArrayAttribute(
Expand All @@ -36,37 +37,147 @@ pir::Attribute ArrayAttributeToIntArrayAttribute(
return attr_data;
}

const auto& handler_reduce_sum_op =
[](::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"axis", attr_axis});
attrs.insert({"dtype", attrs["dtype"]});
attrs.insert({"keepdim", attrs["keep_dim"]});
attrs.erase("dim");
attrs.erase("keep_dim");

auto pd_op = builder.Build<paddle::dialect::SumOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
};

const auto& handler_reduce_max_op =
[&](::pir::Operation* op,
const ::pir::Builder& builder) -> ::pir::Operation* {
[](::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto cinn_op = op->dyn_cast<cinn::dialect::ReduceMaxOp>();
auto attr = cinn_op.attributes();
auto attrs = op->attributes();

// TODO(chenxi67): 1. CINN op Dialect Normalization;2.AST Op compute
// Normalization
pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attr.at("dim").dyn_cast<::pir::ArrayAttribute>());
attr.insert({"axis", attr_axis});
attr.insert({"keepdim", attr["keep_dim"]});
attr.erase("dim");
attr.erase("keep_dim");

auto pd_op =
const_cast<::pir::Builder*>(&builder)->Build<paddle::dialect::MaxOp>(
cinn_op.operand_source(0), attr);
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"axis", attr_axis});
attrs.insert({"keepdim", attrs["keep_dim"]});
attrs.erase("dim");
attrs.erase("keep_dim");

auto pd_op = builder.Build<paddle::dialect::MaxOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
};

const auto& handler_reduce_min_op =
[](::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"axis", attr_axis});
attrs.insert({"keepdim", attrs["keep_dim"]});
attrs.erase("dim");
attrs.erase("keep_dim");

auto pd_op = builder.Build<paddle::dialect::MinOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
};

const auto& handler_reduce_prod_op =
[](::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"dims", attr_axis});
attrs.erase("dim");

auto pd_op = builder.Build<paddle::dialect::ProdOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
};

::pir::Operation* ConvertSliceOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
pir::Attribute starts = ArrayAttributeToIntArrayAttribute(
attrs.at("starts").dyn_cast<::pir::ArrayAttribute>());
pir::Attribute ends = ArrayAttributeToIntArrayAttribute(
attrs.at("ends").dyn_cast<::pir::ArrayAttribute>());
attrs["starts"] = starts;
attrs["ends"] = ends;
auto pd_op = builder.Build<paddle::dialect::SliceOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

::pir::Operation* ConvertConcatOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
for (auto item : attrs) {
VLOG(0) << item.first;
}
std::vector<pir::Value> vec_inputs;
for (uint32_t i = 0; i < op->num_operands(); ++i) {
vec_inputs.push_back(ir_mapping.Lookup(op->operand_source(i)));
}
auto op_input = builder.Build<pir::CombineOp>(vec_inputs).result(0);

int axis = attrs.at("axis").dyn_cast<::pir::Int32Attribute>().data();

auto pd_op = builder.Build<paddle::dialect::ConcatOp>(op_input, axis);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

bool CanApplyOn(::pir::Operation* op) {
return op->dialect()->name() == "cinn_op";
}

::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation* op,
const ::pir::Builder& builder) {
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(8) << "Rewrite CinnOp to PdOp for op: " << op->name();
auto& op_transformers = TransformContext::Instance();
return op_transformers[op->name()](op, builder);
return op_transformers[op->name()](op, ir_mapping, builder);
}

void RewriteCinnOpToPdOp(const ::pir::Block& src_block,
Expand All @@ -91,20 +202,37 @@ void RewriteCinnOpToPdOp(const ::pir::Block& src_block,
}
::pir::Operation* new_op;
if (CanApplyOn(&op)) {
new_op = RewriteCinnOpToPdOp(&op, builder);
new_op = RewriteCinnOpToPdOp(&op, ir_mapping, builder);
new_op->MoveTo(target_block, target_block->end());
} else {
new_op = op.Clone(ir_mapping, clone_options);
new_op->MoveTo(target_block, target_block->end());
}
for (uint32_t i = 0; i < op.num_results(); ++i) {
ir_mapping.Add(op.result(i), new_op->result(i));
}
}
}

} // namespace cinn::dialect::details

REGISTER_TRANSFORM_RULES(reduce_sum_op,
cinn::dialect::ReduceSumOp::name(),
cinn::dialect::details::handler_reduce_sum_op);

REGISTER_TRANSFORM_RULES(reduce_max_op,
cinn::dialect::ReduceMaxOp::name(),
cinn::dialect::details::handler_reduce_max_op);

REGISTER_TRANSFORM_RULES(reduce_min_op,
cinn::dialect::ReduceMinOp::name(),
cinn::dialect::details::handler_reduce_min_op);

REGISTER_TRANSFORM_RULES(reduce_prod_op,
cinn::dialect::ReduceProdOp::name(),
cinn::dialect::details::handler_reduce_prod_op);

REGISTER_TRANSFORM_RULES(slice_op,
cinn::dialect::SliceOp::name(),
cinn::dialect::details::ConvertSliceOp);

REGISTER_TRANSFORM_RULES(concat_op,
cinn::dialect::ConcatOp::name(),
cinn::dialect::details::ConvertConcatOp);
10 changes: 7 additions & 3 deletions paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
#include "paddle/common/enforce.h"

namespace pir {
class IrMapping;
class Block;
class Operation;
class Builder;
class IrMapping;
} // namespace pir

namespace cinn::dialect::details {

using TRule =
std::function<::pir::Operation*(::pir::Operation*, const ::pir::Builder&)>;
using TRule = std::function<::pir::Operation*(
::pir::Operation*, ::pir::IrMapping&, ::pir::Builder&)>;

class TransformContext {
private:
Expand Down Expand Up @@ -86,6 +88,8 @@ class TransformRegistrar {

void RewriteCinnOpToPdOp(const ::pir::Block& src_block,
::pir::Block* target_block);
::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation*, const ::pir::Builder&);
::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation*,
::pir::IrMapping&, // NOLINT
::pir::Builder&); // NOLINT

} // namespace cinn::dialect::details
10 changes: 7 additions & 3 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class SumOpPattern : public paddle::drr::DrrPatternBase {
const auto &cinn_reduce_sum =
res.Op(cinn::dialect::ReduceSumOp::name(),
{{"dim", pattern.Attr("axis_info")},
{"dtype", pattern.Attr("dtype")},
{"keep_dim", pattern.Attr("keep_dim")}});
res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0"));
}
Expand Down Expand Up @@ -128,16 +129,19 @@ class ProdOpPattern : public paddle::drr::DrrPatternBase {
{"dtype", pattern.Attr("dtype_2")},
{"place", pattern.Attr("place_2")}});

const auto &pd_max = pattern.Op(paddle::dialect::ProdOp::name(),
{{"keep_dim", pattern.Attr("keep_dim")}});
const auto &pd_max =
pattern.Op(paddle::dialect::ProdOp::name(),
{{"keep_dim", pattern.Attr("keep_dim")},
{"reduce_all", pattern.Attr("reduce_all")}});
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceProdOp::name(),
{{"dim", pattern.Attr("axis_info")},
{"keep_dim", pattern.Attr("keep_dim")}});
{"keep_dim", pattern.Attr("keep_dim")},
{"reduce_all", pattern.Attr("reduce_all")}});
res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0"));
}
};
Expand Down
Loading