Skip to content
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5c662f6
[SPARK-24598][SQL] Overflow on airthmetic operation returns incorrect…
mgaido91 Jun 20, 2018
fad75fa
fix scalastyle
mgaido91 Jun 20, 2018
8591417
fix ut failures
mgaido91 Jun 20, 2018
9c3df7d
use larger intermediate buffer for sum
mgaido91 Jun 21, 2018
ebdaf61
fix UT error
mgaido91 Jun 22, 2018
a0b862e
allow precision loss when converting decimal to long
mgaido91 Jun 22, 2018
7bba22f
Merge branch 'master' into SPARK-24598
mgaido91 Jul 16, 2018
77f26f2
Merge branch 'master' of github.com:apache/spark into SPARK-24598
mgaido91 Jun 21, 2019
74cd0a4
Handle NaN
mgaido91 Jun 22, 2019
2cfd946
Add conf flag for checking overflow
mgaido91 Jun 26, 2019
25c853c
fix
mgaido91 Jun 26, 2019
ff02dca
Merge branch 'master' of github.com:apache/spark into SPARK-24598
mgaido91 Jun 27, 2019
00fae1d
fix tests
mgaido91 Jun 27, 2019
8e9715c
change default value and fix tests
mgaido91 Jun 28, 2019
1dff779
Merge branch 'master' into SPARK-24598
mgaido91 Jul 14, 2019
38fc1f4
fix typo
mgaido91 Jul 15, 2019
0d5e510
Merge branch 'SPARK-24598' of github.com:mgaido91/spark into SPARK-24598
mgaido91 Jul 15, 2019
37e19ce
fix
mgaido91 Jul 15, 2019
eb37ee7
Merge branch 'master' of github.com:apache/spark into SPARK-24598
mgaido91 Jul 20, 2019
98bbf83
address comments
mgaido91 Jul 20, 2019
650ea79
fix
mgaido91 Jul 20, 2019
1d20f73
address comments
mgaido91 Jul 26, 2019
538e332
address comments
mgaido91 Jul 26, 2019
3de4bfb
fix
mgaido91 Jul 27, 2019
3baecbc
fixes
mgaido91 Jul 27, 2019
a247f9f
fix unaryminus
mgaido91 Jul 27, 2019
582d148
address comments
mgaido91 Jul 30, 2019
b809a3f
fix
mgaido91 Jul 30, 2019
ce3ed2b
address comments
mgaido91 Jul 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
case _ => DoubleType
}

private lazy val sumDataType = resultType
private lazy val sumDataType = child.dataType match {
case LongType => DecimalType.BigIntDecimal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this reduces the possibilities of overflow, but do other databases have the same behavior? Changing the data type of an expression is a breaking change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not changing the result data type of the expression, this is changing only the internal buffer type in roder to let "temporary" overflows to happen without any exception. Please consider the case when you have:

Long.MaxValue
100
-1000

The result should be Long.MaxValue - 900. With this buffer type larger than the returned type, we can overflow temporarily when we add Long.MaxValue and 100 and then get back to a valid value when we add -1000. So with this change we return the correct value. Other DBs behave in this way too.

case _ => resultType
}

private lazy val castToResultType: (Expression) => Expression =
if (sumDataType == resultType) (e: Expression) => e else (e: Expression) => Cast(e, resultType)

private lazy val sum = AttributeReference("sum", sumDataType)()

Expand Down Expand Up @@ -89,5 +95,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
)
}

override lazy val evaluateExpression: Expression = sum
override lazy val evaluateExpression: Expression = castToResultType(sum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can inline the logic of castToResultType here, as evaluateExpression is a lazy val.

}
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ case class Abs(child: Expression)

abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {

protected val checkOverflow = SQLConf.get.getConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK)

override def dataType: DataType = left.dataType

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
Expand All @@ -129,17 +131,41 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
def calendarIntervalMethod: String =
sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode")

def checkOverflowCode(result: String, op1: String, op2: String): String =
sys.error("BinaryArithmetics must override either checkOverflowCode or genCode")

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case _: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)")
// In the following cases, overflow can happen, so we need to check the result is valid.
// Otherwise we throw an ArithmeticException
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current Spark we are very conservative about runtime error, as it may break the data pipeline middle away, and returning null is a commonly used strategy. Shall we follow it here? We can throw exception when we have a strict mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I am quite against returning null. It is not something a user expects, so he/she is likely not to check for it (when I see a NULL myself, I think that one of the 2 operands was NULL, not that an overflow occurred), so he/she won't realize the issue and would find corrupted data. Moreover, this is not how RDBMS behaves and it is against SQL standard. So I think that the behavior which was chosen for DECIMAL was wrong and I'd prefer not to introduce the same behavior also in other places.

Anyway I see your point about consistency over the codebase and it makes sense.

I'd love to know @gatorsmile and @hvanhovell's opinions too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile @hvanhovell do you have time to check this and give your opinion here? Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we fix it to return null and throw an exception when the configuration is on? Overflowed value is already quite pointless and changing the behaviour to return null by default might not be so harmful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment @HyukjinKwon . The issue with returning null is described in #21599 (comment) (moreover this behavior would be against SQL standard, but that's a minor point).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I am working on #25300, +1 for returning null on overflow by default with @cloud-fan @HyukjinKwon . This makes arithmetic operations and casting consistent.
I think null is better than a non-sense number on overflow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't agree on this proposal for the reasons explained in the comment I mentioned earlier. Making all arithmetic operations nullable is a too broad change and I think it is not worth for the little benefit.

// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val overflowCheck = if (checkOverflow) {
checkOverflowCode(ev.value, eval1, eval2)
} else {
""
}
s"""
|${ev.value} = (${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2);
|$overflowCheck
""".stripMargin
})
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val overflowCheck = if (checkOverflow) {
checkOverflowCode(ev.value, eval1, eval2)
} else {
""
}
s"""
|${ev.value} = $eval1 $symbol $eval2;
|$overflowCheck
""".stripMargin
})
}
}

Expand Down Expand Up @@ -170,9 +196,27 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
if (dataType.isInstanceOf[CalendarIntervalType]) {
input1.asInstanceOf[CalendarInterval].add(input2.asInstanceOf[CalendarInterval])
} else {
numeric.plus(input1, input2)
val result = numeric.plus(input1, input2)
if (checkOverflow) {
val resSignum = numeric.signum(result)
val input1Signum = numeric.signum(input1)
val input2Signum = numeric.signum(input2)
if (resSignum != -1 && input1Signum == -1 && input2Signum == -1
|| resSignum != 1 && input1Signum == 1 && input2Signum == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the checking is different from the codegen version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because since here we are using numeric in order for the code to be generic for alla data types, numeric exposes signum but not comparison operators, which are instead available in codegen since the datatypes are specified

throw new ArithmeticException(s"$input1 + $input2 caused overflow.")
}
}
result
}
}

override def checkOverflowCode(result: String, op1: String, op2: String): String = {
s"""
|if ($result >= 0 && $op1 < 0 && $op2 < 0 || $result <= 0 && $op1 > 0 && $op2 > 0) {
| throw new ArithmeticException($op1 + " + " + $op2 + " caused overflow.");
|}
""".stripMargin
}
}

@ExpressionDescription(
Expand All @@ -198,9 +242,27 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
if (dataType.isInstanceOf[CalendarIntervalType]) {
input1.asInstanceOf[CalendarInterval].subtract(input2.asInstanceOf[CalendarInterval])
} else {
numeric.minus(input1, input2)
val result = numeric.minus(input1, input2)
if (checkOverflow) {
val resSignum = numeric.signum(result)
val input1Signum = numeric.signum(input1)
val input2Signum = numeric.signum(input2)
if (resSignum != 1 && input1Signum == 1 && input2Signum == -1
|| resSignum != -1 && input1Signum == -1 && input2Signum == 1) {
throw new ArithmeticException(s"$input1 - $input2 caused overflow.")
}
}
result
}
}

override def checkOverflowCode(result: String, op1: String, op2: String): String = {
s"""
|if ($result <= 0 && $op1 > 0 && $op2 < 0 || $result >= 0 && $op1 < 0 && $op2 > 0) {
| throw new ArithmeticException($op1 + " - " + $op2 + " caused overflow.");
|}
""".stripMargin
}
}

@ExpressionDescription(
Expand All @@ -219,7 +281,29 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val result = numeric.times(input1, input2)
if (checkOverflow) {
if (numeric.signum(result) != numeric.signum(input1) * numeric.signum(input2) &&
!(result.isInstanceOf[Double] && !result.asInstanceOf[Double].isNaN) &&
!(result.isInstanceOf[Float] && !result.asInstanceOf[Float].isNaN)) {
throw new ArithmeticException(s"$input1 * $input2 caused overflow.")
}
}
result
}

override def checkOverflowCode(result: String, op1: String, op2: String): String = {
val isNaNCheck = dataType match {
case DoubleType | FloatType => s" && !java.lang.Double.isNaN($result)"
case _ => ""
}
s"""
|if (Math.signum($result) != Math.signum($op1) * Math.signum($op2)$isNaNCheck) {
| throw new ArithmeticException($op1 + " * " + $op2 + " caused overflow.");
|}
""".stripMargin
}
}

// Common base trait for Divide and Remainder, since these two classes are almost identical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
}

protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2)

override def checkOverflowCode(result: String, op1: String, op2: String): String = ""
}

/**
Expand Down Expand Up @@ -83,6 +85,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
}

protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2)

override def checkOverflowCode(result: String, op1: String, op2: String): String = ""
}

/**
Expand Down Expand Up @@ -115,6 +119,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
}

protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2)

override def checkOverflowCode(result: String, op1: String, op2: String): String = ""
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val ARITHMETIC_OPERATION_OVERFLOW_CHECK = buildConf("spark.sql.arithmetic.checkOverflow")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check is vague here, how about spark.sql.arithmeticOperations.failOnOverFlow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion, I'll update accordingly

.doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " +
"exception if an overflow occurs. If it is false (default), in case of overflow a wrong " +
"result is returned.")
.internal()
.booleanConf
.createWithDefault(false)

val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE =
buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.math.{BigInteger, MathContext, RoundingMode}

import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.internal.SQLConf

/**
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
Expand Down Expand Up @@ -228,6 +229,12 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (decimalVal.eq(null)) {
longVal / POW_10(_scale)
} else {
if (SQLConf.get.getConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK)) {
// This will throw an exception if overflow occurs
if (decimalVal.compare(LONG_MIN_BIG_DEC) < 0 || decimalVal.compare(LONG_MAX_BIG_DEC) > 0) {
throw new ArithmeticException("Overflow")
}
}
decimalVal.longValue()
}
}
Expand Down Expand Up @@ -456,6 +463,9 @@ object Decimal {
private val LONG_MAX_BIG_INT = BigInteger.valueOf(JLong.MAX_VALUE)
private val LONG_MIN_BIG_INT = BigInteger.valueOf(JLong.MIN_VALUE)

private val LONG_MAX_BIG_DEC = BigDecimal.valueOf(JLong.MAX_VALUE)
private val LONG_MIN_BIG_DEC = BigDecimal.valueOf(JLong.MIN_VALUE)

def apply(value: Double): Decimal = new Decimal().set(value)

def apply(value: Long): Decimal = new Decimal().set(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Add(positiveIntLit, negativeIntLit), -1)
checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L)

DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegen(Add, tpe, tpe)
Seq("true", "false").foreach { checkOverflow =>
withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) {
DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Add, tpe, tpe)
}
}
}
}

Expand Down Expand Up @@ -100,8 +104,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt)
checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong)

DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegen(Subtract, tpe, tpe)
Seq("true", "false").foreach { checkOverflow =>
withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) {
DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Subtract, tpe, tpe)
}
}
}
}

Expand All @@ -118,8 +126,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt)
checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong)

DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegen(Multiply, tpe, tpe)
Seq("true", "false").foreach { checkOverflow =>
withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) {
DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Multiply, tpe, tpe)
}
}
}
}

Expand Down Expand Up @@ -376,4 +388,28 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2)
assert(ctx2.inlinedMutableStates.size == 1)
}

test("SPARK-24598: overflow on BigInt returns wrong result") {
val maxLongLiteral = Literal(Long.MaxValue)
val minLongLiteral = Literal(Long.MinValue)
val e1 = Add(maxLongLiteral, Literal(1L))
val e2 = Subtract(maxLongLiteral, Literal(-1L))
val e3 = Multiply(maxLongLiteral, Literal(2L))
val e4 = Add(minLongLiteral, minLongLiteral)
val e5 = Subtract(minLongLiteral, maxLongLiteral)
val e6 = Multiply(minLongLiteral, minLongLiteral)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we also test UnaryMinus? BTW do you know why the previous mistake in UnaryMinus code was not caught by the existing tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it was my fault, I was not testing "normal" cases with the flag turned on. I fixed it. UnaryMinus cases are already tested in its UT, I think it is useless to add them here too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, thanks!

withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "caused overflow")
}
}
withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "false") {
checkEvaluation(e1, Long.MinValue)
checkEvaluation(e2, Long.MinValue)
checkEvaluation(e3, -2L)
checkEvaluation(e4, 0L)
checkEvaluation(e5, 1L)
checkEvaluation(e6, 0L)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -956,6 +957,19 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]")
}

test("SPARK-24598: Cast to long should fail on overflow") {
val overflowCast = cast(Literal.create(Decimal(Long.MaxValue) + Decimal(1)), LongType)
val nonOverflowCast = cast(Literal.create(Decimal(Long.MaxValue)), LongType)
withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") {
checkExceptionInExpression[ArithmeticException](overflowCast, "Overflow")
checkEvaluation(nonOverflowCast, Long.MaxValue)
}
withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "false") {
checkEvaluation(overflowCast, Long.MinValue)
checkEvaluation(nonOverflowCast, Long.MaxValue)
}
}

test("up-cast") {
def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match {
case (_, dt: DecimalType) => dt.isWiderThan(from)
Expand Down
Loading