diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 3c5a086864c267..7b999f9ac3140e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -370,6 +370,30 @@ bool SumOpInferSymbolicShape(pir::Operation *op, return true; } +bool ProdOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + auto attributes = op->attributes(); + bool keepdim = attributes["keep_dim"].dyn_cast().data(); + + bool reduce_all = + attributes["reduce_all"].dyn_cast().data(); + + auto axis_gen_op = op->operand_source(1).defining_op(); + if (axis_gen_op->isa()) { + std::vector axis = GetVectorAttr( + axis_gen_op->dyn_cast(), "value"); + return ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all); + } else { + // TODO(lanxianghit): deal with other source: pir::VectorType, + // paddle::dialect::DenseTensorType + PADDLE_THROW( + phi::errors::Unimplemented("ProdOpInferSymbolicShape: 'axis' only " + "support FullIntArrayOp's result now.")); + } + + return true; +} + bool ReshapeOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value operand_source_shape = op->operand_source(1); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h index d85a2485ac505d..994e5f048408e1 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h @@ -182,6 +182,10 @@ bool ReluOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); bool Relu_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool ProdOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + } // namespace paddle::dialect namespace cinn::dialect { diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 4593a7e7d7c427..8b47531b67767b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1041,6 +1041,7 @@ kernel : func : prod backward : prod_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : push_sparse_v2 args : (Tensor[] ids, Tensor[] w, Tensor[] out_grad_in, int embeddingdim = 11, int tableid = 0, str accessorclass = "", str ctrlabelname = "", int paddingid = 0, bool scalesparsegrad = true, str[] inputnames = {}, bool is_distributed = true)