-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-26218][SQL] Overflow on arithmetic operations returns incorrect result #21599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
5c662f6
fad75fa
8591417
9c3df7d
ebdaf61
a0b862e
7bba22f
77f26f2
74cd0a4
2cfd946
25c853c
ff02dca
00fae1d
8e9715c
1dff779
38fc1f4
0d5e510
37e19ce
eb37ee7
98bbf83
650ea79
1d20f73
538e332
3de4bfb
3baecbc
a247f9f
582d148
b809a3f
ce3ed2b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,7 +56,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast | |
| case _ => DoubleType | ||
| } | ||
|
|
||
| private lazy val sumDataType = resultType | ||
| private lazy val sumDataType = child.dataType match { | ||
| case LongType => DecimalType.BigIntDecimal | ||
| case _ => resultType | ||
| } | ||
|
|
||
| private lazy val castToResultType: (Expression) => Expression = | ||
| if (sumDataType == resultType) (e: Expression) => e else (e: Expression) => Cast(e, resultType) | ||
|
|
||
| private lazy val sum = AttributeReference("sum", sumDataType)() | ||
|
|
||
|
|
@@ -89,5 +95,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast | |
| ) | ||
| } | ||
|
|
||
| override lazy val evaluateExpression: Expression = sum | ||
| override lazy val evaluateExpression: Expression = castToResultType(sum) | ||
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -117,6 +117,8 @@ case class Abs(child: Expression) | |
|
|
||
| abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { | ||
|
|
||
| protected val checkOverflow = SQLConf.get.getConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK) | ||
|
|
||
| override def dataType: DataType = left.dataType | ||
|
|
||
| override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess | ||
|
|
@@ -129,17 +131,41 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { | |
| def calendarIntervalMethod: String = | ||
| sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") | ||
|
|
||
| def checkOverflowCode(result: String, op1: String, op2: String): String = | ||
| sys.error("BinaryArithmetics must override either checkOverflowCode or genCode") | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { | ||
| case _: DecimalType => | ||
| defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") | ||
| case CalendarIntervalType => | ||
| defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") | ||
| // In the following cases, overflow can happen, so we need to check the result is valid. | ||
| // Otherwise we throw an ArithmeticException | ||
|
||
| // byte and short are casted into int when add, minus, times or divide | ||
| case ByteType | ShortType => | ||
| defineCodeGen(ctx, ev, | ||
| (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") | ||
| nullSafeCodeGen(ctx, ev, (eval1, eval2) => { | ||
| val overflowCheck = if (checkOverflow) { | ||
| checkOverflowCode(ev.value, eval1, eval2) | ||
| } else { | ||
| "" | ||
| } | ||
| s""" | ||
| |${ev.value} = (${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2); | ||
| |$overflowCheck | ||
| """.stripMargin | ||
| }) | ||
| case _ => | ||
| defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") | ||
| nullSafeCodeGen(ctx, ev, (eval1, eval2) => { | ||
| val overflowCheck = if (checkOverflow) { | ||
| checkOverflowCode(ev.value, eval1, eval2) | ||
| } else { | ||
| "" | ||
| } | ||
| s""" | ||
| |${ev.value} = $eval1 $symbol $eval2; | ||
mgaido91 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| |$overflowCheck | ||
| """.stripMargin | ||
| }) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -170,9 +196,27 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { | |
| if (dataType.isInstanceOf[CalendarIntervalType]) { | ||
| input1.asInstanceOf[CalendarInterval].add(input2.asInstanceOf[CalendarInterval]) | ||
| } else { | ||
| numeric.plus(input1, input2) | ||
| val result = numeric.plus(input1, input2) | ||
| if (checkOverflow) { | ||
| val resSignum = numeric.signum(result) | ||
| val input1Signum = numeric.signum(input1) | ||
| val input2Signum = numeric.signum(input2) | ||
| if (resSignum != -1 && input1Signum == -1 && input2Signum == -1 | ||
| || resSignum != 1 && input1Signum == 1 && input2Signum == 1) { | ||
|
||
| throw new ArithmeticException(s"$input1 + $input2 caused overflow.") | ||
| } | ||
| } | ||
| result | ||
| } | ||
| } | ||
|
|
||
| override def checkOverflowCode(result: String, op1: String, op2: String): String = { | ||
| s""" | ||
| |if ($result >= 0 && $op1 < 0 && $op2 < 0 || $result <= 0 && $op1 > 0 && $op2 > 0) { | ||
| | throw new ArithmeticException($op1 + " + " + $op2 + " caused overflow."); | ||
| |} | ||
| """.stripMargin | ||
| } | ||
| } | ||
|
|
||
| @ExpressionDescription( | ||
|
|
@@ -198,9 +242,27 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti | |
| if (dataType.isInstanceOf[CalendarIntervalType]) { | ||
| input1.asInstanceOf[CalendarInterval].subtract(input2.asInstanceOf[CalendarInterval]) | ||
| } else { | ||
| numeric.minus(input1, input2) | ||
| val result = numeric.minus(input1, input2) | ||
| if (checkOverflow) { | ||
| val resSignum = numeric.signum(result) | ||
| val input1Signum = numeric.signum(input1) | ||
| val input2Signum = numeric.signum(input2) | ||
| if (resSignum != 1 && input1Signum == 1 && input2Signum == -1 | ||
| || resSignum != -1 && input1Signum == -1 && input2Signum == 1) { | ||
| throw new ArithmeticException(s"$input1 - $input2 caused overflow.") | ||
| } | ||
| } | ||
| result | ||
| } | ||
| } | ||
|
|
||
| override def checkOverflowCode(result: String, op1: String, op2: String): String = { | ||
| s""" | ||
| |if ($result <= 0 && $op1 > 0 && $op2 < 0 || $result >= 0 && $op1 < 0 && $op2 > 0) { | ||
| | throw new ArithmeticException($op1 + " - " + $op2 + " caused overflow."); | ||
| |} | ||
| """.stripMargin | ||
| } | ||
| } | ||
|
|
||
| @ExpressionDescription( | ||
|
|
@@ -219,7 +281,29 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti | |
|
|
||
| private lazy val numeric = TypeUtils.getNumeric(dataType) | ||
|
|
||
| protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) | ||
| protected override def nullSafeEval(input1: Any, input2: Any): Any = { | ||
| val result = numeric.times(input1, input2) | ||
| if (checkOverflow) { | ||
| if (numeric.signum(result) != numeric.signum(input1) * numeric.signum(input2) && | ||
| !(result.isInstanceOf[Double] && !result.asInstanceOf[Double].isNaN) && | ||
| !(result.isInstanceOf[Float] && !result.asInstanceOf[Float].isNaN)) { | ||
| throw new ArithmeticException(s"$input1 * $input2 caused overflow.") | ||
| } | ||
| } | ||
| result | ||
| } | ||
|
|
||
| override def checkOverflowCode(result: String, op1: String, op2: String): String = { | ||
| val isNaNCheck = dataType match { | ||
| case DoubleType | FloatType => s" && !java.lang.Double.isNaN($result)" | ||
| case _ => "" | ||
| } | ||
| s""" | ||
| |if (Math.signum($result) != Math.signum($op1) * Math.signum($op2)$isNaNCheck) { | ||
| | throw new ArithmeticException($op1 + " * " + $op2 + " caused overflow."); | ||
| |} | ||
| """.stripMargin | ||
| } | ||
| } | ||
|
|
||
| // Common base trait for Divide and Remainder, since these two classes are almost identical | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1764,6 +1764,14 @@ object SQLConf { | |
| .booleanConf | ||
| .createWithDefault(false) | ||
|
|
||
| val ARITHMETIC_OPERATION_OVERFLOW_CHECK = buildConf("spark.sql.arithmetic.checkOverflow") | ||
|
||
| .doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " + | ||
| "exception if an overflow occurs. If it is false (default), in case of overflow a wrong " + | ||
| "result is returned.") | ||
| .internal() | ||
| .booleanConf | ||
| .createWithDefault(false) | ||
|
|
||
| val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE = | ||
| buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere") | ||
| .internal() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,8 +59,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper | |
| checkEvaluation(Add(positiveIntLit, negativeIntLit), -1) | ||
| checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) | ||
|
|
||
| DataTypeTestUtils.numericAndInterval.foreach { tpe => | ||
| checkConsistencyBetweenInterpretedAndCodegen(Add, tpe, tpe) | ||
| Seq("true", "false").foreach { checkOverflow => | ||
| withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) { | ||
| DataTypeTestUtils.numericAndInterval.foreach { tpe => | ||
| checkConsistencyBetweenInterpretedAndCodegenAllowingException(Add, tpe, tpe) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -100,8 +104,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper | |
| checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt) | ||
| checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) | ||
|
|
||
| DataTypeTestUtils.numericAndInterval.foreach { tpe => | ||
| checkConsistencyBetweenInterpretedAndCodegen(Subtract, tpe, tpe) | ||
| Seq("true", "false").foreach { checkOverflow => | ||
| withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) { | ||
| DataTypeTestUtils.numericAndInterval.foreach { tpe => | ||
| checkConsistencyBetweenInterpretedAndCodegenAllowingException(Subtract, tpe, tpe) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -118,8 +126,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper | |
| checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt) | ||
| checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) | ||
|
|
||
| DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => | ||
| checkConsistencyBetweenInterpretedAndCodegen(Multiply, tpe, tpe) | ||
| Seq("true", "false").foreach { checkOverflow => | ||
| withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) { | ||
| DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => | ||
| checkConsistencyBetweenInterpretedAndCodegenAllowingException(Multiply, tpe, tpe) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -376,4 +388,28 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper | |
| Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) | ||
| assert(ctx2.inlinedMutableStates.size == 1) | ||
| } | ||
|
|
||
| test("SPARK-24598: overflow on BigInt returns wrong result") { | ||
| val maxLongLiteral = Literal(Long.MaxValue) | ||
| val minLongLiteral = Literal(Long.MinValue) | ||
| val e1 = Add(maxLongLiteral, Literal(1L)) | ||
| val e2 = Subtract(maxLongLiteral, Literal(-1L)) | ||
| val e3 = Multiply(maxLongLiteral, Literal(2L)) | ||
| val e4 = Add(minLongLiteral, minLongLiteral) | ||
| val e5 = Subtract(minLongLiteral, maxLongLiteral) | ||
| val e6 = Multiply(minLongLiteral, minLongLiteral) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we also test
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it was my fault, I was not testing "normal" cases with the flag turned on. I fixed it.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense, thanks! |
||
| withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") { | ||
| Seq(e1, e2, e3, e4, e5, e6).foreach { e => | ||
| checkExceptionInExpression[ArithmeticException](e, "caused overflow") | ||
| } | ||
| } | ||
| withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "false") { | ||
| checkEvaluation(e1, Long.MinValue) | ||
| checkEvaluation(e2, Long.MinValue) | ||
| checkEvaluation(e3, -2L) | ||
| checkEvaluation(e4, 0L) | ||
| checkEvaluation(e5, 1L) | ||
| checkEvaluation(e6, 0L) | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this reduces the possibilities of overflow, but do other databases have the same behavior? Changing the data type of an expression is a breaking change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not changing the result data type of the expression, this is changing only the internal buffer type in roder to let "temporary" overflows to happen without any exception. Please consider the case when you have:
The result should be
Long.MaxValue - 900. With this buffer type larger than the returned type, we can overflow temporarily when we addLong.MaxValueand100and then get back to a valid value when we add-1000. So with this change we return the correct value. Other DBs behave in this way too.