diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index 9d05ba421cb681..2cecc5bd052bc5 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -103,7 +103,7 @@ void GroupOp::Print(pir::IrPrinter& printer) { auto& os = printer.os; auto op = operation(); printer.PrintOpResult(op); - os << " = " << name(); + os << " = " << name() << " [id:" << op->id() << "]"; printer.PrintOpOperands(op); os << " -> "; printer.PrintOpReturnType(op); @@ -181,7 +181,7 @@ void FusionOp::Print(pir::IrPrinter& printer) { auto& os = printer.os; auto op = operation(); printer.PrintOpResult(op); - os << " = " << name(); + os << " = " << name() << " [id:" << op->id() << "]"; printer.PrintOpOperands(op); os << " -> "; printer.PrintOpReturnType(op); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc index e7b577aad5c269..10cc7ae94f80b6 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc @@ -262,7 +262,7 @@ class AddBroadcastToElementwisePass : public pir::PatternRewritePass { } bool CanApplyOn(pir::Operation* op) const override { - return op->num_regions() > 0; + return op->num_regions() > 0 && op->isa(); } }; diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_group.cc b/paddle/cinn/hlir/framework/pir/op_lowering_group.cc index 5deb5c01d020da..e5187f47ab471b 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_group.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_group.cc @@ -183,7 +183,8 @@ std::ostream& operator<<(std::ostream& os, const OpLoweringGroup& group) { os << "}"; }; ::pir::IrPrinter printer(os); - os << "Group " << group.group_id() << " :\n"; + os << "Group id: " << group.group_id() << ", func_name: " << group.FuncName() + << "\n"; for (auto* op : group.ops()) { printer.PrintOperation(op); PrintSymbolDims(*op); diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 359a858560b87a..e46fb04279c682 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -3994,7 +3994,7 @@ void ShapeBroadcastOp::Build(pir::Builder &builder, } void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) { - auto fn = PD_INFER_META(phi::ElementwiseInferMeta); + auto fn = PD_INFER_META(phi::ShapeBroadcastInferMeta); fn(infer_meta); } @@ -4051,7 +4051,7 @@ std::vector ShapeBroadcastOp::InferMeta( paddle::dialect::IrTensor dense_out; paddle::dialect::IrMetaTensor meta_out(&dense_out); - phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out); + phi::ShapeBroadcastInferMeta(meta_x, meta_y, &meta_out); std::vector argument_outputs; pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( @@ -4069,17 +4069,15 @@ namespace { symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs, const symbol::DimExpr &rhs) { - if (lhs.isa() && rhs.isa()) { - return std::max(lhs.dyn_cast(), rhs.dyn_cast()); - } else if (lhs.isa()) { - return lhs.dyn_cast() == 1 ? rhs : lhs; - } else if (rhs.isa()) { - return rhs.dyn_cast() == 1 ? lhs : rhs; - } else if (lhs == rhs) { + if (lhs == rhs) { + return lhs; + } else if (lhs == 1) { + return rhs; + } else if (rhs == 1) { return lhs; } else { - return symbol::Broadcast{ - symbol::List{lhs, rhs}}; + return symbol::SimplifyDimExpr(symbol::Broadcast{ + symbol::List{lhs, rhs}}); } PADDLE_THROW(phi::errors::Fatal("Dead code")); } diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index c5bb04afdfe12a..84cf559c9cd4ae 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -3072,6 +3072,30 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence, } } +void ShapeBroadcastInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + const auto& x_dims = x.dims(); + const auto& y_dims = y.dims(); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 1, + phi::errors::InvalidArgument("The rank of x must be 1. But received: %d", + x_dims.size())); + PADDLE_ENFORCE_EQ( + y_dims.size(), + 1, + phi::errors::InvalidArgument("The rank of y must be 1. But received: %d", + y_dims.size())); + + if (x_dims[0] <= y_dims[0]) { + out->set_dims(y_dims); + } else { + out->set_dims(x_dims); + } + out->set_dtype(x.dtype()); +} + void ShuffleBatchInferMeta(const MetaTensor& x, const MetaTensor& seed, int startup_seed, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index ed30a40e39730d..1516a9e5b310fa 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -526,6 +526,10 @@ void SequenceMaskInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* y); +void ShapeBroadcastInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out); + void ShuffleBatchInferMeta(const MetaTensor& x, const MetaTensor& seed, int startup_seed, diff --git a/test/cpp/pir/cinn/add_broadcast_to_elementwise_test.cc b/test/cpp/pir/cinn/add_broadcast_to_elementwise_test.cc index 2d285cf281d6ad..86b0d04a9340c1 100644 --- a/test/cpp/pir/cinn/add_broadcast_to_elementwise_test.cc +++ b/test/cpp/pir/cinn/add_broadcast_to_elementwise_test.cc @@ -17,16 +17,32 @@ #include #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/core/builtin_dialect.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_manager.h" #include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h" +std::vector CreateDenseTensorTypes(const phi::DDim &dims) { + pir::IrContext *ctx = ::pir::IrContext::Instance(); + pir::Type fp32_dtype = ::pir::Float32Type::get(ctx); + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {}; + size_t offset = 0; + std::vector<::pir::Type> op_output_types = {::pir::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset)}; + return op_output_types; +} + void BuildProgram(pir::Builder &builder) { // NOLINT + auto group_op = builder.Build( + CreateDenseTensorTypes(common::make_ddim({4, 3, 16}))); + builder.SetInsertionPointToBlockEnd(group_op.block()); paddle::dialect::FullOp full_input_x = builder.Build(std::vector{4, 3, 16}, 1.5, @@ -39,9 +55,13 @@ void BuildProgram(pir::Builder &builder) { // NOLINT full_input_y.result(0)); auto relu_op = builder.Build(add_op.result(0)); + builder.Build(std::vector{relu_op.out()}); } void BuildProgramBoth(pir::Builder &builder) { // NOLINT + auto group_op = builder.Build( + CreateDenseTensorTypes(common::make_ddim({10, 10}))); + builder.SetInsertionPointToBlockEnd(group_op.block()); paddle::dialect::FullOp full_input_x = builder.Build(std::vector{10, 1}, 1.5, @@ -57,9 +77,13 @@ void BuildProgramBoth(pir::Builder &builder) { // NOLINT full_input_y.result(0)); auto relu_op = builder.Build(add_op.result(0)); + builder.Build(std::vector{relu_op.out()}); } void BuildProgramSubBoth(pir::Builder &builder) { // NOLINT + auto group_op = builder.Build( + CreateDenseTensorTypes(common::make_ddim({10, 10}))); + builder.SetInsertionPointToBlockEnd(group_op.block()); paddle::dialect::FullOp full_input_x = builder.Build(std::vector{10, 1}, 1.5, @@ -75,6 +99,7 @@ void BuildProgramSubBoth(pir::Builder &builder) { // NOLINT full_input_x.result(0), full_input_y.result(0)); auto relu_op = builder.Build(sub_op.result(0)); + builder.Build(std::vector{relu_op.out()}); } TEST(PatternRewrite, broadcast_elementwise) { @@ -91,7 +116,11 @@ TEST(PatternRewrite, broadcast_elementwise) { pm.Run(&program); - auto it = program.block()->begin(); + auto it = program.block() + ->begin() + ->dyn_cast() + .block() + ->begin(); CHECK_EQ(it->isa(), true); it++; @@ -116,7 +145,11 @@ TEST(PatternRewrite, broadcast_elementwise_both) { pm.Run(&program); - auto it = program.block()->begin(); + auto it = program.block() + ->begin() + ->dyn_cast() + .block() + ->begin(); CHECK_EQ(it->isa(), true); it++; @@ -143,7 +176,11 @@ TEST(PatternRewrite, broadcast_elementwise_sub_both) { pm.Run(&program); - auto it = program.block()->begin(); + auto it = program.block() + ->begin() + ->dyn_cast() + .block() + ->begin(); CHECK_EQ(it->isa(), true); it++; diff --git a/test/cpp/pir/cinn/pir_all_path_test.cc b/test/cpp/pir/cinn/pir_all_path_test.cc index 0c660c228a5dea..7568855e1f71ca 100644 --- a/test/cpp/pir/cinn/pir_all_path_test.cc +++ b/test/cpp/pir/cinn/pir_all_path_test.cc @@ -62,20 +62,24 @@ static void RunAndCheckResult(::pir::Program* program, ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - pir::PassManager pm(ctx); - pm.AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass()); - pm.AddPass(cinn::dialect::ir::CreateAddBroadcastToElementwisePass()); - pm.AddPass( + pir::PassManager stage_1_pm(ctx); + stage_1_pm.AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass()); + stage_1_pm.AddPass( std::make_unique()); - pm.AddPass(pir::CreateDeadCodeEliminationPass()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass()); - pm.AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass()); - pm.AddPass(pir::CreateDeadCodeEliminationPass()); - pm.AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass()); - pm.EnableIRPrinting(); - CHECK_EQ(pm.Run(program), true); + stage_1_pm.AddPass(pir::CreateDeadCodeEliminationPass()); + stage_1_pm.AddPass(pir::CreateBuildCinnPass()); + stage_1_pm.AddPass(cinn::dialect::ir::CreateAddBroadcastToElementwisePass()); + stage_1_pm.EnableIRPrinting(); + CHECK_EQ(stage_1_pm.Run(program), true); + + pir::PassManager stage_2_pm(ctx); + stage_2_pm.AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass()); + stage_2_pm.AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass()); + stage_2_pm.AddPass(pir::CreateDeadCodeEliminationPass()); + stage_2_pm.AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass()); + stage_2_pm.EnableIRPrinting(); + CHECK_EQ(stage_2_pm.Run(program), true); paddle::platform::Place place = paddle::platform::CUDAPlace(0);