diff --git a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt index 0b54046cb6e6af..56f9ab3d5ebe72 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt @@ -67,10 +67,12 @@ if(NOT CINN_ONLY) op_dialect.cc ${cinn_op_source_file} ${cinn_op_info_file} + generate_shape_util.cc manual_op.cc op_attribute.cc DEPS - op_dialect_vjp) + op_dialect_vjp + pir) target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_SOURCE_DIR}) endif() diff --git a/paddle/pir/dialect/shape/utils/dim_expr_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc similarity index 69% rename from paddle/pir/dialect/shape/utils/dim_expr_util.cc rename to paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index 8421f500c23daa..eef663585a4086 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/dialect/shape/utils/dim_expr_util.h" +#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" -namespace symbol { +namespace cinn::dialect { +using namespace symbol; // NOLINT namespace { @@ -58,71 +59,71 @@ std::string GetSerializedTag>() { return "Broadcast"; } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const std::int64_t& dim_expr) { - return builder->int64_attr(dim_expr); + return pir::Int64Attribute::get(ctx, dim_expr); } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const std::string& dim_expr) { - return builder->str_attr(dim_expr); + return pir::StrAttribute::get(ctx, dim_expr); } template -::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::IrContext* ctx, const T& dim_expr) { std::vector<::pir::Attribute> attr_vecs{}; - attr_vecs.push_back(builder->str_attr(GetSerializedTag())); + attr_vecs.push_back(pir::StrAttribute::get(ctx, GetSerializedTag())); const auto& operand = dim_expr->data; - attr_vecs.push_back(ConvertDimExprToAttribute(builder, operand)); - return builder->array_attr(attr_vecs); + attr_vecs.push_back(ConvertDimExprToAttribute(ctx, operand)); + return pir::ArrayAttribute::get(ctx, attr_vecs); } ::pir::Attribute ConvertDimExprToAttributeImpl( - ::pir::Builder* builder, const Negative& dim_expr) { - return ConvertUnaryDimExprToAttributeImpl(builder, dim_expr); + ::pir::IrContext* ctx, const Negative& dim_expr) { + return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr); } ::pir::Attribute ConvertDimExprToAttributeImpl( - ::pir::Builder* builder, const Reciprocal& dim_expr) { - return ConvertUnaryDimExprToAttributeImpl(builder, dim_expr); + ::pir::IrContext* ctx, const Reciprocal& dim_expr) { + return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr); } template -::pir::Attribute ConvertVariadicDimExprToAttribute(::pir::Builder* builder, +::pir::Attribute ConvertVariadicDimExprToAttribute(::pir::IrContext* ctx, const T& dim_expr) { std::vector<::pir::Attribute> attr_vecs{}; - attr_vecs.push_back(builder->str_attr(GetSerializedTag())); + attr_vecs.push_back(pir::StrAttribute::get(ctx, GetSerializedTag())); const auto& operands = *(dim_expr.operands); for (const auto& operand : operands) { - attr_vecs.push_back(ConvertDimExprToAttribute(builder, operand)); + attr_vecs.push_back(ConvertDimExprToAttribute(ctx, operand)); } - return builder->array_attr(attr_vecs); + return pir::ArrayAttribute::get(ctx, attr_vecs); } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const Add& dim_expr) { - return ConvertVariadicDimExprToAttribute(builder, dim_expr); + return ConvertVariadicDimExprToAttribute(ctx, dim_expr); } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const Mul& dim_expr) { - return ConvertVariadicDimExprToAttribute(builder, dim_expr); + return ConvertVariadicDimExprToAttribute(ctx, dim_expr); } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const Max& dim_expr) { - return ConvertVariadicDimExprToAttribute(builder, dim_expr); + return ConvertVariadicDimExprToAttribute(ctx, dim_expr); } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const Min& dim_expr) { - return ConvertVariadicDimExprToAttribute(builder, dim_expr); + return ConvertVariadicDimExprToAttribute(ctx, dim_expr); } ::pir::Attribute ConvertDimExprToAttributeImpl( - ::pir::Builder* builder, const Broadcast& dim_expr) { - return ConvertVariadicDimExprToAttribute(builder, dim_expr); + ::pir::IrContext* ctx, const Broadcast& dim_expr) { + return ConvertVariadicDimExprToAttribute(ctx, dim_expr); } std::optional ConvertInt64AttributeToDimExpr( @@ -211,11 +212,11 @@ std::optional ConvertArrayAttributeToDimExpr( } // namespace -::pir::Attribute ConvertDimExprToAttribute(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx, const DimExpr& dim_expr) { return std::visit( [&](const auto& impl) { - return ConvertDimExprToAttributeImpl(builder, impl); + return ConvertDimExprToAttributeImpl(ctx, impl); }, dim_expr.variant()); } @@ -359,4 +360,66 @@ MakeGetterDimExpr4SymbolName( }; } -} // namespace symbol +namespace { + +std::optional GetDimExprBySymbolBindingImpl( + const GenerateShapeOp::DataSymbolBinding& symbol_binding, + const std::function& + DimExpr4InputDim) { + const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr = + DimExpr4InputDim(symbol_binding.input_tensor_idx); + if (!shape_or_data_dim_expr.data().has_value()) return std::nullopt; + int dim_idx = symbol_binding.input_tensor_dim_idx; + if (dim_idx >= shape_or_data_dim_expr.data().value().size()) + return std::nullopt; + return shape_or_data_dim_expr.data().value().at(dim_idx); +} + +std::optional GetDimExprBySymbolBindingImpl( + const GenerateShapeOp::ShapeSymbolBinding& symbol_binding, + const std::function& + DimExpr4InputDim) { + const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr = + DimExpr4InputDim(symbol_binding.input_tensor_idx); + int dim_idx = symbol_binding.input_tensor_dim_idx; + if (dim_idx >= shape_or_data_dim_expr.shape().size()) return std::nullopt; + return shape_or_data_dim_expr.shape().at(dim_idx); +} + +} // namespace + +std::function(const std::string& symbol_name)> +MakeGetterDimExpr4SymbolName( + const GenerateShapeOp::SymbolBindings& symbol_bindings, + const std::function& + DimExpr4InputDim) { + std::unordered_map> + symbol_name2symbol_bindins{}; + const auto& GetDimExpr = + [&](const GenerateShapeOp::SymbolBinding& symbol_binding) { + return std::visit( + [&](const auto& impl) { + return GetDimExprBySymbolBindingImpl(impl, DimExpr4InputDim); + }, + symbol_binding); + }; + return [map = std::move(symbol_name2symbol_bindins), GetDimExpr]( + const std::string& symbol_name) -> std::optional { + const auto& iter = map.find(symbol_name); + if (iter == map.end()) return std::nullopt; + std::optional ret = std::nullopt; + for (const auto& symbol_binding : iter->second) { + const auto& current = GetDimExpr(symbol_binding); + if (!current.has_value()) return std::nullopt; + if (ret.has_value()) { + // Same names, same DimExprs. + if (ret.value() != current.value()) return std::nullopt; + } else { + ret = current; + } + } + return ret; + }; +} + +} // namespace cinn::dialect diff --git a/paddle/pir/dialect/shape/utils/dim_expr_util.h b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h similarity index 53% rename from paddle/pir/dialect/shape/utils/dim_expr_util.h rename to paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h index 3ed4550c2248d5..ee4ad3c129e6b4 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr_util.h +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h @@ -15,28 +15,35 @@ #pragma once #include +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/pir/core/builder.h" -#include "paddle/pir/core/dll_decl.h" #include "paddle/pir/dialect/shape/utils/dim_expr.h" -namespace symbol { +namespace cinn::dialect { -IR_API ::pir::Attribute ConvertDimExprToAttribute(::pir::Builder* builder, - const DimExpr& dim_expr); -IR_API std::optional ConvertAttributeToDimExpr( +::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx, + const symbol::DimExpr& dim_expr); + +std::optional ConvertAttributeToDimExpr( ::pir::Attribute attribute); -IR_API std::optional SubstituteDimExpr( - const DimExpr& dim_expr, - const std::function(const std::string& symbol_name)>& - DimExpr4SymbolName); +std::optional SubstituteDimExpr( + const symbol::DimExpr& dim_expr, + const std::function( + const std::string& symbol_name)>& DimExpr4SymbolName); -IR_API std::function(const std::string& symbol_name)> +std::function(const std::string& symbol_name)> MakeGetterDimExpr4SymbolName( const std::vector>& symbol_bindings, - const std::function( + const std::function( int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim); -} // namespace symbol +std::function(const std::string& symbol_name)> +MakeGetterDimExpr4SymbolName( + const GenerateShapeOp::SymbolBindings& symbol_bindings, + const std::function& + DimExpr4InputDim); + +} // namespace cinn::dialect diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index 68a09ad7a9868b..7bbcd74025a076 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -16,8 +16,12 @@ #include #include "glog/logging.h" +#include "paddle/common/ddim.h" #include "paddle/common/enforce.h" +#include "paddle/fluid/pir/dialect/operator/ir/ir_meta_tensor.h" +#include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/op_base.h" #include "paddle/pir/dialect/control_flow/ir/cf_op.h" @@ -25,25 +29,25 @@ namespace cinn { namespace dialect { -const char *GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"}; -const char *ConcatOp::attributes_name[ConcatOp::attributes_num] = {"axis"}; -const char *SplitOp::attributes_name[SplitOp::attributes_num] = { +const char* GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"}; +const char* ConcatOp::attributes_name[ConcatOp::attributes_num] = {"axis"}; +const char* SplitOp::attributes_name[SplitOp::attributes_num] = { "num_or_sections", "axis"}; -void GroupOp::Build(pir::Builder &builder, - pir::OperationArgument &argument, - const std::vector &output_types) { +void GroupOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + const std::vector& output_types) { argument.AddRegion(nullptr); argument.output_types = output_types; } -void GroupOp::Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT - std::unique_ptr &&block) { +void GroupOp::Build(pir::Builder& builder, // NOLINT + pir::OperationArgument& argument, // NOLINT + std::unique_ptr&& block) { VLOG(4) << "Start build GroupOp"; if (block && !block->empty()) { IR_ENFORCE(block->back().isa()); - auto &op = block->back(); + auto& op = block->back(); for (size_t i = 0; i < op.num_operands(); ++i) { argument.AddOutput(op.operand(i).type()); } @@ -51,15 +55,15 @@ void GroupOp::Build(pir::Builder &builder, // NOLINT argument.AddRegion().push_back(block.release()); } -pir::Block *GroupOp::block() { - pir::Region ®ion = (*this)->region(0); +pir::Block* GroupOp::block() { + pir::Region& region = (*this)->region(0); if (region.empty()) region.emplace_back(); return ®ion.front(); } -std::vector GroupOp::ops() { - std::vector rt_ops; - for (auto &op : *block()) { +std::vector GroupOp::ops() { + std::vector rt_ops; + for (auto& op : *block()) { rt_ops.push_back(&op); } return rt_ops; @@ -67,8 +71,8 @@ std::vector GroupOp::ops() { void GroupOp::VerifySig() {} -void GroupOp::Print(pir::IrPrinter &printer) { - auto &os = printer.os; +void GroupOp::Print(pir::IrPrinter& printer) { + auto& os = printer.os; auto op = operation(); printer.PrintOpResult(op); os << " = " << name(); @@ -76,16 +80,16 @@ void GroupOp::Print(pir::IrPrinter &printer) { os << " -> "; printer.PrintOpReturnType(op); os << " {"; - for (auto &sub_op : ops()) { + for (auto& sub_op : ops()) { os << "\n"; printer.PrintOperation(sub_op); } os << " \n }"; } -void ConcatOp::Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT - const std::vector &inputs, +void ConcatOp::Build(pir::Builder& builder, // NOLINT + pir::OperationArgument& argument, // NOLINT + const std::vector& inputs, int axis) { VLOG(4) << "Start build ConcatOp"; @@ -131,10 +135,10 @@ void ConcatOp::Build(pir::Builder &builder, // NOLINT "axis", pir::Int32Attribute::get(pir::IrContext::Instance(), axis)); } -void SplitOp::Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT +void SplitOp::Build(pir::Builder& builder, // NOLINT + pir::OperationArgument& argument, // NOLINT pir::Value input, - const std::vector §ions, + const std::vector& sections, int axis) { VLOG(4) << "Start build ConcatOp"; @@ -177,9 +181,174 @@ void SplitOp::Build(pir::Builder &builder, // NOLINT "axis", pir::Int32Attribute::get(pir::IrContext::Instance(), axis)); } +const char* GenerateShapeOp::attributes_name[attributes_num] = { + "output_dim_exprs", "symbol_bindings"}; + +void GenerateShapeOp::Build( + pir::Builder& builder, + pir::OperationArgument& argument, + const std::vector& inputs, + const std::vector& output_dim_exprs, + const GenerateShapeOp::SymbolBindings& symbol_bindings) { + CHECK(!inputs.empty()); + argument.AddInputs(inputs); + argument.AddAttribute("output_dim_exprs", + builder.array_attr(output_dim_exprs)); + argument.AddAttribute( + "symbol_bindings", + ConvertSymbolBindingsToAttribute(builder, symbol_bindings)); + argument.AddOutputs({[&]() { + auto* ctx = pir::IrContext::Instance(); + auto type = pir::Int64Type::get(ctx); + auto dim = + ::common::make_ddim({static_cast(output_dim_exprs.size())}); + return paddle::dialect::DenseTensorType::get(ctx, type, dim); + }()}); + ::pir::PassStopGradientsDefaultly(argument); +} + +namespace { + +const char* GetSymbolBindingTypeImpl( + const GenerateShapeOp::DataSymbolBinding& binding) { + return "DataSymbolBinding"; +} + +const char* GetSymbolBindingTypeImpl( + const GenerateShapeOp::ShapeSymbolBinding& binding) { + return "ShapeSymbolBinding"; +} + +const char* GetSymbolBindingType( + const GenerateShapeOp::SymbolBinding& binding) { + return std::visit( + [](const auto& impl) { return GetSymbolBindingTypeImpl(impl); }, binding); +} + +const GenerateShapeOp::SymbolBindingBase* GetSymbolBindingBaseImpl( + const GenerateShapeOp::DataSymbolBinding& binding) { + return &binding; +} + +const GenerateShapeOp::SymbolBindingBase* GetSymbolBindingBaseImpl( + const GenerateShapeOp::ShapeSymbolBinding& binding) { + return &binding; +} + +const GenerateShapeOp::SymbolBindingBase* GetSymbolBindingBase( + const GenerateShapeOp::SymbolBinding& binding) { + return std::visit( + [](const auto& impl) { return GetSymbolBindingBaseImpl(impl); }, binding); +} + +typedef GenerateShapeOp::SymbolBinding (*SymbolBindingConstructorT)( + const std::string& symbol_name, + int64_t input_tensor_idx, + int64_t input_tensor_dim_idx); + +GenerateShapeOp::SymbolBinding MakeDataSymbolBinding( + const std::string& symbol_name, + int64_t input_tensor_idx, + int64_t input_tensor_dim_idx) { + return GenerateShapeOp::DataSymbolBinding{ + symbol_name, input_tensor_idx, input_tensor_dim_idx}; +} + +GenerateShapeOp::SymbolBinding MakeShapeSymbolBinding( + const std::string& symbol_name, + int64_t input_tensor_idx, + int64_t input_tensor_dim_idx) { + return GenerateShapeOp::ShapeSymbolBinding{ + symbol_name, input_tensor_idx, input_tensor_dim_idx}; +} + +std::optional GetMakerSymbolBinding( + const std::string& type) { + static std::map map{ + {GetSymbolBindingTypeImpl(GenerateShapeOp::DataSymbolBinding{}), + &MakeDataSymbolBinding}, + {GetSymbolBindingTypeImpl(GenerateShapeOp::ShapeSymbolBinding{}), + &MakeShapeSymbolBinding}, + }; + const auto& iter = map.find(type); + if (iter == map.end()) return std::nullopt; + return iter->second; +} + +std::optional MakeSymbolBinding( + const std::string& type, + const std::string& symbol_name, + int64_t input_tensor_idx, + int64_t input_tensor_dim_idx) { + auto opt_creator = GetMakerSymbolBinding(type); + if (!opt_creator.has_value()) return std::nullopt; + return opt_creator.value()( + symbol_name, input_tensor_idx, input_tensor_dim_idx); +} + +} // namespace + +pir::Attribute GenerateShapeOp::ConvertSymbolBindingsToAttribute( + pir::Builder& builder, + const GenerateShapeOp::SymbolBindings& symbol_bindings) { + const auto& ConvertSymbolBindingToAttr = [&](const SymbolBinding& binding) { + const auto* type = GetSymbolBindingType(binding); + const auto& [symbol_name, input_tensor_idx, input_tensor_dim_idx] = + *GetSymbolBindingBase(binding); + return builder.array_attr({ + builder.str_attr(type), + builder.str_attr(symbol_name), + builder.int64_attr(input_tensor_idx), + builder.int64_attr(input_tensor_dim_idx), + }); + }; + std::vector bindings_attr{}; + for (const auto& symbol_binding : symbol_bindings) { + bindings_attr.push_back(ConvertSymbolBindingToAttr(symbol_binding)); + } + return builder.array_attr(bindings_attr); +} + +std::optional +GenerateShapeOp::ConvertAttributeToSymbolBindings( + const pir::Attribute& symbol_bindings) { + if (!symbol_bindings.isa()) return std::nullopt; + const auto& symbol_bindings_array_attr = + symbol_bindings.dyn_cast(); + GenerateShapeOp::SymbolBindings ret{GenerateShapeOp::SymbolBindings{}}; + for (int i = 0; i < symbol_bindings_array_attr.size(); ++i) { + const auto& symbol_binding = symbol_bindings_array_attr.at(i); + if (!symbol_binding.isa()) return std::nullopt; + const auto& symbol_binding_array_attr = + symbol_binding.dyn_cast(); + if (symbol_binding_array_attr.size() != 4) return std::nullopt; + if (!symbol_binding_array_attr.at(0).isa()) + return std::nullopt; + if (!symbol_binding_array_attr.at(1).isa()) + return std::nullopt; + if (!symbol_binding_array_attr.at(2).isa()) + return std::nullopt; + if (!symbol_binding_array_attr.at(3).isa()) + return std::nullopt; + const auto& opt_symbol_binding = MakeSymbolBinding( + symbol_binding_array_attr.at(0) + .dyn_cast() + .AsString(), + symbol_binding_array_attr.at(1) + .dyn_cast() + .AsString(), + symbol_binding_array_attr.at(2).dyn_cast().data(), + symbol_binding_array_attr.at(3).dyn_cast().data()); + if (!opt_symbol_binding.has_value()) return std::nullopt; + ret.emplace_back(opt_symbol_binding.value()); + } + return std::move(ret); +} + } // namespace dialect } // namespace cinn IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp) IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp) IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp) +IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp); diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index fbec6e32ee56b7..8a9acef15aa9d7 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include "paddle/phi/core/infermeta_utils.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/ir_printer.h" @@ -81,9 +82,46 @@ class IR_API SplitOp : public pir::Op { void VerifySig() const {} }; +class GenerateShapeOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "cinn_op.generate_shape"; } + static constexpr uint32_t attributes_num = 2; + static const char *attributes_name[attributes_num]; + + struct SymbolBindingBase { + std::string symbol_name; + int64_t input_tensor_idx; + int64_t input_tensor_dim_idx; + }; + + struct DataSymbolBinding : public SymbolBindingBase {}; + struct ShapeSymbolBinding : public SymbolBindingBase {}; + + using SymbolBinding = std::variant; + + using SymbolBindings = std::vector; + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + const std::vector &inputs, + const std::vector &output_dim_exprs, + const SymbolBindings &symbol_bindings); + + void VerifySig() {} + + pir::OpResult out() { return result(0); } + + static pir::Attribute ConvertSymbolBindingsToAttribute( + pir::Builder &builder, const SymbolBindings &symbol_bindings); // NOLINT + static std::optional ConvertAttributeToSymbolBindings( + const pir::Attribute &symbol_bindings); +}; + } // namespace dialect } // namespace cinn IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp) IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp) IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp) +IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index 18ce80a92baff4..6d76ccbec8adc1 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -19,4 +19,14 @@ if(NOT CINN_ONLY) pir cinn_op_dialect op_dialect_vjp) + + cinn_cc_library( + fuse_shape_ops_into_generate_shape_op_pass + SRCS + fuse_shape_ops_into_generate_shape_op_pass.cc + DEPS + pir + cinn_op_dialect + op_dialect_vjp) + endif() diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc new file mode 100644 index 00000000000000..48c7427b402a14 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -0,0 +1,369 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h" +#include +#include +#include "paddle/cinn/common/bfs_walker.h" +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/pattern_applicator.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace cinn { +namespace dialect { +namespace ir { + +namespace { + +using ShapeOrDataDimExprs4ValueT = + std::function; + +std::vector FindSourceDenseTensorOfDimTensor( + pir::Value shape, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + std::vector ret{}; + const auto& Emplace = [&](pir::Value value) { + if (std::find(ret.begin(), ret.end(), value) != ret.end()) return; + ret.emplace_back(value); + }; + const auto& ForEachInputValue = + [&](pir::Value value, const std::function& Visit) { + // find input dimension tensor; + pir::Operation* owner = value.defining_op(); + if (owner == nullptr) return; + for (int i = 0; i < owner->num_operands(); ++i) { + Visit(owner->operand_source(i)); + } + }; + const auto& IsDimTensor = [&](pir::Value value) -> bool { + return ShapeOrDataDimExprs4Value(value).data().has_value(); + }; + const auto& ForEachInputDimTensor = + [&](pir::Value value, const std::function& Visit) { + // find input dimension tensor; + ForEachInputValue(value, [&](pir::Value input) { + if (IsDimTensor(input)) { + Visit(input); + } + }); + }; + common::BfsWalker walker(ForEachInputDimTensor); + walker(shape, [&](pir::Value value) { + size_t input_cnt = 0; + ForEachInputValue(value, [&](pir::Value input) { + ++input_cnt; + if (IsDimTensor(input)) return; + Emplace(input); + }); + if (input_cnt == 0) { + // `value` is a result of a source op. + Emplace(value); + } + }); + return ret; +} + +bool IsConstant(const std::vector& dim_exprs) { + for (const auto& dim_expr : dim_exprs) { + if (dim_expr.isa()) continue; + return false; + } + return true; +} + +bool IsAtomicImpl(int64_t) { return true; } + +bool IsAtomicImpl(const std::string&) { return true; } + +bool IsAtomicImpl(const symbol::Negative&) { return false; } + +bool IsAtomicImpl(const symbol::Reciprocal&) { return false; } + +bool IsAtomicImpl(const symbol::Add&) { return false; } + +bool IsAtomicImpl(const symbol::Mul&) { return false; } + +bool IsAtomicImpl(const symbol::Max&) { return false; } + +bool IsAtomicImpl(const symbol::Min&) { return false; } + +bool IsAtomicImpl(const symbol::Broadcast&) { return false; } + +bool IsAtomic(const symbol::DimExpr& dim_expr) { + return std::visit([](const auto& impl) { return IsAtomicImpl(impl); }, + dim_expr.variant()); +} + +bool InputDimExprsAllSupported( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors) { + const auto& AllSupported = + [](const std::vector& dim_exprs) -> bool { + for (const auto& dim_expr : dim_exprs) { + if (!IsAtomic(dim_expr)) return false; + } + return true; + }; + for (const auto& input_tensor : input_tensors) { + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + if (!AllSupported(dim_exprs.shape())) return false; + if (dim_exprs.data().has_value()) { + if (!AllSupported(dim_exprs.data().value())) return false; + } + } + return true; +} + +void ConvertDimExprToAttributes(pir::IrContext* ir_context, + const std::vector& dim_exprs, + std::vector* attrs) { + attrs->clear(); + attrs->reserve(dim_exprs.size()); + for (const auto& dim_expr : dim_exprs) { + attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr)); + } +} + +void CollectSymbolNames(const symbol::DimExpr& dim_expr, + std::set* ret); + +void CollectSymbolNamesImpl(const int64_t& dim_expr, + std::set* ret) { + // do nothing. +} + +void CollectSymbolNamesImpl(const std::string& dim_expr, + std::set* ret) { + ret->insert(dim_expr); +} + +template +void CollectSymbolNamesImplForUnary(const T& dim_expr, + std::set* ret) { + const auto& [operand] = *dim_expr; + CollectSymbolNames(operand, ret); +} + +void CollectSymbolNamesImpl(const symbol::Negative& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForUnary(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Reciprocal& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForUnary(dim_expr, ret); +} + +template +void CollectSymbolNamesImplForVariadic(const T& dim_expr, + std::set* ret) { + const auto& operands = *(dim_expr.operands); + for (const auto& operand : operands) { + CollectSymbolNames(operand, ret); + } +} + +void CollectSymbolNamesImpl(const symbol::Add& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Mul& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Max& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Min& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Broadcast& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNames(const symbol::DimExpr& dim_expr, + std::set* ret) { + return std::visit( + [&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); }, + dim_expr.variant()); +} + +void CollectSymbolNames(const std::vector& dim_exprs, + std::set* ret) { + for (const auto& dim_expr : dim_exprs) { + CollectSymbolNames(dim_expr, ret); + } +} + +template +void AppendSymbolBindings(const std::vector& dim_exprs, + const std::set& symbol_names, + int in_tensor_idx, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size(); + ++in_tensor_dim_idx) { + const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx); + CHECK(IsAtomic(dim_expr)); + if (!dim_expr.isa()) continue; + const auto& sym_name = dim_expr.dyn_cast(); + if (symbol_names.find(sym_name) == symbol_names.end()) continue; + symbol_bindings->emplace_back(SymbolBindingsT{ + /*.symbol_name=*/sym_name, + /*.input_tensor_idx=*/in_tensor_idx, + /*.input_tensor_dim_idx=*/in_tensor_dim_idx, + }); + } +} + +void GenerateSymbolBindings( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors, + const std::set& symbol_names, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + for (int i = 0; i < input_tensors.size(); ++i) { + const auto& input_tensor = input_tensors.at(i); + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + AppendSymbolBindings( + dim_exprs.shape(), symbol_names, i, symbol_bindings); + if (dim_exprs.data().has_value()) { + AppendSymbolBindings( + dim_exprs.shape(), symbol_names, i, symbol_bindings); + } + } +} + +bool MakeGenerateShapeOpAttribute( + pir::IrContext* ir_context, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors, + pir::Value output_shape, + std::vector* output_dim_expr_attrs, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + const auto& shape_or_data_dim_exprs = ShapeOrDataDimExprs4Value(output_shape); + CHECK(shape_or_data_dim_exprs.data().has_value()); + const auto& out_dim_exprs = shape_or_data_dim_exprs.data().value(); + if (IsConstant(out_dim_exprs)) return false; + if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, input_tensors)) { + VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " + "they are handled by other passes"; + return false; + } + // generate output_dim_expr_attrs + ConvertDimExprToAttributes( + ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs); + // generate symbol_bindings + std::set symbol_names_in_out_dim_exprs{}; + CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs); + GenerateSymbolBindings(ShapeOrDataDimExprs4Value, + input_tensors, + symbol_names_in_out_dim_exprs, + /*out*/ symbol_bindings); + return true; +} + +std::optional GetOutOfRewritedGenerateShapeOp( + pir::Value shape, + pir::PatternRewriter* rewriter, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + std::vector input_tensors = + FindSourceDenseTensorOfDimTensor(shape, ShapeOrDataDimExprs4Value); + if (input_tensors.empty()) return std::nullopt; + std::vector output_dim_expr_attrs{}; + GenerateShapeOp::SymbolBindings symbol_bindings{}; + bool success = MakeGenerateShapeOpAttribute(rewriter->ir_context(), + ShapeOrDataDimExprs4Value, + input_tensors, + shape, + &output_dim_expr_attrs, + &symbol_bindings); + if (!success) return std::nullopt; + return rewriter + ->Build( + input_tensors, output_dim_expr_attrs, symbol_bindings) + .out(); +} + +bool ProcessOp(paddle::dialect::ExpandOp op, + pir::PatternRewriter* rewriter, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + std::optional opt_generated_shape = + GetOutOfRewritedGenerateShapeOp( + op.shape(), rewriter, ShapeOrDataDimExprs4Value); + if (!opt_generated_shape.has_value()) return false; + op->operand(1).set_source(opt_generated_shape.value()); + return true; +} + +} // namespace + +template +class FuseShapeOpsIntoGenerateShapeOpPattern + : public pir::OpRewritePattern { + public: + FuseShapeOpsIntoGenerateShapeOpPattern( + pir::IrContext* context, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) + : pir::OpRewritePattern(context), + ShapeOrDataDimExprs4Value_(ShapeOrDataDimExprs4Value) {} + + bool MatchAndRewrite(OPTYPE op, + pir::PatternRewriter& rewriter) const override { + return ProcessOp(op, &rewriter, ShapeOrDataDimExprs4Value_); + } + + private: + ShapeOrDataDimExprs4ValueT ShapeOrDataDimExprs4Value_; +}; + +FuseShapeOpsIntoGenerateShapeOpPass::FuseShapeOpsIntoGenerateShapeOpPass( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) + : pir::PatternRewritePass("fuse_shape_ops_into_generate_shape_op_pass", 1), + ShapeOrDataDimExprs4Value_(ShapeOrDataDimExprs4Value) {} + +pir::RewritePatternSet FuseShapeOpsIntoGenerateShapeOpPass::InitializePatterns( + pir::IrContext* context) { + pir::RewritePatternSet ps(context); + // elementwise ops + ps.Add>( + context, ShapeOrDataDimExprs4Value_); + + return ps; +} + +bool FuseShapeOpsIntoGenerateShapeOpPass::CanApplyOn(pir::Operation* op) const { + return op->isa() && op->num_regions() > 0; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h new file mode 100644 index 00000000000000..393ae49825182a --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/dialect/shape/utils/dim_expr.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass { + public: + using ShapeOrDataDimExprs4ValueT = + std::function; + explicit FuseShapeOpsIntoGenerateShapeOpPass( + const ShapeOrDataDimExprs4ValueT &ShapeOrDataDimExprs4Value); + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override; + + bool CanApplyOn(pir::Operation *op) const override; + + private: + ShapeOrDataDimExprs4ValueT ShapeOrDataDimExprs4Value_; +}; + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/test/cpp/pir/cinn/CMakeLists.txt b/test/cpp/pir/cinn/CMakeLists.txt index b38edcbb62041d..ccf7c4d8ce686d 100644 --- a/test/cpp/pir/cinn/CMakeLists.txt +++ b/test/cpp/pir/cinn/CMakeLists.txt @@ -25,6 +25,9 @@ if(WITH_TESTING AND WITH_CINN) paddle_test(test_compilation_task SRCS compilation_task_test.cc) + paddle_test(test_generate_shape_util_test SRCS generate_shape_util_test.cc + DEPS cinn_op_dialect) + # DO NOT forget add test name here, otherwise it will not be executed in # CINN CI. set(cinn_unit_tests @@ -37,7 +40,8 @@ if(WITH_TESTING AND WITH_CINN) test_pir_all_path test_group_op test_pir_build_cinn_pass - test_compilation_task) + test_compilation_task + test_generate_shape_util_test) foreach(test_name ${cinn_unit_tests}) get_property( diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_util_test.cc b/test/cpp/pir/cinn/generate_shape_util_test.cc similarity index 92% rename from test/cpp/pir/shape_dialect/symbol_dim_expr_util_test.cc rename to test/cpp/pir/cinn/generate_shape_util_test.cc index 0893a6d5027055..4fc69c877eb5f7 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_util_test.cc +++ b/test/cpp/pir/cinn/generate_shape_util_test.cc @@ -14,13 +14,14 @@ #include "gtest/gtest.h" +#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/pir/dialect/shape/utils/dim_expr_builder.h" -#include "paddle/pir/dialect/shape/utils/dim_expr_util.h" #include "test/cpp/pir/tools/test_pir_utils.h" -namespace symbol { +namespace cinn::dialect { +using namespace symbol; // NOLINT namespace { DimExpr CreateExampleDimExpr() { @@ -37,11 +38,9 @@ DimExpr CreateExampleDimExpr() { TEST(DimExprUtil, Convert) { pir::IrContext* ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - pir::Builder builder = pir::Builder(ctx, program.block()); DimExpr dim_expr = CreateExampleDimExpr(); - ::pir::Attribute attr = ConvertDimExprToAttribute(&builder, dim_expr); + ::pir::Attribute attr = ConvertDimExprToAttribute(ctx, dim_expr); std::optional opt_expr = ConvertAttributeToDimExpr(attr); ASSERT_TRUE(opt_expr.has_value()); ASSERT_EQ(opt_expr.value(), dim_expr); @@ -96,4 +95,4 @@ TEST(DimExprUtil, MakeGetterDimExpr4SymbolName) { ASSERT_EQ(opt_dim_expr.value(), dim_expr); } -} // namespace symbol +} // namespace cinn::dialect diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index 5c3aa2b9f43449..f508efb56947e9 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -4,9 +4,6 @@ paddle_test(shape_struct_test SRCS shape_struct_test.cc DEPS gtest) paddle_test(symbol_dim_expr_test SRCS symbol_dim_expr_test.cc DEPS gtest) -paddle_test(symbol_dim_expr_util_test SRCS symbol_dim_expr_util_test.cc DEPS - gtest) - if(WITH_CINN) paddle_test( shape_optimization_test