diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 8006551ac33a0a..39ecd66d3e1c76 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -339,6 +339,14 @@ bool GatherNdOpInferSymbolicShape( return true; } +bool IndexSampleOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); + return true; +} + bool KronOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index bb349d1f900fc1..2b272992e4cc0a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -26,6 +26,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ExpandAs) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gather) OP_DECLARE_INFER_SYMBOLIC_SHAPE(GatherNd) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Isclose) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSample) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 0d3bd4700879f2..a1bfce7e8a92ed 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1593,6 +1593,7 @@ backward : index_sample_grad data_transform : skip_transform : index + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : index_select args : (Tensor x, Tensor index, int axis = 0)