diff --git a/paddle/cinn/common/dim_expr_converter.cc b/paddle/cinn/common/dim_expr_converter.cc index 06c8968d988764..312cf110646375 100644 --- a/paddle/cinn/common/dim_expr_converter.cc +++ b/paddle/cinn/common/dim_expr_converter.cc @@ -11,9 +11,10 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "paddle/cinn/common/dim_expr_converter.h" +#include #include "paddle/cinn/common/ir_util.h" +#include "paddle/cinn/ir/tensor.h" namespace cinn::common { using namespace symbol; // NOLINT @@ -27,7 +28,7 @@ struct DimExprToIrExprVisitor { ir::Expr operator()(const int64_t& dim) { return ir::Expr(dim); } - ir::Expr operator()(const std::string& dim_expr) { + virtual ir::Expr operator()(const std::string& dim_expr) { // The dimension must be greater equal than 1, and due to the extensive use // of int32 in CAS, the upper bound here is temporarily INT32_MAX, otherwise // there may be a risk of overflow. @@ -111,8 +112,61 @@ struct DimExprToIrExprVisitor { } // namespace +struct DimExprConverterWithSymbolBindings:: + DimExprToIrExprVisitorWithSymbolBinding : public DimExprToIrExprVisitor { + using SymbolBinding = cinn::dialect::SymbolBinding; + using ShapeSymbolBinding = cinn::dialect::ShapeSymbolBinding; + using DataSymbolBinding = cinn::dialect::DataSymbolBinding; + + const std::vector& inputs_; + std::unordered_map + symbol_binding_map_; + + ir::Expr operator()(const std::string& dim_expr) override { + CHECK(symbol_binding_map_.count(dim_expr)); + auto symbol_binding = symbol_binding_map_[dim_expr]; + auto [input_idx, input_dim_idx] = std::visit( + [](auto&& symbol_binding) -> std::pair { + return {symbol_binding.input_tensor_idx, + symbol_binding.input_tensor_dim_idx}; + }, + symbol_binding); + if (std::holds_alternative(symbol_binding)) { + return inputs_[input_idx]->sym_shape[input_dim_idx]->GetDimExpr(); + } + // for data binding [S0, a, b], inputs[a] is Tensor A, return A(b) + return inputs_[input_idx](cinn::ir::Expr(input_dim_idx)); + } + + DimExprToIrExprVisitorWithSymbolBinding( + const std::vector& inputs, + const std::vector& symbol_bindings) + : inputs_(inputs) { + for (const auto& symbol_binding : symbol_bindings) { + const auto& symbol_name = std::visit( + [](auto&& symbol_binding) -> std::string { + return symbol_binding.symbol_name; + }, + symbol_binding); + symbol_binding_map_[symbol_name] = symbol_binding; + } + } +}; + ir::Expr DimExprConverter::ConvertToIrExpr(const DimExpr& dim_expr) const { return DimExprToIrExprVisitor().ConvertToIrExpr(dim_expr); } +ir::Expr DimExprConverterWithSymbolBindings::ConvertToIrExpr( + const DimExpr& dim_expr) const { + return visitor_->ConvertToIrExpr(dim_expr); +} + +DimExprConverterWithSymbolBindings::DimExprConverterWithSymbolBindings( + const std::vector& inputs, + const cinn::dialect::SymbolBindings& symbol_bindings) { + visitor_ = std::make_shared( + inputs, symbol_bindings); +} + } // namespace cinn::common diff --git a/paddle/cinn/common/dim_expr_converter.h b/paddle/cinn/common/dim_expr_converter.h index 374059eb182957..cefe09046f8c93 100644 --- a/paddle/cinn/common/dim_expr_converter.h +++ b/paddle/cinn/common/dim_expr_converter.h @@ -14,6 +14,8 @@ #pragma once +#include +#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h" #include "paddle/cinn/ir/ir.h" #include "paddle/pir/include/dialect/shape/utils/dim_expr.h" @@ -23,4 +25,15 @@ struct DimExprConverter final { ir::Expr ConvertToIrExpr(const symbol::DimExpr&) const; }; +struct DimExprConverterWithSymbolBindings final { + ir::Expr ConvertToIrExpr(const symbol::DimExpr&) const; + DimExprConverterWithSymbolBindings( + const std::vector& inputs, + const cinn::dialect::SymbolBindings& symbol_bindings); + + private: + struct DimExprToIrExprVisitorWithSymbolBinding; + std::shared_ptr visitor_; +}; + } // namespace cinn::common diff --git a/paddle/cinn/frontend/paddle_model_convertor.cc b/paddle/cinn/frontend/paddle_model_convertor.cc index af6784e69b8429..e1ed58b7babdf7 100644 --- a/paddle/cinn/frontend/paddle_model_convertor.cc +++ b/paddle/cinn/frontend/paddle_model_convertor.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/cinn/frontend/paddle_model_convertor.h" - #include #include @@ -25,6 +24,7 @@ #include "paddle/cinn/frontend/paddle/cpp/program_desc.h" #include "paddle/cinn/frontend/paddle/model_parser.h" #include "paddle/cinn/frontend/var_type_utils.h" +#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h" #include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/common/enforce.h" @@ -202,6 +202,8 @@ void SetOpDescAttr(const std::string& attr_name, VISITOR_EXPAND(std::vector) VISITOR_EXPAND(std::vector) #undef VISITOR_EXPAND + void operator()(const std::vector& v) {} + void operator()(const cinn::dialect::SymbolBindings& v) {} private: paddle::cpp::OpDesc* op_desc_; diff --git a/paddle/cinn/frontend/syntax.cc b/paddle/cinn/frontend/syntax.cc index 66f9c3c0c54961..97bc310b0be3f4 100644 --- a/paddle/cinn/frontend/syntax.cc +++ b/paddle/cinn/frontend/syntax.cc @@ -25,6 +25,7 @@ #include "paddle/cinn/frontend/paddle/model_parser.h" #include "paddle/cinn/frontend/paddle_model_to_program.h" +#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h" #include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/op.h" #include "paddle/cinn/utils/string.h" @@ -559,6 +560,12 @@ std::string _Instruction_::debug_string() const { void operator()(const std::vector& x) { s_ << "[" + utils::Join(x, ",") + "]"; } + void operator()(const std::vector& x) { + s_ << "[" + utils::Join(x, ",") + "]"; + } + void operator()(const cinn::dialect::SymbolBindings& x) { + s_ << "[" + utils::Join(x, ",") + "]"; + } }; std::stringstream ss; diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index e9b19b57b081d7..261794897847a6 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -238,6 +238,19 @@ std::optional ConvertAttributeToDimExpr(::pir::Attribute attribute) { return std::nullopt; } +std::optional> ConvertAttributeToDimExprs( + ::pir::Attribute attribute) { + if (!attribute.isa()) return std::nullopt; + auto array = attribute.dyn_cast(); + std::vector dim_exprs; + for (int i = 0; i < array.size(); ++i) { + const auto& dim_expr = ConvertAttributeToDimExpr(array.at(i)); + if (!dim_expr.has_value()) return std::nullopt; + dim_exprs.push_back(dim_expr.value()); + } + return dim_exprs; +} + class SubstituteDimExprHelper final { public: using DimExpr4SymbolNameT = @@ -327,7 +340,7 @@ DimExpr SubstituteDimExpr( namespace { std::optional GetDimExprBySymbolBindingImpl( - const GenerateShapeOp::DataSymbolBinding& symbol_binding, + const DataSymbolBinding& symbol_binding, const std::function& DimExpr4InputDim) { const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr = @@ -340,7 +353,7 @@ std::optional GetDimExprBySymbolBindingImpl( } std::optional GetDimExprBySymbolBindingImpl( - const GenerateShapeOp::ShapeSymbolBinding& symbol_binding, + const ShapeSymbolBinding& symbol_binding, const std::function& DimExpr4InputDim) { const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr = @@ -350,8 +363,7 @@ std::optional GetDimExprBySymbolBindingImpl( return shape_or_data_dim_expr.shape().at(dim_idx); } -std::string GetSymbolNameBySymbolBinding( - const GenerateShapeOp::SymbolBinding& symbol_binding) { +std::string GetSymbolNameBySymbolBinding(const SymbolBinding& symbol_binding) { return std::visit([](const auto& impl) { return impl.symbol_name; }, symbol_binding); } @@ -360,10 +372,10 @@ std::string GetSymbolNameBySymbolBinding( std::function(const std::string& symbol_name)> MakeGetterDimExpr4SymbolName( - const GenerateShapeOp::SymbolBindings& symbol_bindings, + const SymbolBindings& symbol_bindings, const std::function& DimExpr4InputDim) { - std::unordered_map> + std::unordered_map> symbol_name2symbol_bindins{}; for (const auto& symbol_binding : symbol_bindings) { symbol_name2symbol_bindins[GetSymbolNameBySymbolBinding(symbol_binding)] @@ -529,7 +541,7 @@ template void AppendSymbolBindings(const std::vector& dim_exprs, const std::set& symbol_names, int in_tensor_idx, - GenerateShapeOp::SymbolBindings* symbol_bindings) { + SymbolBindings* symbol_bindings) { for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size(); ++in_tensor_dim_idx) { const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx); @@ -549,14 +561,14 @@ void GenerateSymbolBindings( const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, const std::vector& input_tensors, const std::set& symbol_names, - GenerateShapeOp::SymbolBindings* symbol_bindings) { + SymbolBindings* symbol_bindings) { for (int i = 0; i < input_tensors.size(); ++i) { const auto& input_tensor = input_tensors.at(i); const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); - AppendSymbolBindings( + AppendSymbolBindings( dim_exprs.shape(), symbol_names, i, symbol_bindings); if (dim_exprs.data().has_value()) { - AppendSymbolBindings( + AppendSymbolBindings( dim_exprs.data().value(), symbol_names, i, symbol_bindings); } } @@ -606,7 +618,7 @@ bool MakeGenerateShapeOpAttribute( const std::vector& origin_inputs, std::vector* minimal_inputs, std::vector* output_dim_expr_attrs, - GenerateShapeOp::SymbolBindings* symbol_bindings) { + SymbolBindings* symbol_bindings) { *minimal_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs); if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minimal_inputs)) { VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h index 58256f83b607df..3284f53aa04dd7 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h @@ -18,6 +18,7 @@ #include #include #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h" #include "paddle/pir/include/core/builder.h" #include "paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h" @@ -29,6 +30,9 @@ ::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx, std::optional ConvertAttributeToDimExpr( ::pir::Attribute attribute); +std::optional> ConvertAttributeToDimExprs( + ::pir::Attribute attribute); + symbol::DimExpr SubstituteDimExpr( const symbol::DimExpr& dim_expr, const std::function( @@ -36,7 +40,7 @@ symbol::DimExpr SubstituteDimExpr( std::function(const std::string& symbol_name)> MakeGetterDimExpr4SymbolName( - const GenerateShapeOp::SymbolBindings& symbol_bindings, + const SymbolBindings& symbol_bindings, const std::function& DimExpr4InputDim); @@ -51,6 +55,6 @@ bool MakeGenerateShapeOpAttribute( const std::vector& origin_inputs, std::vector* minimal_inputs, std::vector* output_dim_expr_attrs, - GenerateShapeOp::SymbolBindings* symbol_bindings); + SymbolBindings* symbol_bindings); } // namespace cinn::dialect diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index bfd9348e39aed4..3e5baa1291b34e 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -478,24 +478,19 @@ GenerateShapeOp::ConvertAttributeToSymbolBindings( } return std::move(ret); } - bool GenerateShapeOp::InferSymbolicShape( pir::InferSymbolicShapeContext* infer_context) { const auto attr_dim_exprs = [&] { - std::vector dim_exprs{}; pir::Attribute dim_expr_attr = this->attributes().at("output_dim_exprs"); - PADDLE_ENFORCE(dim_expr_attr.isa(), - ::common::errors::PreconditionNotMet( - "Required dim_expr_attr is ArrayAttribute.")); - auto array = dim_expr_attr.dyn_cast(); - for (int i = 0; i < array.size(); ++i) { - const auto& dim_expr = ConvertAttributeToDimExpr(array.at(i)); - PADDLE_ENFORCE(dim_expr.has_value(), - ::common::errors::PreconditionNotMet( - "Required dim_expr.has_value()==true.")); - dim_exprs.push_back(dim_expr.value()); - } - return dim_exprs; + auto dim_exprs = ConvertAttributeToDimExprs(dim_expr_attr); + + PADDLE_ENFORCE_EQ( + dim_exprs.has_value(), + true, + ::common::errors::PreconditionNotMet( + "Required success to execute convert attribute to dim exprs.")); + + return dim_exprs.value(); }(); const auto symbol_bindings = [&] { pir::Attribute symbol_bindings_attr = diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 396f9929ecb35d..8ddff590ddef3a 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -15,6 +15,7 @@ #pragma once #include #include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h" +#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h" #include "paddle/phi/core/infermeta_utils.h" @@ -154,18 +155,11 @@ class IR_API GenerateShapeOp static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; - struct SymbolBindingBase { - std::string symbol_name; - int64_t input_tensor_idx; - int64_t input_tensor_dim_idx; - }; - - struct DataSymbolBinding : public SymbolBindingBase {}; - struct ShapeSymbolBinding : public SymbolBindingBase {}; - - using SymbolBinding = std::variant; - - using SymbolBindings = std::vector; + using SymbolBindingBase = cinn::dialect::SymbolBindingBase; + using SymbolBinding = cinn::dialect::SymbolBinding; + using ShapeSymbolBinding = cinn::dialect::ShapeSymbolBinding; + using DataSymbolBinding = cinn::dialect::DataSymbolBinding; + using SymbolBindings = cinn::dialect::SymbolBindings; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT diff --git a/paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h b/paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h new file mode 100644 index 00000000000000..bd6d70b58bf5fd --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h @@ -0,0 +1,59 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include + +namespace cinn { +namespace dialect { +struct SymbolBindingBase { + std::string symbol_name; + int64_t input_tensor_idx; + int64_t input_tensor_dim_idx; + bool operator==(const SymbolBindingBase& other) const { + return symbol_name == other.symbol_name && + input_tensor_idx == other.input_tensor_idx && + input_tensor_dim_idx == other.input_tensor_dim_idx; + } +}; + +constexpr char* kDataSymbolBinding = "DataSymbolBinding"; +constexpr char* kShapeSymbolBinding = "ShapeSymbolBinding"; + +struct DataSymbolBinding : public SymbolBindingBase { + const char* binding_type() const { return kDataSymbolBinding; } +}; +struct ShapeSymbolBinding : public SymbolBindingBase { + const char* binding_type() const { return kShapeSymbolBinding; } +}; + +using SymbolBinding = std::variant; + +using SymbolBindings = std::vector; + +inline std::ostream& operator<<(std::ostream& os, + const SymbolBinding& symbol_binding) { + std::visit( + [&](auto&& binding) { + os << binding.binding_type() << "[" << binding.symbol_name << "," + << binding.input_tensor_idx << "," << binding.input_tensor_dim_idx + << "]"; + }, + symbol_binding); +} +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/node.cc b/paddle/cinn/hlir/framework/node.cc index 4f50d930f4c7e2..e68a8d7ce1b2da 100644 --- a/paddle/cinn/hlir/framework/node.cc +++ b/paddle/cinn/hlir/framework/node.cc @@ -11,12 +11,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#include "paddle/cinn/hlir/framework/node.h" - #include #include "paddle/cinn/common/context.h" +#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h" +#include "paddle/cinn/hlir/framework/node.h" namespace cinn { namespace hlir { @@ -68,6 +67,8 @@ struct PyBindNodeAttrVisitor { VISIT_ELEMENTS(double) VISIT_ELEMENTS(bool) VISIT_ELEMENTS(std::string) + VISIT_ELEMENTS(symbol::DimExpr) + VISIT_ELEMENTS(cinn::dialect::SymbolBinding) }; } // namespace diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc index b68f0f5f8ebe08..e3a15bb8dabc21 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc @@ -590,21 +590,14 @@ std::vector GetAllForIters(const ir::Expr& expr) { } // namespace trivial_fusion_detail std::vector OperationFusion( - const std::vector<::pir::Operation*>& original_ops, + const std::vector<::pir::Operation*>& ops, const std::vector& op_compute_bodies, const std::vector<::pir::Value>& outputs) { - PADDLE_ENFORCE(FLAGS_group_schedule_tiling_first, - ::common::errors::PreconditionNotMet( - "TrivialFusion must be used with tiling first, set " - "FLAGS_group_schedule_tiling_first=1")); - const auto& ops = trivial_fusion_detail::FilterVector( - original_ops, [](const ::pir::Operation* op) { - if (op->name() == "cinn_op.generate_shape") { - return false; - } - return true; - }); - + PADDLE_ENFORCE_EQ(FLAGS_group_schedule_tiling_first, + true, + ::common::errors::PreconditionNotMet( + "TrivialFusion must be used with tiling first, set " + "FLAGS_group_schedule_tiling_first=1")); std::vector contents; for (int i = 0; i < ops.size(); i++) { contents.emplace_back(ops[i], op_compute_bodies[i]); diff --git a/paddle/cinn/hlir/framework/pir/utils.cc b/paddle/cinn/hlir/framework/pir/utils.cc index a493f2c6a22d45..38f0bd007d7606 100644 --- a/paddle/cinn/hlir/framework/pir/utils.cc +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -18,7 +18,7 @@ #include #include #include "glog/logging.h" - +#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/framework/op.h" @@ -489,10 +489,24 @@ utils::AttributeMap CompatibleInfo::ConvertAttributes( utils::AttributeMap dst_attrs; for (auto& item : src_attrs) { VLOG(4) << "deal with " << item.first; - if (item.first == ::pir::kStopGradientAttrName || - item.first == ::pir::kOutputDimExprs || - item.first == ::pir::kSymbolBindings) { + if (item.first == ::pir::kStopGradientAttrName) { continue; + } else if (item.first == ::pir::kSymbolBindings) { + auto symbol_bindings = + cinn::dialect::GenerateShapeOp::ConvertAttributeToSymbolBindings( + item.second); + PADDLE_ENFORCE(symbol_bindings.has_value(), + ::common::errors::PreconditionNotMet( + "Required success to execute convert attribute to " + "symbol bindings.")); + dst_attrs[::pir::kSymbolBindings] = symbol_bindings.value(); + } else if (item.first == ::pir::kOutputDimExprs) { + auto dim_exprs = cinn::dialect::ConvertAttributeToDimExprs(item.second); + PADDLE_ENFORCE( + dim_exprs.has_value(), + ::common::errors::PreconditionNotMet( + "Required success to execute convert attribute to dim exprs.")); + dst_attrs[::pir::kOutputDimExprs] = dim_exprs.value(); } else if (item.second.isa()) { auto is_cpu = item.second.dyn_cast().data() == diff --git a/paddle/cinn/hlir/op/custom_call.cc b/paddle/cinn/hlir/op/custom_call.cc index c090e165066600..5ce460f5156327 100644 --- a/paddle/cinn/hlir/op/custom_call.cc +++ b/paddle/cinn/hlir/op/custom_call.cc @@ -14,6 +14,7 @@ #include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/common/cas.h" +#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h" #include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/op.h" #include "paddle/cinn/hlir/framework/op_strategy.h" @@ -25,6 +26,7 @@ #include "paddle/cinn/hlir/pe/transform.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/utils/string.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" #ifdef CINN_WITH_CUDNN #include @@ -893,6 +895,8 @@ std::vector CustomCallArgsForMemset( EXPAND_MEMSET_TYPE_UNSUPPORT(std::vector) EXPAND_MEMSET_TYPE_UNSUPPORT(std::vector) EXPAND_MEMSET_TYPE_UNSUPPORT(std::vector) + EXPAND_MEMSET_TYPE_UNSUPPORT(std::vector) + EXPAND_MEMSET_TYPE_UNSUPPORT(std::vector) #undef EXPAND_MEMSET_TYPE_UNSUPPORT }; diff --git a/paddle/cinn/hlir/op/elementwise.cc b/paddle/cinn/hlir/op/elementwise.cc index f206fe7af7cc97..814b191459fa11 100644 --- a/paddle/cinn/hlir/op/elementwise.cc +++ b/paddle/cinn/hlir/op/elementwise.cc @@ -19,6 +19,7 @@ #include "absl/types/optional.h" #include "paddle/cinn/adt/op_equation_context.h" #include "paddle/cinn/common/type.h" +#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h" #include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/op.h" #include "paddle/cinn/hlir/framework/op_strategy.h" @@ -359,6 +360,14 @@ Expr GetScalarExpr(const framework::NodeAttr::attr_t &attr) { PADDLE_THROW( phi::errors::InvalidArgument("wrong type std::vector")); } + void operator()(const std::vector &) { + PADDLE_THROW(phi::errors::InvalidArgument( + "wrong type std::vector")); + } + void operator()(const std::vector &) { + PADDLE_THROW(phi::errors::InvalidArgument( + "wrong type std::vector")); + } }; absl::visit(Visitor{scalar}, attr); return scalar; @@ -1271,6 +1280,19 @@ std::shared_ptr StrategyForGenerateShapeSymbolic( const std::vector &out_type, const std::vector> &output_shapes, const Target &target) { + PADDLE_ENFORCE( + attrs.attr_store.count("output_dim_exprs"), + ::common::errors::InvalidArgument("Expected attribute output_dim_exprs " + "in strategy for generate shape op")); + PADDLE_ENFORCE( + attrs.attr_store.count("symbol_bindings"), + ::common::errors::InvalidArgument("Expected attribute symbol_bindings " + "in strategy for generate shape op")); + auto output_dim_exprs = absl::get>( + attrs.attr_store.at("output_dim_exprs")); + auto symbol_bindings = absl::get( + attrs.attr_store.at("symbol_bindings")); + framework::CINNCompute generate_shape_compute( [=](lang::Args args, lang::RetValue *ret) { PADDLE_ENFORCE(!args.empty(), @@ -1287,16 +1309,8 @@ std::shared_ptr StrategyForGenerateShapeSymbolic( auto stages = CreateStages({}); std::string tensor_name = pack_args.back().operator std::string(); - ir::Tensor out(ir::_Tensor_::Make(/*name=*/tensor_name, - /*dtype=*/common::type_of(), - /*shape=*/ - { - Expr(1), - }, - /*domain=*/ - { - Expr(1), - })); + ir::Tensor out = pe::GenerateShape( + inputs, symbol_bindings, output_dim_exprs, tensor_name); std::vector res; stages->InsertLazily(out); res.push_back(CINNValue(out)); diff --git a/paddle/cinn/hlir/pe/elementwise.cc b/paddle/cinn/hlir/pe/elementwise.cc index 906b8c2154378e..41eb7f2fd2c10b 100644 --- a/paddle/cinn/hlir/pe/elementwise.cc +++ b/paddle/cinn/hlir/pe/elementwise.cc @@ -18,6 +18,7 @@ #include #include "paddle/cinn/common/cas.h" +#include "paddle/cinn/common/dim_expr_converter.h" #include "paddle/cinn/hlir/op/op_util.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/lang/builtin.h" @@ -354,6 +355,29 @@ ir::Tensor Tril(const ir::Tensor& A, return res; } +ir::Tensor GenerateShape(const std::vector& inputs, + const cinn::dialect::SymbolBindings& symbol_bindings, + const std::vector& output_dim_exprs, + const std::string& name) { + if (output_dim_exprs.size() != 1) { + LOG(WARNING) << "pe::GenerateShape will return a meaningless tensor when " + "output_dim_exprs.size() != 1"; + return Compute( + {Expr(1)}, + [=](const std::vector& indice) { return Expr(1); }, + name); + } + cinn::common::DimExprConverterWithSymbolBindings converter(inputs, + symbol_bindings); + auto res = Compute( + {Expr(1)}, + [=, &converter](const std::vector& indice) { + return converter.ConvertToIrExpr(output_dim_exprs[0]); + }, + name); + return res; +} + ir::Tensor IsClose(const ir::Tensor& x, const ir::Tensor& y, int axis, diff --git a/paddle/cinn/hlir/pe/elementwise.h b/paddle/cinn/hlir/pe/elementwise.h index 64be14a23d05e0..6afd814ce0620c 100644 --- a/paddle/cinn/hlir/pe/elementwise.h +++ b/paddle/cinn/hlir/pe/elementwise.h @@ -17,9 +17,11 @@ #include #include +#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" namespace cinn { namespace hlir { @@ -154,6 +156,12 @@ ir::Tensor Tril(const ir::Tensor& A, const std::vector& out_shape, const std::string& name = UniqName("T_Elementwise_Tril_out")); +ir::Tensor GenerateShape( + const std::vector& inputs, + const cinn::dialect::SymbolBindings& symbol_bindings, + const std::vector& output_dim_exprs, + const std::string& name = UniqName("T_Generate_Shape_out")); + // This operator checks if all x and y satisfy the condition: |x - y| <= atol + // rtol * |y| ir::Tensor IsClose( diff --git a/paddle/cinn/utils/type_defs.h b/paddle/cinn/utils/type_defs.h index 7167c85effa597..22af77dff5c1b3 100644 --- a/paddle/cinn/utils/type_defs.h +++ b/paddle/cinn/utils/type_defs.h @@ -13,12 +13,12 @@ // limitations under the License. #pragma once - #include #include - #include #include +#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" namespace cinn { namespace utils { @@ -35,7 +35,10 @@ using Attribute = absl::variant, - std::vector>; + std::vector, + // the followings are only for generate shape op + std::vector, + cinn::dialect::SymbolBindings>; using AttributeMap = absl::flat_hash_map; // shape type defs diff --git a/paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc index 63a3d4daad17cf..74f7a793e26e1d 100644 --- a/paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc +++ b/paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc @@ -390,4 +390,4 @@ void AddShapeOptimizationPass( } // namespace pir::shape -REGISTER_IR_PASS(shape_optimization_pass, pir::ShapeOptimizationPass); +// REGISTER_IR_PASS(shape_optimization_pass, pir::ShapeOptimizationPass);