diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc index 1adc4788b096f8..31d3bc87aa4a55 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc @@ -133,13 +133,25 @@ bool ScaleOpInferSymbolicShape(pir::Operation *op, shape_analysis->GetShapeOrDataForValue(operand_source); std::vector shape(operand_shape_or_data.shape()); - std::vector data; if (operand_shape_or_data.data()) { - for (auto &val : *(operand_shape_or_data.data())) { - int scale = op->attribute("scale").dyn_cast().data(); + const std::vector data = [&] { + const symbol::DimExpr scale = [&]() -> symbol::DimExpr { + if (op->num_operands() == 2) { + return shape_analysis->GetShapeOrDataForValue(op->operand_source(1)) + .data() + ->at(0); + } + return static_cast( + op->attribute("scale").dyn_cast().data()); + }(); int bias = op->attribute("bias").dyn_cast().data(); - data.push_back(val * scale + bias); - } + + std::vector data; + for (auto &val : *(operand_shape_or_data.data())) { + data.push_back(val * scale + bias); + } + return data; + }(); shape_analysis->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs(shape, data));