diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.cc index f154cd8ddb5b4e..fb496c898bfb22 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.cc @@ -15,6 +15,14 @@ #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +bool ShouldUseData(pir::Value val) { + if (!val.defining_op()) return false; + if (val.defining_op()->isa()) { + return true; + } + return false; +} + bool InferSymbolicShapeElementWiseBinary( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { const auto &x_shapeordata = @@ -22,11 +30,8 @@ bool InferSymbolicShapeElementWiseBinary( std::vector shape_0; // For ElementWiseBinary ops, if the input tensor is from full op, the value // of fullop is useless, only the shape need doing broadcast - bool x_from_fullop = - op->operand_source(0).defining_op() - ? op->operand_source(0).defining_op()->isa() - : false; - if (!x_from_fullop && x_shapeordata.data().has_value()) { + if (ShouldUseData(op->operand_source(0)) && + x_shapeordata.data().has_value()) { shape_0 = x_shapeordata.data().value(); } else { shape_0 = x_shapeordata.shape(); @@ -35,11 +40,8 @@ bool InferSymbolicShapeElementWiseBinary( const auto &y_shapeordata = shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); std::vector shape_1; - bool y_from_fullop = - op->operand_source(1).defining_op() - ? op->operand_source(1).defining_op()->isa() - : false; - if (!y_from_fullop && y_shapeordata.data().has_value()) { + if (ShouldUseData(op->operand_source(1)) && + y_shapeordata.data().has_value()) { shape_1 = y_shapeordata.data().value(); } else { shape_1 = y_shapeordata.shape();