Skip to content

Commit 94f10e1

Browse files
wangyumkazuyukitanimura
authored andcommitted
[SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division results (apache#358)
* [SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division results ### What changes were proposed in this pull request? This PR fixes inaccurate Decimal multiplication and division results. ### Why are the changes needed? Decimal multiplication and division results may be inaccurate due to rounding issues. #### Multiplication: ``` scala> sql("select -14120025096157587712113961295153.858047 * -0.4652").show(truncate=false) +----------------------------------------------------+ |(-14120025096157587712113961295153.858047 * -0.4652)| +----------------------------------------------------+ |6568635674732509803675414794505.574764 | +----------------------------------------------------+ ``` The correct answer is `6568635674732509803675414794505.574763` Please note that the last digit is `3` instead of `4` as ``` scala> java.math.BigDecimal("-14120025096157587712113961295153.858047").multiply(java.math.BigDecimal("-0.4652")) val res21: java.math.BigDecimal = 6568635674732509803675414794505.5747634644 ``` Since the factional part `.574763` is followed by `4644`, it should not be rounded up. #### Division: ``` scala> sql("select -0.172787979 / 533704665545018957788294905796.5").show(truncate=false) +-------------------------------------------------+ |(-0.172787979 / 533704665545018957788294905796.5)| +-------------------------------------------------+ |-3.237521E-31 | +-------------------------------------------------+ ``` The correct answer is `-3.237520E-31` Please note that the last digit is `0` instead of `1` as ``` scala> java.math.BigDecimal("-0.172787979").divide(java.math.BigDecimal("533704665545018957788294905796.5"), 100, java.math.RoundingMode.DOWN) val res22: java.math.BigDecimal = -3.237520489418037889998826491401059986665344697406144511563561222578738E-31 ``` Since the factional part `.237520` is followed by `4894...`, it should not be rounded up. ### Does this PR introduce _any_ user-facing change? Yes, users will see correct Decimal multiplication and division results. Directly multiplying and dividing with `org.apache.spark.sql.types.Decimal()` (not via SQL) will return 39 digit at maximum instead of 38 at maximum and round down instead of round half-up ### How was this patch tested? Test added ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#43678 from kazuyukitanimura/SPARK-45786. Authored-by: Kazuyuki Tanimura <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit 5ef3a84) Signed-off-by: Dongjoon Hyun <[email protected]> * [SPARK-45786][SQL][FOLLOWUP][TEST] Fix Decimal random number tests with ANSI enabled ### What changes were proposed in this pull request? This follow-up PR fixes the test for SPARK-45786 that is failing in GHA with SPARK_ANSI_SQL_MODE=true ### Why are the changes needed? The issue discovered in apache#43678 (comment) ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test updated ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#43853 from kazuyukitanimura/SPARK-45786-FollowUp. Authored-by: Kazuyuki Tanimura <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit 949de34) Signed-off-by: Dongjoon Hyun <[email protected]> --------- Signed-off-by: Dongjoon Hyun <[email protected]> Co-authored-by: Kazuyuki Tanimura <[email protected]>
1 parent fc23429 commit 94f10e1

3 files changed

Lines changed: 128 additions & 9 deletions

File tree

sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
499499

500500
def / (that: Decimal): Decimal =
501501
if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal,
502-
DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode))
502+
DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode))
503503

504504
def % (that: Decimal): Decimal =
505505
if (that.isZero) null
@@ -547,7 +547,11 @@ object Decimal {
547547

548548
val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong)
549549

550-
private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP)
550+
// SPARK-45786 Using RoundingMode.HALF_UP with MathContext may cause inaccurate SQL results
551+
// because TypeCoercion later rounds again. Instead, always round down and use 1 digit longer
552+
// precision than DecimalType.MAX_PRECISION. Then, TypeCoercion will properly round up/down
553+
// the last extra digit.
554+
private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.DOWN)
551555

552556
private[sql] val ZERO = Decimal(0)
553557
private[sql] val ONE = Decimal(1)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.math.RoundingMode
2021
import java.sql.{Date, Timestamp}
2122
import java.time.{Duration, Period}
2223
import java.time.temporal.ChronoUnit
@@ -231,6 +232,120 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
231232
}
232233
}
233234

235+
test("SPARK-45786: Decimal multiply, divide, remainder, quot") {
236+
// Some known cases
237+
checkEvaluation(
238+
Multiply(
239+
Literal(Decimal(BigDecimal("-14120025096157587712113961295153.858047"), 38, 6)),
240+
Literal(Decimal(BigDecimal("-0.4652"), 4, 4))
241+
),
242+
Decimal(BigDecimal("6568635674732509803675414794505.574763"))
243+
)
244+
checkEvaluation(
245+
Multiply(
246+
Literal(Decimal(BigDecimal("-240810500742726"), 15, 0)),
247+
Literal(Decimal(BigDecimal("-5677.6988688550027099967697071"), 29, 25))
248+
),
249+
Decimal(BigDecimal("1367249507675382200.164877854336665327"))
250+
)
251+
checkEvaluation(
252+
Divide(
253+
Literal(Decimal(BigDecimal("-0.172787979"), 9, 9)),
254+
Literal(Decimal(BigDecimal("533704665545018957788294905796.5"), 31, 1))
255+
),
256+
Decimal(BigDecimal("-3.237520E-31"))
257+
)
258+
checkEvaluation(
259+
Divide(
260+
Literal(Decimal(BigDecimal("-0.574302343618"), 12, 12)),
261+
Literal(Decimal(BigDecimal("-795826820326278835912868.106"), 27, 3))
262+
),
263+
Decimal(BigDecimal("7.21642358550E-25"))
264+
)
265+
266+
// Random tests
267+
val rand = scala.util.Random
268+
def makeNum(p: Int, s: Int): String = {
269+
val int1 = rand.nextLong()
270+
val int2 = rand.nextLong().abs
271+
val frac1 = rand.nextLong().abs
272+
val frac2 = rand.nextLong().abs
273+
s"$int1$int2".take(p - s + (int1 >>> 63).toInt) + "." + s"$frac1$frac2".take(s)
274+
}
275+
276+
(0 until 100).foreach { _ =>
277+
val p1 = rand.nextInt(38) + 1 // 1 <= p1 <= 38
278+
val s1 = rand.nextInt(p1 + 1) // 0 <= s1 <= p1
279+
val p2 = rand.nextInt(38) + 1
280+
val s2 = rand.nextInt(p2 + 1)
281+
282+
val n1 = makeNum(p1, s1)
283+
val n2 = makeNum(p2, s2)
284+
285+
val mulActual = Multiply(
286+
Literal(Decimal(BigDecimal(n1), p1, s1)),
287+
Literal(Decimal(BigDecimal(n2), p2, s2))
288+
)
289+
val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2))
290+
291+
val divActual = Divide(
292+
Literal(Decimal(BigDecimal(n1), p1, s1)),
293+
Literal(Decimal(BigDecimal(n2), p2, s2))
294+
)
295+
val divExact = new java.math.BigDecimal(n1)
296+
.divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN)
297+
298+
val remActual = Remainder(
299+
Literal(Decimal(BigDecimal(n1), p1, s1)),
300+
Literal(Decimal(BigDecimal(n2), p2, s2))
301+
)
302+
val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2))
303+
304+
val quotActual = IntegralDivide(
305+
Literal(Decimal(BigDecimal(n1), p1, s1)),
306+
Literal(Decimal(BigDecimal(n2), p2, s2))
307+
)
308+
val quotExact =
309+
new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2))
310+
311+
Seq(true, false).foreach { allowPrecLoss =>
312+
withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) {
313+
val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2)
314+
val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP))
315+
val mulExpected =
316+
if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult
317+
checkEvaluationOrException(mulActual, mulExpected)
318+
319+
val divType = Divide(null, null).resultDecimalType(p1, s1, p2, s2)
320+
val divResult = Decimal(divExact.setScale(divType.scale, RoundingMode.HALF_UP))
321+
val divExpected =
322+
if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult
323+
checkEvaluationOrException(divActual, divExpected)
324+
325+
val remType = Remainder(null, null).resultDecimalType(p1, s1, p2, s2)
326+
val remResult = Decimal(remExact.setScale(remType.scale, RoundingMode.HALF_UP))
327+
val remExpected =
328+
if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult
329+
checkEvaluationOrException(remActual, remExpected)
330+
331+
val quotType = IntegralDivide(null, null).resultDecimalType(p1, s1, p2, s2)
332+
val quotResult = Decimal(quotExact.setScale(quotType.scale, RoundingMode.HALF_UP))
333+
val quotExpected =
334+
if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult
335+
checkEvaluationOrException(quotActual, quotExpected.toLong)
336+
}
337+
}
338+
339+
def checkEvaluationOrException(actual: BinaryArithmetic, expected: Any): Unit =
340+
if (SQLConf.get.ansiEnabled && expected == null) {
341+
checkExceptionInExpression[SparkArithmeticException](actual,
342+
"NUMERIC_VALUE_OUT_OF_RANGE")
343+
} else {
344+
checkEvaluation(actual, expected)
345+
}
346+
}
347+
}
348+
234349
private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = {
235350
testFunc(_.toDouble)
236351
testFunc(Decimal(_))

sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ org.apache.spark.SparkArithmeticException
155155
"config" : "\"spark.sql.ansi.enabled\"",
156156
"precision" : "38",
157157
"scale" : "6",
158-
"value" : "1000000000000000000000000000000000000.00000000000000000000000000000000000000"
158+
"value" : "1000000000000000000000000000000000000.000000000000000000000000000000000000000"
159159
},
160160
"queryContext" : [ {
161161
"objectType" : "",
@@ -204,7 +204,7 @@ org.apache.spark.SparkArithmeticException
204204
"config" : "\"spark.sql.ansi.enabled\"",
205205
"precision" : "38",
206206
"scale" : "6",
207-
"value" : "10123456789012345678901234567890123456.00000000000000000000000000000000000000"
207+
"value" : "10123456789012345678901234567890123456.000000000000000000000000000000000000000"
208208
},
209209
"queryContext" : [ {
210210
"objectType" : "",
@@ -229,7 +229,7 @@ org.apache.spark.SparkArithmeticException
229229
"config" : "\"spark.sql.ansi.enabled\"",
230230
"precision" : "38",
231231
"scale" : "6",
232-
"value" : "101234567890123456789012345678901234.56000000000000000000000000000000000000"
232+
"value" : "101234567890123456789012345678901234.560000000000000000000000000000000000000"
233233
},
234234
"queryContext" : [ {
235235
"objectType" : "",
@@ -254,7 +254,7 @@ org.apache.spark.SparkArithmeticException
254254
"config" : "\"spark.sql.ansi.enabled\"",
255255
"precision" : "38",
256256
"scale" : "6",
257-
"value" : "10123456789012345678901234567890123.45600000000000000000000000000000000000"
257+
"value" : "10123456789012345678901234567890123.456000000000000000000000000000000000000"
258258
},
259259
"queryContext" : [ {
260260
"objectType" : "",
@@ -279,7 +279,7 @@ org.apache.spark.SparkArithmeticException
279279
"config" : "\"spark.sql.ansi.enabled\"",
280280
"precision" : "38",
281281
"scale" : "6",
282-
"value" : "1012345678901234567890123456789012.34560000000000000000000000000000000000"
282+
"value" : "1012345678901234567890123456789012.345600000000000000000000000000000000000"
283283
},
284284
"queryContext" : [ {
285285
"objectType" : "",
@@ -304,7 +304,7 @@ org.apache.spark.SparkArithmeticException
304304
"config" : "\"spark.sql.ansi.enabled\"",
305305
"precision" : "38",
306306
"scale" : "6",
307-
"value" : "101234567890123456789012345678901.23456000000000000000000000000000000000"
307+
"value" : "101234567890123456789012345678901.234560000000000000000000000000000000000"
308308
},
309309
"queryContext" : [ {
310310
"objectType" : "",
@@ -337,7 +337,7 @@ org.apache.spark.SparkArithmeticException
337337
"config" : "\"spark.sql.ansi.enabled\"",
338338
"precision" : "38",
339339
"scale" : "6",
340-
"value" : "101234567890123456789012345678901.23456000000000000000000000000000000000"
340+
"value" : "101234567890123456789012345678901.234560000000000000000000000000000000000"
341341
},
342342
"queryContext" : [ {
343343
"objectType" : "",

0 commit comments

Comments
 (0)