Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,30 @@ bool SumOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool ProdOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
auto attributes = op->attributes();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以使用 const auto&

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,后续统一修改一下吧,免得这个pr再重新过ci了

bool keepdim = attributes["keep_dim"].dyn_cast<pir::BoolAttribute>().data();

bool reduce_all =
attributes["reduce_all"].dyn_cast<pir::BoolAttribute>().data();

auto axis_gen_op = op->operand_source(1).defining_op();
if (axis_gen_op->isa<paddle::dialect::FullIntArrayOp>()) {
std::vector<int64_t> axis = GetVectorAttr(
axis_gen_op->dyn_cast<paddle::dialect::FullIntArrayOp>(), "value");
return ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all);
} else {
// TODO(lanxianghit): deal with other source: pir::VectorType,
// paddle::dialect::DenseTensorType
PADDLE_THROW(
phi::errors::Unimplemented("ProdOpInferSymbolicShape: 'axis' only "
"support FullIntArrayOp's result now."));
}

return true;
}

bool ReshapeOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source_shape = op->operand_source(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ bool ReluOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);
bool Relu_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

bool ProdOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

} // namespace paddle::dialect

namespace cinn::dialect {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,7 @@
kernel :
func : prod
backward : prod_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : push_sparse_v2
args : (Tensor[] ids, Tensor[] w, Tensor[] out_grad_in, int embeddingdim = 11, int tableid = 0, str accessorclass = "", str ctrlabelname = "", int paddingid = 0, bool scalesparsegrad = true, str[] inputnames = {}, bool is_distributed = true)
Expand Down