Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,145 @@ bool Where_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return WhereOpInferSymbolicShape(op, shape_analysis);
}

bool AssignOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool Assign_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return AssignOpInferSymbolicShape(op, shape_analysis);
}

bool BitwiseAndOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool BitwiseAnd_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return BitwiseAndOpInferSymbolicShape(op, shape_analysis);
}

bool FeedOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool GreaterThanOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool GreaterThan_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return GreaterThanOpInferSymbolicShape(op, shape_analysis);
}

bool IncrementOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool Increment_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return IncrementOpInferSymbolicShape(op, shape_analysis);
}

bool LessThanOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool LessThan_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return LessThanOpInferSymbolicShape(op, shape_analysis);
}

bool LogicalAndOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool LogicalAnd_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return LogicalAndOpInferSymbolicShape(op, shape_analysis);
}

bool LogicalNotOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool LogicalNot_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return LogicalNotOpInferSymbolicShape(op, shape_analysis);
}

bool NotEqualOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool NotEqual_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return NotEqualOpInferSymbolicShape(op, shape_analysis);
}

bool TopPSamplingOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool LogOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool Log_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return LogOpInferSymbolicShape(op, shape_analysis);
}

bool ExpandAsOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool SplitOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

} // namespace paddle::dialect

namespace cinn::dialect {

bool SliceOpInferSymbolicShape(pir::Operation *op,
Expand Down
66 changes: 66 additions & 0 deletions paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,72 @@ bool WhereOpInferSymbolicShape(pir::Operation *op,
bool Where_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

bool AssignOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

bool Assign_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool BitwiseAndOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool BitwiseAnd_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool FeedOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

bool GreaterThanOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool GreaterThan_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool IncrementOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool Increment_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool LessThanOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool LessThan_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool LogicalAndOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool LogicalAnd_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool LogicalNotOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool LogicalNot_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool NotEqualOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool NotEqual_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool TopPSamplingOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool LogOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

bool Log_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

bool ExpandAsOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool SplitOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

} // namespace paddle::dialect

namespace cinn::dialect {
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
func : assign
backward : assign_grad
inplace : (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : assign_out_
args : (Tensor x, Tensor output)
Expand Down Expand Up @@ -532,6 +533,7 @@
- op : feed
args : (str name, int col)
output : Tensor(out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : fetch
args : (Tensor x, str name, int col)
Expand Down Expand Up @@ -722,6 +724,7 @@
data_transform :
support_trans_dtype : x, y
inplace: (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : hardswish
args : (Tensor x)
Expand Down Expand Up @@ -752,6 +755,7 @@
kernel :
func : increment
inplace : (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : less_equal
args : (Tensor x, Tensor y)
Expand All @@ -774,6 +778,7 @@
data_transform :
support_trans_dtype : x, y
inplace: (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : linspace
args : (Tensor start, Tensor stop, Tensor number, DataType dtype, Place place)
Expand Down Expand Up @@ -986,6 +991,7 @@
data_transform :
support_trans_dtype : x, y
inplace: (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : one_hot
args : (Tensor x, Scalar(int) num_classes)
Expand Down Expand Up @@ -1310,6 +1316,7 @@
kernel :
func : split
backward : split_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : split_with_num
args : (Tensor x, int num, Scalar(int) axis)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@
func : bitwise_and
backend : x
inplace: (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : bitwise_left_shift
args : (Tensor x, Tensor y, bool is_arithmetic = true)
Expand Down Expand Up @@ -921,6 +922,7 @@
data_type : x
optional : y
backward : expand_as_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : expm1
args : (Tensor x)
Expand Down Expand Up @@ -1618,6 +1620,7 @@
func : log
inplace: (x -> out)
backward: log_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : log10
args : (Tensor x)
Expand Down Expand Up @@ -1687,6 +1690,7 @@
data_type : x
backend : x
inplace: (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : logical_not
args : (Tensor x)
Expand All @@ -1698,6 +1702,7 @@
data_type : x
backend : x
inplace: (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : logical_or
args : (Tensor x, Tensor y)
Expand Down Expand Up @@ -2784,6 +2789,7 @@
func : top_p_sampling
data_type : x
optional : threshold
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : topk
args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true)
Expand Down