Skip to content

Commit 1fa8e01

Browse files
authored
[PIR][DynamicShape] Add InferSymbolicShapeInterface for prod (#61342)
Add InferSymbolicShapeInterface for prod
1 parent bb5a95c commit 1fa8e01

3 files changed

Lines changed: 29 additions & 0 deletions

File tree

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,30 @@ bool SumOpInferSymbolicShape(pir::Operation *op,
389389
return true;
390390
}
391391

392+
bool ProdOpInferSymbolicShape(pir::Operation *op,
393+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
394+
auto attributes = op->attributes();
395+
bool keepdim = attributes["keep_dim"].dyn_cast<pir::BoolAttribute>().data();
396+
397+
bool reduce_all =
398+
attributes["reduce_all"].dyn_cast<pir::BoolAttribute>().data();
399+
400+
auto axis_gen_op = op->operand_source(1).defining_op();
401+
if (axis_gen_op->isa<paddle::dialect::FullIntArrayOp>()) {
402+
std::vector<int64_t> axis = GetVectorAttr(
403+
axis_gen_op->dyn_cast<paddle::dialect::FullIntArrayOp>(), "value");
404+
return ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all);
405+
} else {
406+
// TODO(lanxianghit): deal with other source: pir::VectorType,
407+
// paddle::dialect::DenseTensorType
408+
PADDLE_THROW(
409+
phi::errors::Unimplemented("ProdOpInferSymbolicShape: 'axis' only "
410+
"support FullIntArrayOp's result now."));
411+
}
412+
413+
return true;
414+
}
415+
392416
bool ReshapeOpInferSymbolicShape(
393417
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
394418
pir::Value operand_source_shape = op->operand_source(1);

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ bool ReluOpInferSymbolicShape(pir::Operation *op,
182182
pir::ShapeConstraintIRAnalysis *shape_analysis);
183183
bool Relu_OpInferSymbolicShape(pir::Operation *op,
184184
pir::ShapeConstraintIRAnalysis *shape_analysis);
185+
186+
bool ProdOpInferSymbolicShape(pir::Operation *op,
187+
pir::ShapeConstraintIRAnalysis *shape_analysis);
188+
185189
} // namespace paddle::dialect
186190

187191
namespace cinn::dialect {

paddle/fluid/pir/dialect/operator/ir/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,7 @@
10411041
kernel :
10421042
func : prod
10431043
backward : prod_grad
1044+
interfaces : paddle::dialect::InferSymbolicShapeInterface
10441045

10451046
- op : push_sparse_v2
10461047
args : (Tensor[] ids, Tensor[] w, Tensor[] out_grad_in, int embeddingdim = 11, int tableid = 0, str accessorclass = "", str ctrlabelname = "", int paddingid = 0, bool scalesparsegrad = true, str[] inputnames = {}, bool is_distributed = true)

0 commit comments

Comments
 (0)