Skip to content
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5c662f6
[SPARK-24598][SQL] Overflow on airthmetic operation returns incorrect…
mgaido91 Jun 20, 2018
fad75fa
fix scalastyle
mgaido91 Jun 20, 2018
8591417
fix ut failures
mgaido91 Jun 20, 2018
9c3df7d
use larger intermediate buffer for sum
mgaido91 Jun 21, 2018
ebdaf61
fix UT error
mgaido91 Jun 22, 2018
a0b862e
allow precision loss when converting decimal to long
mgaido91 Jun 22, 2018
7bba22f
Merge branch 'master' into SPARK-24598
mgaido91 Jul 16, 2018
77f26f2
Merge branch 'master' of github.com:apache/spark into SPARK-24598
mgaido91 Jun 21, 2019
74cd0a4
Handle NaN
mgaido91 Jun 22, 2019
2cfd946
Add conf flag for checking overflow
mgaido91 Jun 26, 2019
25c853c
fix
mgaido91 Jun 26, 2019
ff02dca
Merge branch 'master' of github.com:apache/spark into SPARK-24598
mgaido91 Jun 27, 2019
00fae1d
fix tests
mgaido91 Jun 27, 2019
8e9715c
change default value and fix tests
mgaido91 Jun 28, 2019
1dff779
Merge branch 'master' into SPARK-24598
mgaido91 Jul 14, 2019
38fc1f4
fix typo
mgaido91 Jul 15, 2019
0d5e510
Merge branch 'SPARK-24598' of github.com:mgaido91/spark into SPARK-24598
mgaido91 Jul 15, 2019
37e19ce
fix
mgaido91 Jul 15, 2019
eb37ee7
Merge branch 'master' of github.com:apache/spark into SPARK-24598
mgaido91 Jul 20, 2019
98bbf83
address comments
mgaido91 Jul 20, 2019
650ea79
fix
mgaido91 Jul 20, 2019
1d20f73
address comments
mgaido91 Jul 26, 2019
538e332
address comments
mgaido91 Jul 26, 2019
3de4bfb
fix
mgaido91 Jul 27, 2019
3baecbc
fixes
mgaido91 Jul 27, 2019
a247f9f
fix unaryminus
mgaido91 Jul 27, 2019
582d148
address comments
mgaido91 Jul 30, 2019
b809a3f
fix
mgaido91 Jul 30, 2019
ce3ed2b
address comments
mgaido91 Jul 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
case _ => DoubleType
}

private lazy val sumDataType = resultType
private lazy val sumDataType = child.dataType match {
case LongType => DecimalType.BigIntDecimal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this reduces the possibilities of overflow, but do other databases have the same behavior? Changing the data type of an expression is a breaking change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not changing the result data type of the expression, this is changing only the internal buffer type in roder to let "temporary" overflows to happen without any exception. Please consider the case when you have:

Long.MaxValue
100
-1000

The result should be Long.MaxValue - 900. With this buffer type larger than the returned type, we can overflow temporarily when we add Long.MaxValue and 100 and then get back to a valid value when we add -1000. So with this change we return the correct value. Other DBs behave in this way too.

case _ => resultType
}

private lazy val sum = AttributeReference("sum", sumDataType)()

Expand Down Expand Up @@ -89,5 +92,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
)
}

override lazy val evaluateExpression: Expression = sum
override lazy val evaluateExpression: Expression = {
if (sumDataType == resultType) {
sum
} else {
Cast(sum, resultType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some more thoughts, I think we should still use long as buffer type to sum long, at least by default. Adding long values is faster than adding decimal values, and we shouldn't introduce this performance regression silently.

We can have an option to use decimal as buffer type, to reduce the possibility of overflow. But it should be opt-in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I agree. Shall we revert this change here and do this in another PR?

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,36 @@ import org.apache.spark.unsafe.types.CalendarInterval
""")
case class UnaryMinus(child: Expression) extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {
private val checkOverflow = SQLConf.get.arithmeticOperationOverflowCheck

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

override def dataType: DataType = child.dataType

override def toString: String = s"-$child"

private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case ByteType | ShortType if checkOverflow =>
nullSafeCodeGen(ctx, ev, eval => {
val javaBoxedType = CodeGenerator.boxedType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val originValue = ctx.freshName("origin")
s"""
|$javaType $originValue = ($javaType)($eval);
|if ($originValue == $javaBoxedType.MIN_VALUE) {
| throw new ArithmeticException("- " + $originValue + " caused overflow.");
|}
|${ev.value} = ($javaType)(-($originValue));
""".stripMargin
})
case IntegerType | LongType if checkOverflow =>
nullSafeCodeGen(ctx, ev, eval => {
val mathClass = classOf[Math].getName
s"${ev.value} = $mathClass.negateExact(-($eval));"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, shouldn't it be $mathClass.negateExact($eval)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmh...it should indeed...I am wondering how could test cases pass...

})
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
// codegen would fail to compile if we just write (-($c))
Expand Down Expand Up @@ -117,6 +136,8 @@ case class Abs(child: Expression)

abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {

protected val checkOverflow = SQLConf.get.arithmeticOperationOverflowCheck

override def dataType: DataType = left.dataType

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
Expand All @@ -129,17 +150,57 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
def calendarIntervalMethod: String =
sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode")

/** Name of the function for the exact version of this expression in [[Math]]. */
def exactMathMethod: String =
sys.error("BinaryArithmetics must override either exactMathMethod or genCode")

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case _: DecimalType =>
// Overflow is handled in the CheckOverflow operator
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)")
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val tmpResult = ctx.freshName("tmpResult")
val overflowCheck = if (checkOverflow) {
val javaType = CodeGenerator.boxedType(dataType)
s"""
|if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) {
| throw new ArithmeticException($eval1 + " $symbol " + $eval2 + " caused overflow.");
|}
""".stripMargin
} else {
""
}
s"""
|${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2;
|$overflowCheck
|${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult);
""".stripMargin
})
case IntegerType | LongType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val operation = if (checkOverflow) {
val mathClass = classOf[Math].getName
s"$mathClass.$exactMathMethod($eval1, $eval2)"
} else {
s"$eval1 $symbol $eval2"
}
s"""
|${ev.value} = $operation;
""".stripMargin
})
case DoubleType | FloatType =>
// When Double/Float overflows, there can be 2 cases:
// - precision loss: according to SQL standard, the number is truncated;
// - returns (+/-)Infinite: same behavior also other DBs have (eg. Postgres)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
|${ev.value} = $eval1 $symbol $eval2;
""".stripMargin
})
}
}

Expand All @@ -164,7 +225,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {

override def calendarIntervalMethod: String = "add"

private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[CalendarIntervalType]) {
Expand All @@ -173,6 +234,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
numeric.plus(input1, input2)
}
}

override def exactMathMethod: String = "addExact"
}

@ExpressionDescription(
Expand All @@ -192,7 +255,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti

override def calendarIntervalMethod: String = "subtract"

private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[CalendarIntervalType]) {
Expand All @@ -201,6 +264,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
numeric.minus(input1, input2)
}
}

override def exactMathMethod: String = "subtractExact"
}

@ExpressionDescription(
Expand All @@ -217,9 +282,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
override def symbol: String = "*"
override def decimalMethod: String = "$times"

private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)

override def exactMathMethod: String = "multiplyExact"
}

// Common base trait for Divide and Remainder, since these two classes are almost identical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,13 @@ object TypeUtils {
}
}

def getNumeric(t: DataType): Numeric[Any] =
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = {
if (exactNumericRequired) {
t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]]
} else {
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
}
}

def getInterpretedOrdering(t: DataType): Ordering[Any] = {
t match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val ARITHMETIC_OPERATION_OVERFLOW_CHECK =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we rename it to be consistent with the config name spark.sql.arithmeticOperations.failOnOverFlow?

buildConf("spark.sql.arithmeticOperations.failOnOverFlow")
.doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " +
"exception if an overflow occurs. If it is false (default), in case of overflow a wrong " +
"result is returned.")
.internal()
.booleanConf
.createWithDefault(false)

val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE =
buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere")
.internal()
Expand Down Expand Up @@ -2287,6 +2296,8 @@ class SQLConf extends Serializable with Logging {

def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW)

def arithmeticOperationOverflowCheck: Boolean = getConf(ARITHMETIC_OPERATION_OVERFLOW_CHECK)

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)

def continuousStreamingEpochBacklogQueueSize: Int =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ abstract class NumericType extends AtomicType {
// desugared by the compiler into an argument to the objects constructor. This means there is no
// longer a no argument constructor and thus the JVM cannot serialize the object anymore.
private[sql] val numeric: Numeric[InternalType]

private[sql] def exactNumeric: Numeric[InternalType] = numeric
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ByteType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Byte]]
private[sql] val integral = implicitly[Integral[Byte]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = ByteExactNumeric

/**
* The default size of a value of the ByteType is 1 byte.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.math.{BigInteger, MathContext, RoundingMode}

import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.internal.SQLConf

/**
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
Expand Down Expand Up @@ -228,6 +229,12 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (decimalVal.eq(null)) {
longVal / POW_10(_scale)
} else {
if (SQLConf.get.arithmeticOperationOverflowCheck) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we use DECIMAL_OPERATIONS_NULL_ON_OVERFLOW config instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW we can do it in another PR, let's focus on non-decimal arithmetic operation in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we need to handle more cases as well, e.g. long to int, int to byte, etc. AFAIK @gengliangwang is working on cast recently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, I had to do this in order to have some test cases work. IIRC it is related to the sum buffer. I'd consider quite challenging decide the right config among these 2 honestly.... It may be quite counter-intuitive for a user that a sum of long can be affected by DECIMAL_OPERATIONS_NULL_ON_OVERFLOW ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to do this in order to have some test cases work.

Not anymore if we revert #21599 (comment) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, exactly. We can discuss on this when we do the PR for addressing the sum buffer datatype.

// This will throw an exception if overflow occurs
if (decimalVal.compare(LONG_MIN_BIG_DEC) < 0 || decimalVal.compare(LONG_MAX_BIG_DEC) > 0) {
throw new ArithmeticException("Overflow")
}
}
decimalVal.longValue()
}
}
Expand Down Expand Up @@ -456,6 +463,9 @@ object Decimal {
private val LONG_MAX_BIG_INT = BigInteger.valueOf(JLong.MAX_VALUE)
private val LONG_MIN_BIG_INT = BigInteger.valueOf(JLong.MIN_VALUE)

private val LONG_MAX_BIG_DEC = BigDecimal.valueOf(JLong.MAX_VALUE)
private val LONG_MIN_BIG_DEC = BigDecimal.valueOf(JLong.MIN_VALUE)

def apply(value: Double): Decimal = new Decimal().set(value)

def apply(value: Long): Decimal = new Decimal().set(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class IntegerType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Int]]
private[sql] val integral = implicitly[Integral[Int]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = IntegerExactNumeric

/**
* The default size of a value of the IntegerType is 4 bytes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class LongType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Long]]
private[sql] val integral = implicitly[Integral[Long]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = LongExactNumeric

/**
* The default size of a value of the LongType is 8 bytes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ShortType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Short]]
private[sql] val integral = implicitly[Integral[Short]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = ShortExactNumeric

/**
* The default size of a value of the ShortType is 2 bytes.
Expand Down
Loading