Skip to content

Commit 1f5dfc9

Browse files
committed
change shape_analysis to infer_context
1 parent 636ccff commit 1f5dfc9

2 files changed

Lines changed: 13 additions & 13 deletions

File tree

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -505,20 +505,20 @@ bool MemoryEfficientAttentionOpInferSymbolicShape(
505505
}
506506

507507
bool RoiAlignOpInferSymbolicShape(
508-
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
508+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
509509
const auto &x = op->operand_source(0);
510510
const auto &boxes = op->operand_source(1);
511511

512512
const auto &num_boxes =
513-
shape_analysis->GetShapeOrDataForValue(boxes).shape()[0];
513+
infer_context->GetShapeOrDataForValue(boxes).shape()[0];
514514
symbol::DimExpr channel_num =
515-
shape_analysis->GetShapeOrDataForValue(x).shape()[1];
515+
infer_context->GetShapeOrDataForValue(x).shape()[1];
516516

517517
int32_t out_h = op->attribute<pir::Int32Attribute>("pooled_height").data();
518518
int32_t out_w = op->attribute<pir::Int32Attribute>("pooled_width").data();
519519

520520
std::vector<symbol::DimExpr> out_dim = {num_boxes, channel_num, out_h, out_w};
521-
shape_analysis->SetShapeOrDataForValue(
521+
infer_context->SetShapeOrDataForValue(
522522
op->result(0), symbol::TensorShapeOrDataDimExprs(out_dim));
523523
return true;
524524
}

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ bool DiagonalOpInferSymbolicShape(
247247
}
248248

249249
bool DistributeFpnProposalsOpInferSymbolicShape(
250-
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
250+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
251251
const auto &attributes = op->attributes();
252252
int32_t min_level =
253253
attributes.at("min_level").dyn_cast<pir::Int32Attribute>().data();
@@ -258,7 +258,7 @@ bool DistributeFpnProposalsOpInferSymbolicShape(
258258
symbol::DimExpr num_rois = [&]() {
259259
pir::Value rois_num = op->operand_source(1);
260260
const auto &rois_num_shape_or_data =
261-
shape_analysis->GetShapeOrDataForValue(rois_num);
261+
infer_context->GetShapeOrDataForValue(rois_num);
262262

263263
PADDLE_ENFORCE_EQ(
264264
rois_num_shape_or_data.shape()[0],
@@ -290,7 +290,7 @@ bool DistributeFpnProposalsOpInferSymbolicShape(
290290
} else {
291291
symbol::DimExpr last_dim = num_rois;
292292
for (int i = 0; i < num_levels - 1; i++) {
293-
const auto &next_sym_name = shape_analysis->GetNextSymName();
293+
const auto &next_sym_name = infer_context->GetNextSymName();
294294
std::vector<symbol::DimExpr> level_dim = {next_sym_name, 4};
295295
multi_rois_out_shape.emplace_back(
296296
symbol::TensorShapeOrDataDimExprs(level_dim));
@@ -314,15 +314,15 @@ bool DistributeFpnProposalsOpInferSymbolicShape(
314314
return symbol::TensorShapeOrDataDimExprs({num_rois, 1});
315315
}();
316316

317-
shape_analysis->SetShapeOrDataForValue(op->result(0), multi_rois_out_shape);
318-
shape_analysis->SetShapeOrDataForValue(op->result(1),
319-
rois_num_per_level_out_shape);
320-
shape_analysis->SetShapeOrDataForValue(op->result(2), restore_ind);
317+
infer_context->SetShapeOrDataForValue(op->result(0), multi_rois_out_shape);
318+
infer_context->SetShapeOrDataForValue(op->result(1),
319+
rois_num_per_level_out_shape);
320+
infer_context->SetShapeOrDataForValue(op->result(2), restore_ind);
321321
return true;
322322
}
323323

324-
bool EinsumOpInferSymbolicShape(
325-
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
324+
bool EinsumOpInferSymbolicShape(pir::Operation *op,
325+
pir::InferSymbolicShapeContext *infer_context) {
326326
PADDLE_THROW(phi::errors::Unimplemented(
327327
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
328328
return true;

0 commit comments

Comments
 (0)