Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ void GroupOp::Print(pir::IrPrinter& printer) {
auto& os = printer.os;
auto op = operation();
printer.PrintOpResult(op);
os << " = " << name();
os << " = " << name() << " [id:" << op->id() << "]";
printer.PrintOpOperands(op);
os << " -> ";
printer.PrintOpReturnType(op);
Expand Down Expand Up @@ -181,7 +181,7 @@ void FusionOp::Print(pir::IrPrinter& printer) {
auto& os = printer.os;
auto op = operation();
printer.PrintOpResult(op);
os << " = " << name();
os << " = " << name() << " [id:" << op->id() << "]";
printer.PrintOpOperands(op);
os << " -> ";
printer.PrintOpReturnType(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class AddBroadcastToElementwisePass : public pir::PatternRewritePass {
}

bool CanApplyOn(pir::Operation* op) const override {
return op->num_regions() > 0;
return op->num_regions() > 0 && op->isa<cinn::dialect::GroupOp>();
}
};

Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ std::ostream& operator<<(std::ostream& os, const OpLoweringGroup& group) {
os << "}";
};
::pir::IrPrinter printer(os);
os << "Group " << group.group_id() << " :\n";
os << "Group id: " << group.group_id() << ", func_name: " << group.FuncName()
<< "\n";
for (auto* op : group.ops()) {
printer.PrintOperation(op);
PrintSymbolDims(*op);
Expand Down
20 changes: 9 additions & 11 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3994,7 +3994,7 @@ void ShapeBroadcastOp::Build(pir::Builder &builder,
}

void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::ElementwiseInferMeta);
auto fn = PD_INFER_META(phi::ShapeBroadcastInferMeta);
fn(infer_meta);
}

Expand Down Expand Up @@ -4051,7 +4051,7 @@ std::vector<pir::Type> ShapeBroadcastOp::InferMeta(
paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out);
phi::ShapeBroadcastInferMeta(meta_x, meta_y, &meta_out);

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
Expand All @@ -4069,17 +4069,15 @@ namespace {

symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs,
const symbol::DimExpr &rhs) {
if (lhs.isa<std::int64_t>() && rhs.isa<std::int64_t>()) {
return std::max(lhs.dyn_cast<std::int64_t>(), rhs.dyn_cast<std::int64_t>());
} else if (lhs.isa<std::int64_t>()) {
return lhs.dyn_cast<std::int64_t>() == 1 ? rhs : lhs;
} else if (rhs.isa<std::int64_t>()) {
return rhs.dyn_cast<std::int64_t>() == 1 ? lhs : rhs;
} else if (lhs == rhs) {
if (lhs == rhs) {
return lhs;
} else if (lhs == 1) {
return rhs;
} else if (rhs == 1) {
return lhs;
} else {
return symbol::Broadcast<symbol::DimExpr>{
symbol::List<symbol::DimExpr>{lhs, rhs}};
return symbol::SimplifyDimExpr(symbol::Broadcast<symbol::DimExpr>{
symbol::List<symbol::DimExpr>{lhs, rhs}});
}
PADDLE_THROW(phi::errors::Fatal("Dead code"));
}
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3072,6 +3072,30 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
}
}

void ShapeBroadcastInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out) {
const auto& x_dims = x.dims();
const auto& y_dims = y.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
1,
phi::errors::InvalidArgument("The rank of x must be 1. But received: %d",
x_dims.size()));
PADDLE_ENFORCE_EQ(
y_dims.size(),
1,
phi::errors::InvalidArgument("The rank of y must be 1. But received: %d",
y_dims.size()));

if (x_dims[0] <= y_dims[0]) {
out->set_dims(y_dims);
} else {
out->set_dims(x_dims);
}
out->set_dtype(x.dtype());
}

void ShuffleBatchInferMeta(const MetaTensor& x,
const MetaTensor& seed,
int startup_seed,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,10 @@ void SequenceMaskInferMeta(const MetaTensor& x,
DataType out_dtype,
MetaTensor* y);

void ShapeBroadcastInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);

void ShuffleBatchInferMeta(const MetaTensor& x,
const MetaTensor& seed,
int startup_seed,
Expand Down
43 changes: 40 additions & 3 deletions test/cpp/pir/cinn/add_broadcast_to_elementwise_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,32 @@
#include <memory>

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/builtin_dialect.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h"

std::vector<pir::Type> CreateDenseTensorTypes(const phi::DDim &dims) {
pir::IrContext *ctx = ::pir::IrContext::Instance();
pir::Type fp32_dtype = ::pir::Float32Type::get(ctx);
phi::DataLayout data_layout = phi::DataLayout::NCHW;
phi::LoD lod = {};
size_t offset = 0;
std::vector<::pir::Type> op_output_types = {::pir::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset)};
return op_output_types;
}

void BuildProgram(pir::Builder &builder) { // NOLINT
auto group_op = builder.Build<cinn::dialect::GroupOp>(
CreateDenseTensorTypes(common::make_ddim({4, 3, 16})));
builder.SetInsertionPointToBlockEnd(group_op.block());
paddle::dialect::FullOp full_input_x =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16},
1.5,
Expand All @@ -39,9 +55,13 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
full_input_y.result(0));

auto relu_op = builder.Build<paddle::dialect::ReluOp>(add_op.result(0));
builder.Build<pir::YieldOp>(std::vector<pir::Value>{relu_op.out()});
}

void BuildProgramBoth(pir::Builder &builder) { // NOLINT
auto group_op = builder.Build<cinn::dialect::GroupOp>(
CreateDenseTensorTypes(common::make_ddim({10, 10})));
builder.SetInsertionPointToBlockEnd(group_op.block());
paddle::dialect::FullOp full_input_x =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{10, 1},
1.5,
Expand All @@ -57,9 +77,13 @@ void BuildProgramBoth(pir::Builder &builder) { // NOLINT
full_input_y.result(0));

auto relu_op = builder.Build<paddle::dialect::ReluOp>(add_op.result(0));
builder.Build<pir::YieldOp>(std::vector<pir::Value>{relu_op.out()});
}

void BuildProgramSubBoth(pir::Builder &builder) { // NOLINT
auto group_op = builder.Build<cinn::dialect::GroupOp>(
CreateDenseTensorTypes(common::make_ddim({10, 10})));
builder.SetInsertionPointToBlockEnd(group_op.block());
paddle::dialect::FullOp full_input_x =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{10, 1},
1.5,
Expand All @@ -75,6 +99,7 @@ void BuildProgramSubBoth(pir::Builder &builder) { // NOLINT
full_input_x.result(0), full_input_y.result(0));

auto relu_op = builder.Build<paddle::dialect::ReluOp>(sub_op.result(0));
builder.Build<pir::YieldOp>(std::vector<pir::Value>{relu_op.out()});
}

TEST(PatternRewrite, broadcast_elementwise) {
Expand All @@ -91,7 +116,11 @@ TEST(PatternRewrite, broadcast_elementwise) {

pm.Run(&program);

auto it = program.block()->begin();
auto it = program.block()
->begin()
->dyn_cast<cinn::dialect::GroupOp>()
.block()
->begin();

CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
it++;
Expand All @@ -116,7 +145,11 @@ TEST(PatternRewrite, broadcast_elementwise_both) {

pm.Run(&program);

auto it = program.block()->begin();
auto it = program.block()
->begin()
->dyn_cast<cinn::dialect::GroupOp>()
.block()
->begin();

CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
it++;
Expand All @@ -143,7 +176,11 @@ TEST(PatternRewrite, broadcast_elementwise_sub_both) {

pm.Run(&program);

auto it = program.block()->begin();
auto it = program.block()
->begin()
->dyn_cast<cinn::dialect::GroupOp>()
.block()
->begin();

CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
it++;
Expand Down
28 changes: 16 additions & 12 deletions test/cpp/pir/cinn/pir_all_path_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,24 @@ static void RunAndCheckResult(::pir::Program* program,
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();

pir::PassManager pm(ctx);
pm.AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass());
pm.AddPass(cinn::dialect::ir::CreateAddBroadcastToElementwisePass());
pm.AddPass(
pir::PassManager stage_1_pm(ctx);
stage_1_pm.AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass());
stage_1_pm.AddPass(
std::make_unique<cinn::dialect::ir::MergeReshapeWithBroadcastPass>());

pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.AddPass(pir::CreateBuildCinnPass());
pm.AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass());
pm.AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass());
pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass());
pm.EnableIRPrinting();
CHECK_EQ(pm.Run(program), true);
stage_1_pm.AddPass(pir::CreateDeadCodeEliminationPass());
stage_1_pm.AddPass(pir::CreateBuildCinnPass());
stage_1_pm.AddPass(cinn::dialect::ir::CreateAddBroadcastToElementwisePass());
stage_1_pm.EnableIRPrinting();
CHECK_EQ(stage_1_pm.Run(program), true);

pir::PassManager stage_2_pm(ctx);
stage_2_pm.AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass());
stage_2_pm.AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass());
stage_2_pm.AddPass(pir::CreateDeadCodeEliminationPass());
stage_2_pm.AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass());
stage_2_pm.EnableIRPrinting();
CHECK_EQ(stage_2_pm.Run(program), true);

paddle::platform::Place place = paddle::platform::CUDAPlace(0);

Expand Down