-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-26218][SQL] Overflow on arithmetic operations returns incorrect result #21599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 26 commits
5c662f6
fad75fa
8591417
9c3df7d
ebdaf61
a0b862e
7bba22f
77f26f2
74cd0a4
2cfd946
25c853c
ff02dca
00fae1d
8e9715c
1dff779
38fc1f4
0d5e510
37e19ce
eb37ee7
98bbf83
650ea79
1d20f73
538e332
3de4bfb
3baecbc
a247f9f
582d148
b809a3f
ce3ed2b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,7 +56,10 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast | |
| case _ => DoubleType | ||
| } | ||
|
|
||
| private lazy val sumDataType = resultType | ||
| private lazy val sumDataType = child.dataType match { | ||
| case LongType => DecimalType.BigIntDecimal | ||
| case _ => resultType | ||
| } | ||
|
|
||
| private lazy val sum = AttributeReference("sum", sumDataType)() | ||
|
|
||
|
|
@@ -89,5 +92,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast | |
| ) | ||
| } | ||
|
|
||
| override lazy val evaluateExpression: Expression = sum | ||
| override lazy val evaluateExpression: Expression = { | ||
| if (sumDataType == resultType) { | ||
| sum | ||
| } else { | ||
| Cast(sum, resultType) | ||
|
||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,17 +35,36 @@ import org.apache.spark.unsafe.types.CalendarInterval | |
| """) | ||
| case class UnaryMinus(child: Expression) extends UnaryExpression | ||
| with ExpectsInputTypes with NullIntolerant { | ||
| private val checkOverflow = SQLConf.get.arithmeticOperationOverflowCheck | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) | ||
|
|
||
| override def dataType: DataType = child.dataType | ||
|
|
||
| override def toString: String = s"-$child" | ||
|
|
||
| private lazy val numeric = TypeUtils.getNumeric(dataType) | ||
| private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { | ||
| case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") | ||
| case ByteType | ShortType if checkOverflow => | ||
| nullSafeCodeGen(ctx, ev, eval => { | ||
| val javaBoxedType = CodeGenerator.boxedType(dataType) | ||
| val javaType = CodeGenerator.javaType(dataType) | ||
| val originValue = ctx.freshName("origin") | ||
| s""" | ||
| |$javaType $originValue = ($javaType)($eval); | ||
| |if ($originValue == $javaBoxedType.MIN_VALUE) { | ||
| | throw new ArithmeticException("- " + $originValue + " caused overflow."); | ||
| |} | ||
| |${ev.value} = ($javaType)(-($originValue)); | ||
| """.stripMargin | ||
| }) | ||
| case IntegerType | LongType if checkOverflow => | ||
| nullSafeCodeGen(ctx, ev, eval => { | ||
| val mathClass = classOf[Math].getName | ||
| s"${ev.value} = $mathClass.negateExact(-($eval));" | ||
|
||
| }) | ||
| case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { | ||
| val originValue = ctx.freshName("origin") | ||
| // codegen would fail to compile if we just write (-($c)) | ||
|
|
@@ -117,6 +136,8 @@ case class Abs(child: Expression) | |
|
|
||
| abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { | ||
|
|
||
| protected val checkOverflow = SQLConf.get.arithmeticOperationOverflowCheck | ||
|
|
||
| override def dataType: DataType = left.dataType | ||
|
|
||
| override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess | ||
|
|
@@ -129,17 +150,57 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { | |
| def calendarIntervalMethod: String = | ||
| sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") | ||
|
|
||
| /** Name of the function for the exact version of this expression in [[Math]]. */ | ||
| def exactMathMethod: String = | ||
| sys.error("BinaryArithmetics must override either exactMathMethod or genCode") | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { | ||
| case _: DecimalType => | ||
| // Overflow is handled in the CheckOverflow operator | ||
| defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") | ||
| case CalendarIntervalType => | ||
| defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") | ||
| // byte and short are casted into int when add, minus, times or divide | ||
| case ByteType | ShortType => | ||
| defineCodeGen(ctx, ev, | ||
| (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") | ||
| case _ => | ||
| defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") | ||
| nullSafeCodeGen(ctx, ev, (eval1, eval2) => { | ||
| val tmpResult = ctx.freshName("tmpResult") | ||
| val overflowCheck = if (checkOverflow) { | ||
| val javaType = CodeGenerator.boxedType(dataType) | ||
| s""" | ||
| |if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) { | ||
| | throw new ArithmeticException($eval1 + " $symbol " + $eval2 + " caused overflow."); | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| "" | ||
| } | ||
| s""" | ||
| |${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2; | ||
| |$overflowCheck | ||
| |${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult); | ||
| """.stripMargin | ||
| }) | ||
| case IntegerType | LongType => | ||
| nullSafeCodeGen(ctx, ev, (eval1, eval2) => { | ||
| val operation = if (checkOverflow) { | ||
| val mathClass = classOf[Math].getName | ||
| s"$mathClass.$exactMathMethod($eval1, $eval2)" | ||
| } else { | ||
| s"$eval1 $symbol $eval2" | ||
| } | ||
| s""" | ||
| |${ev.value} = $operation; | ||
| """.stripMargin | ||
| }) | ||
| case DoubleType | FloatType => | ||
| // When Double/Float overflows, there can be 2 cases: | ||
| // - precision loss: according to SQL standard, the number is truncated; | ||
| // - returns (+/-)Infinite: same behavior also other DBs have (eg. Postgres) | ||
| nullSafeCodeGen(ctx, ev, (eval1, eval2) => { | ||
| s""" | ||
| |${ev.value} = $eval1 $symbol $eval2; | ||
mgaido91 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """.stripMargin | ||
| }) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -164,7 +225,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { | |
|
|
||
| override def calendarIntervalMethod: String = "add" | ||
|
|
||
| private lazy val numeric = TypeUtils.getNumeric(dataType) | ||
| private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) | ||
|
|
||
| protected override def nullSafeEval(input1: Any, input2: Any): Any = { | ||
| if (dataType.isInstanceOf[CalendarIntervalType]) { | ||
|
|
@@ -173,6 +234,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { | |
| numeric.plus(input1, input2) | ||
| } | ||
| } | ||
|
|
||
| override def exactMathMethod: String = "addExact" | ||
| } | ||
|
|
||
| @ExpressionDescription( | ||
|
|
@@ -192,7 +255,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti | |
|
|
||
| override def calendarIntervalMethod: String = "subtract" | ||
|
|
||
| private lazy val numeric = TypeUtils.getNumeric(dataType) | ||
| private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) | ||
|
|
||
| protected override def nullSafeEval(input1: Any, input2: Any): Any = { | ||
| if (dataType.isInstanceOf[CalendarIntervalType]) { | ||
|
|
@@ -201,6 +264,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti | |
| numeric.minus(input1, input2) | ||
| } | ||
| } | ||
|
|
||
| override def exactMathMethod: String = "subtractExact" | ||
| } | ||
|
|
||
| @ExpressionDescription( | ||
|
|
@@ -217,9 +282,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti | |
| override def symbol: String = "*" | ||
| override def decimalMethod: String = "$times" | ||
|
|
||
| private lazy val numeric = TypeUtils.getNumeric(dataType) | ||
| private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) | ||
|
|
||
| protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) | ||
|
|
||
| override def exactMathMethod: String = "multiplyExact" | ||
| } | ||
|
|
||
| // Common base trait for Divide and Remainder, since these two classes are almost identical | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1780,6 +1780,15 @@ object SQLConf { | |
| .booleanConf | ||
| .createWithDefault(false) | ||
|
|
||
| val ARITHMETIC_OPERATION_OVERFLOW_CHECK = | ||
|
||
| buildConf("spark.sql.arithmeticOperations.failOnOverFlow") | ||
| .doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " + | ||
| "exception if an overflow occurs. If it is false (default), in case of overflow a wrong " + | ||
| "result is returned.") | ||
| .internal() | ||
| .booleanConf | ||
| .createWithDefault(false) | ||
|
|
||
| val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE = | ||
| buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere") | ||
| .internal() | ||
|
|
@@ -2287,6 +2296,8 @@ class SQLConf extends Serializable with Logging { | |
|
|
||
| def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW) | ||
|
|
||
| def arithmeticOperationOverflowCheck: Boolean = getConf(ARITHMETIC_OPERATION_OVERFLOW_CHECK) | ||
|
|
||
| def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) | ||
|
|
||
| def continuousStreamingEpochBacklogQueueSize: Int = | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ import java.math.{BigInteger, MathContext, RoundingMode} | |
|
|
||
| import org.apache.spark.annotation.Unstable | ||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
||
| /** | ||
| * A mutable implementation of BigDecimal that can hold a Long if values are small enough. | ||
|
|
@@ -228,6 +229,12 @@ final class Decimal extends Ordered[Decimal] with Serializable { | |
| if (decimalVal.eq(null)) { | ||
| longVal / POW_10(_scale) | ||
| } else { | ||
| if (SQLConf.get.arithmeticOperationOverflowCheck) { | ||
|
||
| // This will throw an exception if overflow occurs | ||
| if (decimalVal.compare(LONG_MIN_BIG_DEC) < 0 || decimalVal.compare(LONG_MAX_BIG_DEC) > 0) { | ||
| throw new ArithmeticException("Overflow") | ||
| } | ||
| } | ||
| decimalVal.longValue() | ||
| } | ||
| } | ||
|
|
@@ -456,6 +463,9 @@ object Decimal { | |
| private val LONG_MAX_BIG_INT = BigInteger.valueOf(JLong.MAX_VALUE) | ||
| private val LONG_MIN_BIG_INT = BigInteger.valueOf(JLong.MIN_VALUE) | ||
|
|
||
| private val LONG_MAX_BIG_DEC = BigDecimal.valueOf(JLong.MAX_VALUE) | ||
| private val LONG_MIN_BIG_DEC = BigDecimal.valueOf(JLong.MIN_VALUE) | ||
|
|
||
| def apply(value: Double): Decimal = new Decimal().set(value) | ||
|
|
||
| def apply(value: Long): Decimal = new Decimal().set(value) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this reduces the possibilities of overflow, but do other databases have the same behavior? Changing the data type of an expression is a breaking change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not changing the result data type of the expression, this is changing only the internal buffer type in roder to let "temporary" overflows to happen without any exception. Please consider the case when you have:
The result should be
Long.MaxValue - 900. With this buffer type larger than the returned type, we can overflow temporarily when we addLong.MaxValueand100and then get back to a valid value when we add-1000. So with this change we return the correct value. Other DBs behave in this way too.