@@ -247,7 +247,7 @@ bool DiagonalOpInferSymbolicShape(
247247}
248248
249249bool DistributeFpnProposalsOpInferSymbolicShape (
250- pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis ) {
250+ pir::Operation *op, pir::InferSymbolicShapeContext *infer_context ) {
251251 const auto &attributes = op->attributes ();
252252 int32_t min_level =
253253 attributes.at (" min_level" ).dyn_cast <pir::Int32Attribute>().data ();
@@ -258,7 +258,7 @@ bool DistributeFpnProposalsOpInferSymbolicShape(
258258 symbol::DimExpr num_rois = [&]() {
259259 pir::Value rois_num = op->operand_source (1 );
260260 const auto &rois_num_shape_or_data =
261- shape_analysis ->GetShapeOrDataForValue (rois_num);
261+ infer_context ->GetShapeOrDataForValue (rois_num);
262262
263263 PADDLE_ENFORCE_EQ (
264264 rois_num_shape_or_data.shape ()[0 ],
@@ -290,7 +290,7 @@ bool DistributeFpnProposalsOpInferSymbolicShape(
290290 } else {
291291 symbol::DimExpr last_dim = num_rois;
292292 for (int i = 0 ; i < num_levels - 1 ; i++) {
293- const auto &next_sym_name = shape_analysis ->GetNextSymName ();
293+ const auto &next_sym_name = infer_context ->GetNextSymName ();
294294 std::vector<symbol::DimExpr> level_dim = {next_sym_name, 4 };
295295 multi_rois_out_shape.emplace_back (
296296 symbol::TensorShapeOrDataDimExprs (level_dim));
@@ -314,15 +314,15 @@ bool DistributeFpnProposalsOpInferSymbolicShape(
314314 return symbol::TensorShapeOrDataDimExprs ({num_rois, 1 });
315315 }();
316316
317- shape_analysis ->SetShapeOrDataForValue (op->result (0 ), multi_rois_out_shape);
318- shape_analysis ->SetShapeOrDataForValue (op->result (1 ),
319- rois_num_per_level_out_shape);
320- shape_analysis ->SetShapeOrDataForValue (op->result (2 ), restore_ind);
317+ infer_context ->SetShapeOrDataForValue (op->result (0 ), multi_rois_out_shape);
318+ infer_context ->SetShapeOrDataForValue (op->result (1 ),
319+ rois_num_per_level_out_shape);
320+ infer_context ->SetShapeOrDataForValue (op->result (2 ), restore_ind);
321321 return true ;
322322}
323323
324- bool EinsumOpInferSymbolicShape (
325- pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis ) {
324+ bool EinsumOpInferSymbolicShape (pir::Operation *op,
325+ pir::InferSymbolicShapeContext *infer_context ) {
326326 PADDLE_THROW (phi::errors::Unimplemented (
327327 op->name () + " 's InferSymbolicShape interface is NOT implemented now." ));
328328 return true ;
0 commit comments