diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc index d52270e5b3b667..d5da282de676b5 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc @@ -76,7 +76,7 @@ bool ConcatOpInferSymbolicShape( out_dims[axis] = out_dims[axis] + operand_shape_or_data.shape()[axis]; } - for (size_t i = 1; i < rank; ++i) { + for (size_t i = 0; i < rank; ++i) { if (i == static_cast(axis)) continue; paddle::dialect::details::BuildCstrEqForTensorListAlongAxis( shape_analysis, input_values, i); @@ -85,6 +85,9 @@ bool ConcatOpInferSymbolicShape( return out_dims; }; + VLOG(3) << "constraints size:" + << shape_analysis->CreateDimExprBuilder().constraints().size(); + symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(GetOutDimExprs())}; diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h index dc2794ac6f90be..b3cc2232a1f91c 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" namespace cinn::dialect { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h index e392023aa0c339..65fa20c8e63e7a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" namespace paddle::dialect { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h index 515eaaca1b3484..c44f6c70fe33b2 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h @@ -16,6 +16,7 @@ #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h" diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc new file mode 100644 index 00000000000000..d3e4b38b57a5b5 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" + +namespace paddle::dialect { + +bool EmptyOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &shape_gen_op = op->operand_source(0).defining_op(); + if (shape_gen_op->isa()) { + std::vector shape = details::GetVectorAttr( + shape_gen_op->dyn_cast(), "value"); + std::vector sym_dims; + sym_dims.reserve(shape.size()); + for (const int64_t &dim : shape) { + sym_dims.emplace_back(symbol::DimExpr(dim)); + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(sym_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; + + } else { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + + shape_analysis->SetShapeOrDataForValue(op->result(0), + operand_shape_or_data); + return true; + } +} + +bool GaussianOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &shape_gen_op = op->operand_source(0).defining_op(); + + if (shape_gen_op->isa()) { + std::vector shape = details::GetVectorAttr( + shape_gen_op->dyn_cast(), "value"); + std::vector sym_dims; + sym_dims.reserve(shape.size()); + for (const int64_t &dim : shape) { + sym_dims.emplace_back(symbol::DimExpr(dim)); + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(sym_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; + + } else { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; + } +} + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h new file mode 100644 index 00000000000000..7e706bf942f83c --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h @@ -0,0 +1,22 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace paddle::dialect { +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Empty) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gaussian) +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index 9003b88c18fd34..9192478548d51d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -97,7 +97,6 @@ bool StackOpInferSymbolicShape(pir::Operation *op, static_cast(shape_data_list.size())); } else { for (int i = 0; i < rank; ++i) { - if (i == axis) continue; details::BuildCstrEqForTensorListAlongAxis( shape_analysis, shape_data_list, i); } @@ -931,26 +930,6 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, } // Not Implemented Ops. - -bool DiagEmbedOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool DiagonalOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool DirichletOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} - bool GatherOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { const auto &input_shape_or_data = @@ -1020,17 +999,33 @@ bool GatherOpInferSymbolicShape( bool KronOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool KthvalueOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + const auto &x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); + const auto &y_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)).shape(); + const int rank_x = x_shape_or_data.size(); + const int rank_y = y_shape_or_data.size(); + const int rank = (rank_x > rank_y) ? rank_x : rank_y; + + std::vector dim_out; + dim_out.reserve(rank); + const auto one = symbol::DimExpr{1}; + const auto minus_one = symbol::DimExpr{-1}; + for (int i = 0; i < rank; i++) { + symbol::DimExpr dim_xi = + (i < rank - rank_x) ? one : x_shape_or_data.at(i - (rank - rank_x)); + symbol::DimExpr dim_yi = + (i < rank - rank_y) ? one : y_shape_or_data.at(i - (rank - rank_y)); + dim_out.push_back(dim_xi * dim_yi); + } + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(dim_out)}; + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; } +// Not Impelmented Ops. bool LogcumsumexpOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( @@ -1095,32 +1090,6 @@ bool UniqueConsecutiveOpInferSymbolicShape( op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } - -bool EinsumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool EmptyOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Exponential_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool GaussianOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} - bool LinspaceOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h index 9ad13dd02933e5..a84d71815549b6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" namespace paddle::dialect { @@ -51,12 +50,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Split) // Not Impelmented Ops. -OP_DECLARE_INFER_SYMBOLIC_SHAPE(DiagEmbed) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(Diagonal) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dirichlet) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gather) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kthvalue) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Poisson) @@ -67,10 +62,6 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(TakeAlongAxis) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Topk) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unbind) OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniqueConsecutive) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(Einsum) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(Empty) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(Exponential_) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gaussian) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Linspace) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.cc index bb540647d0219e..f6d45dad1956a2 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.cc @@ -154,6 +154,10 @@ bool Digamma_OpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { return SameOperandsAndResultShape(op, shape_analysis); } +bool DirichletOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return SameOperandsAndResultShape(op, shape_analysis); +} bool EqualOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { return SameOperandsAndResultShape(op, shape_analysis); @@ -194,6 +198,10 @@ bool Expm1_OpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { return SameOperandsAndResultShape(op, shape_analysis); } +bool Exponential_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return SameOperandsAndResultShape(op, shape_analysis); +} bool FetchOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { return SameOperandsAndResultShape(op, shape_analysis); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h index dc77d9cd70bb4d..6afe08d753a55a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" namespace paddle::dialect { @@ -50,6 +49,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cosh) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cosh_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Digamma) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Digamma_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dirichlet) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Equal) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Equal_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Erf) @@ -60,6 +60,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Exp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Exp_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Expm1) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Expm1_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Exponential_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Fetch) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flip) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Floor) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index c2e17f1f8f8c67..42067e28e310af 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -165,6 +165,121 @@ bool Cumsum_OpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { return CumsumOpInferSymbolicShape(op, shape_analysis); } +bool DiagEmbedOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + const auto &attributes = op->attributes(); + int dim1 = attributes.at("dim1").dyn_cast().data(); + int dim2 = attributes.at("dim2").dyn_cast().data(); + int offset = attributes.at("offset").dyn_cast().data(); + + const auto &x_dims = operand_shape_or_data.shape(); + int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1; + int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2; + int64_t offset_ = static_cast(std::abs(offset)); + symbol::DimExpr new_dim_len = + symbol::DimExpr(offset_) + x_dims[x_dims.size() - 1]; + + const auto &out_dims = [&] { + std::vector out_dims = x_dims; + out_dims.pop_back(); + out_dims.insert(out_dims.begin() + std::min(dim1_, dim2_), new_dim_len); + out_dims.insert(out_dims.begin() + std::max(dim1_, dim2_), new_dim_len); + return out_dims; + }(); + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} +bool DiagonalOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + const auto &attributes = op->attributes(); + int axis1 = attributes.at("axis1").dyn_cast().data(); + int axis2 = attributes.at("axis2").dyn_cast().data(); + int offset = attributes.at("offset").dyn_cast().data(); + + const auto &x_dims = operand_shape_or_data.shape(); + int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1; + int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2; + + auto out_dims = x_dims; + auto axis1_size = out_dims[axis1_]; + auto axis2_size = out_dims[axis2_]; + out_dims.erase(out_dims.begin() + std::max(axis1_, axis2_)); + out_dims.erase(out_dims.begin() + std::min(axis1_, axis2_)); + + symbol::DimExprBuilder builder{nullptr}; + symbol::DimExpr zero{0}; + symbol::DimExpr res_shape; + symbol::DimExpr offset_sym{offset}; + if (offset == 0) { + res_shape = builder.Min(axis1_size, axis2_size); + } else if (offset > 0) { + if (axis2_size.isa()) { + res_shape = (axis2_size.dyn_cast() - offset) > 0 + ? builder.Min(axis1_size, axis2_size - offset_sym) + : zero; + } else { + res_shape = shape_analysis->GetNextSymName(); + } + } else { + if (axis1_size.isa()) { + res_shape = (axis1_size.dyn_cast() + offset) > 0 + ? builder.Min(axis1_size + offset_sym, axis2_size) + : zero; + } else { + res_shape = shape_analysis->GetNextSymName(); + } + } + out_dims.push_back(res_shape); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} + +bool EinsumOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; +} + +bool KthvalueOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + const auto &attributes = op->attributes(); + int axis = attributes.at("axis").dyn_cast().data(); + bool keepdim = GetBoolAttr(op, "keepdim"); + + const auto &input_dims = operand_shape_or_data.shape(); + const int &dim_size = input_dims.size(); + if (axis < 0) axis += dim_size; + std::vector out_dims; + for (int i = 0; i < axis; i++) { + out_dims.emplace_back(input_dims[i]); + } + if (keepdim && dim_size > 0) { + out_dims.emplace_back(symbol::DimExpr(1)); + } + for (int i = axis + 1; i < dim_size; i++) { + out_dims.emplace_back(input_dims[i]); + } + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + shape_analysis->SetShapeOrDataForValue(op->result(1), shape_data); + return true; +} bool ReshapeOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value operand_source = op->operand_source(0); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 8d47e5a5fd91ea..aeeb03713f4814 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" namespace paddle::dialect { @@ -29,6 +28,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumprod) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumprod_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumsum) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumsum_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(DiagEmbed) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Diagonal) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Einsum) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kthvalue) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape_) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.h b/paddle/fluid/pir/dialect/operator/utils/utils.h index a0248993caaafe..fd8ec68401b086 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.h +++ b/paddle/fluid/pir/dialect/operator/utils/utils.h @@ -28,10 +28,6 @@ namespace dialect { using VariantType = phi::Attribute; -#define OP_DECLARE_INFER_SYMBOLIC_SHAPE(name) \ - bool name##OpInferSymbolicShape( \ - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis); - // TODO(zhangbo): The builtin type needs to cover all data types of // phi::DataType. static inline phi::DataType TransToPhiDataType(pir::Type dtype) { diff --git a/paddle/pir/include/dialect/shape/utils/shape_analysis.h b/paddle/pir/include/dialect/shape/utils/shape_analysis.h index 284487b7210c5a..04625f3047e401 100644 --- a/paddle/pir/include/dialect/shape/utils/shape_analysis.h +++ b/paddle/pir/include/dialect/shape/utils/shape_analysis.h @@ -100,4 +100,8 @@ class IR_API ShapeAnalysisManager { std::unordered_map tables_; }; +#define OP_DECLARE_INFER_SYMBOLIC_SHAPE(name) \ + bool name##OpInferSymbolicShape( \ + pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis); + } // namespace pir diff --git a/test/ir/pir/cinn/symbolic/test_binary_op_infer_sym_shape.py b/test/ir/pir/cinn/symbolic/test_binary_op_infer_sym_shape.py new file mode 100644 index 00000000000000..ab190bf57476e6 --- /dev/null +++ b/test/ir/pir/cinn/symbolic/test_binary_op_infer_sym_shape.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.static import InputSpec + + +def get_sym_shape_str_for_op(net, input_spec, op_name='builtin.shadow_output'): + forward_program = net.forward.get_concrete_program(*input_spec)[ + 1 + ].infer_program.forward_program + all_sym_shape_str = [] + for op in forward_program.global_block().ops: + if op.name() == op_name: + all_sym_shape_str.append(op.attrs()['sym_shape_str']) + + return all_sym_shape_str + + +def apply_to_static(net, use_cinn, input_spec=None): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static( + net, + input_spec=input_spec, + build_strategy=build_strategy, + full_graph=True, + ) + + +class TestBase(unittest.TestCase): + def setUp(self): + paddle.seed(2022) + self.prepare_data() + + def prepare_data(self): + pass + + def test_eval_symbolic(self): + pass + + +class KronNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + y = paddle.empty(shape=[2, 2]) + z = paddle.empty(shape=[3, 3]) + out = paddle.kron(x, y) + out = paddle.kron(y, z) + return out + + +class TestKronOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(4, 5, 6)] + + self.expected = [ + [ + 'shape[Mul(S0, 1), Mul(S1, 2), Mul(S2, 2)], data[NULL]', + 'shape[6, 6], data[NULL]', + ] + ] + + def test_eval_symbolic(self): + net = KronNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.kron' + ) + np.testing.assert_equal( + len(sym_shape_str_list), len(self.expected[i]) + ) + for j in range(len(sym_shape_str_list)): + np.testing.assert_equal( + sym_shape_str_list[j].find(self.expected[i][j]), + 0, + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + ) + + return True + + +if __name__ == '__main__': + unittest.main() diff --git a/test/ir/pir/cinn/symbolic/test_nullary_op_infer_sym_shape.py b/test/ir/pir/cinn/symbolic/test_nullary_op_infer_sym_shape.py new file mode 100644 index 00000000000000..1df40d9bcb4af6 --- /dev/null +++ b/test/ir/pir/cinn/symbolic/test_nullary_op_infer_sym_shape.py @@ -0,0 +1,156 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.static import InputSpec + + +def get_sym_shape_str_for_op(net, input_spec, op_name='builtin.shadow_output'): + forward_program = net.forward.get_concrete_program(*input_spec)[ + 1 + ].infer_program.forward_program + all_sym_shape_str = [] + for op in forward_program.global_block().ops: + if op.name() == op_name: + all_sym_shape_str.append(op.attrs()['sym_shape_str']) + + return all_sym_shape_str + + +def apply_to_static(net, use_cinn, input_spec=None): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static( + net, + input_spec=input_spec, + build_strategy=build_strategy, + full_graph=True, + ) + + +class TestBase(unittest.TestCase): + def setUp(self): + paddle.seed(2022) + self.prepare_data() + + def prepare_data(self): + pass + + def test_eval_symbolic(self): + pass + + +class EmptyNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + out = paddle.empty(shape=[128, 32]) + out = paddle.empty(shape=x) + return out + + +class TestEmptyOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(4, 5, 6)] + self.expected = [ + [ + 'shape[128, 32], data[NULL]', + 'shape[S0, S1, S2], data[NULL]', + ] + ] + + def test_eval_symbolic(self): + net = EmptyNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.empty' + ) + np.testing.assert_equal( + len(sym_shape_str_list), len(self.expected[i]) + ) + for j in range(len(sym_shape_str_list)): + np.testing.assert_equal( + sym_shape_str_list[j].find(self.expected[i][j]), + 0, + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + ) + + return True + + +class GaussianNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + out = paddle.tensor.random.gaussian(shape=[12, 32], mean=1.0, std=2.0) + return out + + +class TestGaussianOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(4, 5, 6)] + self.expected = [ + [ + 'shape[12, 32], data[NULL]', + ] + ] + + def test_eval_symbolic(self): + net = GaussianNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec(shape=[None, None, 2], dtype='float32') + + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.gaussian' + ) + + np.testing.assert_equal( + len(sym_shape_str_list), len(self.expected[i]) + ) + for j in range(len(sym_shape_str_list)): + np.testing.assert_equal( + sym_shape_str_list[j].find(self.expected[i][j]), + 0, + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + ) + + return True + + +if __name__ == '__main__': + unittest.main() diff --git a/test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py b/test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py index be6741661295ad..f12bad12ea2f15 100644 --- a/test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py +++ b/test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py @@ -265,5 +265,171 @@ def test_eval_symbolic(self): return True +class DiagEmbedNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + data = paddle.empty([6]) + out = paddle.diag_embed(data) + out = paddle.diag_embed(data, offset=-1, dim1=0, dim2=1) + out = paddle.diag_embed(x) + out = paddle.diag_embed(x, offset=-1, dim1=0, dim2=1) + return out + + +class TestDiagEmbedOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(4, 5, 6)] + self.expected = [ + [ + 'shape[6, 6], data[NULL]', + 'shape[7, 7], data[NULL]', + 'shape[S0, S1, Add(0, S2), Add(0, S2)], data[NULL]', + 'shape[Add(1, S2), Add(1, S2), S0, S1], data[NULL]', + ] + ] + + def test_eval_symbolic(self): + net = DiagEmbedNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.diag_embed' + ) + + np.testing.assert_equal( + len(sym_shape_str_list), len(self.expected[i]) + ) + for j in range(len(sym_shape_str_list)): + np.testing.assert_equal( + sym_shape_str_list[j].find(self.expected[i][j]), + 0, + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + ) + + return True + + +class DiagonalNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + data = paddle.empty([2, 2, 3], 'float32') + out = paddle.diagonal(data) + out = paddle.diagonal(data, offset=0, axis1=2, axis2=1) + out = paddle.diagonal(x) + out = paddle.diagonal(x, offset=0, axis1=2, axis2=1) + out = paddle.diagonal(x, offset=1, axis1=2, axis2=1) + out = paddle.diagonal(x, offset=-1, axis1=2, axis2=1) + return out + + +class TestDiagonalOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(4, 5, 6)] + self.expected = [ + [ + 'shape[3, Min(2, 2)], data[NULL]', + 'shape[2, Min(3, 2)], data[NULL]', + 'shape[S2, Min(S0, S1)], data[NULL]', + 'shape[S0, Min(S2, S1)], data[NULL]', + 'shape[S0, S3], data[NULL]', + 'shape[S0, S4], data[NULL]', + ] + ] + + def test_eval_symbolic(self): + net = DiagonalNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.diagonal' + ) + + np.testing.assert_equal( + len(sym_shape_str_list), len(self.expected[i]) + ) + for j in range(len(sym_shape_str_list)): + np.testing.assert_equal( + sym_shape_str_list[j].find(self.expected[i][j]), + 0, + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + ) + + return True + + +class KthvalueNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + data = paddle.empty([2, 3, 3], 'float32') + out = paddle.kthvalue(data, 2, 1) + return out + + +class TestKthvalueOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(4, 5, 6)] + self.expected = [ + [ + 'shape[2, 3], data[NULL]', + ] + ] + + def test_eval_symbolic(self): + net = KthvalueNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.kthvalue' + ) + + np.testing.assert_equal( + len(sym_shape_str_list), len(self.expected[i]) + ) + for j in range(len(sym_shape_str_list)): + np.testing.assert_equal( + sym_shape_str_list[j].find(self.expected[i][j]), + 0, + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + ) + + return True + + if __name__ == '__main__': unittest.main()