Skip to content

Commit 3037d4a

Browse files
committed
[SPARK-22036][SQL] Decimal multiplication with high precision/scale often returns NULL
1 parent c2aeddf commit 3037d4a

8 files changed

Lines changed: 206 additions & 42 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -93,41 +93,46 @@ object DecimalPrecision extends TypeCoercionRule {
9393
case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e
9494

9595
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
96-
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
96+
val resultScale = max(s1, s2)
97+
val dt = DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
98+
resultScale)
9799
CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)
98100

99101
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
100-
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
102+
val resultScale = max(s1, s2)
103+
val dt = DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
104+
resultScale)
101105
CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)
102106

103107
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
104-
val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
108+
val resultType = DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
105109
val widerType = widerDecimalType(p1, s1, p2, s2)
106110
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
107111
resultType)
108112

109113
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
110-
var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
111-
var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
112-
val diff = (intDig + decDig) - DecimalType.MAX_SCALE
113-
if (diff > 0) {
114-
decDig -= diff / 2 + 1
115-
intDig = DecimalType.MAX_SCALE - decDig
116-
}
117-
val resultType = DecimalType.bounded(intDig + decDig, decDig)
114+
// From https://msdn.microsoft.com/en-us/library/ms190476.aspx
115+
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
116+
// Scale: max(6, s1 + p2 + 1)
117+
val intDig = p1 - s1 + s2
118+
val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1)
119+
val prec = intDig + scale
120+
val resultType = DecimalType.adjustPrecisionScale(prec, scale)
118121
val widerType = widerDecimalType(p1, s1, p2, s2)
119122
CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
120123
resultType)
121124

122125
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
123-
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
126+
val resultType = DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2),
127+
max(s1, s2))
124128
// resultType may have lower precision, so we cast them into wider type first.
125129
val widerType = widerDecimalType(p1, s1, p2, s2)
126130
CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
127131
resultType)
128132

129133
case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
130-
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
134+
val resultType = DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2),
135+
max(s1, s2))
131136
// resultType may have lower precision, so we cast them into wider type first.
132137
val widerType = widerDecimalType(p1, s1, p2, s2)
133138
CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
@@ -243,17 +248,43 @@ object DecimalPrecision extends TypeCoercionRule {
243248
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
244249
// and fixed-precision decimals in an expression with floats / doubles to doubles
245250
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
246-
(left.dataType, right.dataType) match {
247-
case (t: IntegralType, DecimalType.Fixed(p, s)) =>
248-
b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right))
249-
case (DecimalType.Fixed(p, s), t: IntegralType) =>
250-
b.makeCopy(Array(left, Cast(right, DecimalType.forType(t))))
251-
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
252-
b.makeCopy(Array(left, Cast(right, DoubleType)))
253-
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
254-
b.makeCopy(Array(Cast(left, DoubleType), right))
255-
case _ =>
256-
b
257-
}
251+
nondecimalLiteralAndDecimal(b).lift((left, right)).getOrElse(
252+
nondecimalNonliteralAndDecimal(b).applyOrElse((left.dataType, right.dataType),
253+
(_: (DataType, DataType)) => b))
258254
}
255+
256+
/**
257+
* Type coercion for BinaryOperator in which one side is a non-decimal literal numeric, and the
258+
* other side is a decimal.
259+
*/
260+
private def nondecimalLiteralAndDecimal(
261+
b: BinaryOperator): PartialFunction[(Expression, Expression), Expression] = {
262+
// Promote literal integers inside a binary expression with fixed-precision decimals to
263+
// decimals. The precision and scale are the ones needed by the integer value.
264+
case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType]
265+
&& l.dataType.isInstanceOf[IntegralType] =>
266+
b.makeCopy(Array(Cast(l, DecimalType.forLiteral(l)), r))
267+
case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType]
268+
&& r.dataType.isInstanceOf[IntegralType] =>
269+
b.makeCopy(Array(l, Cast(r, DecimalType.forLiteral(r))))
270+
}
271+
272+
/**
273+
* Type coercion for BinaryOperator in which one side is a non-decimal non-literal numeric, and
274+
* the other side is a decimal.
275+
*/
276+
private def nondecimalNonliteralAndDecimal(
277+
b: BinaryOperator): PartialFunction[(DataType, DataType), Expression] = {
278+
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
279+
// and fixed-precision decimals in an expression with floats / doubles to doubles
280+
case (t: IntegralType, DecimalType.Fixed(p, s)) =>
281+
b.makeCopy(Array(Cast(b.left, DecimalType.forType(t)), b.right))
282+
case (DecimalType.Fixed(_, _), t: IntegralType) =>
283+
b.makeCopy(Array(b.left, Cast(b.right, DecimalType.forType(t))))
284+
case (t, DecimalType.Fixed(_, _)) if isFloat(t) =>
285+
b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
286+
case (DecimalType.Fixed(_, _), t) if isFloat(t) =>
287+
b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
288+
}
289+
259290
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ object Literal {
5858
case s: Short => Literal(s, ShortType)
5959
case s: String => Literal(UTF8String.fromString(s), StringType)
6060
case b: Boolean => Literal(b, BooleanType)
61-
case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale))
61+
case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d))
6262
case d: JavaBigDecimal =>
6363
Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale()))
6464
case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale))

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag
2323

2424
import org.apache.spark.annotation.InterfaceStability
2525
import org.apache.spark.sql.AnalysisException
26-
import org.apache.spark.sql.catalyst.expressions.Expression
26+
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
2727

2828

2929
/**
@@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType {
117117
val MAX_SCALE = 38
118118
val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18)
119119
val USER_DEFAULT: DecimalType = DecimalType(10, 0)
120+
val MINIMUM_ADJUSTED_SCALE = 6
120121

121122
// The decimal types compatible with other numeric types
122123
private[sql] val ByteDecimal = DecimalType(3, 0)
@@ -136,10 +137,54 @@ object DecimalType extends AbstractDataType {
136137
case DoubleType => DoubleDecimal
137138
}
138139

140+
private[sql] def forLiteral(literal: Literal): DecimalType = literal.value match {
141+
case v: Short => fromBigDecimal(BigDecimal(v))
142+
case v: Int => fromBigDecimal(BigDecimal(v))
143+
case v: Long => fromBigDecimal(BigDecimal(v))
144+
case _ => forType(literal.dataType)
145+
}
146+
147+
private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = {
148+
DecimalType(Math.max(d.precision, d.scale), d.scale)
149+
}
150+
139151
private[sql] def bounded(precision: Int, scale: Int): DecimalType = {
140152
DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
141153
}
142154

155+
// scalastyle:off line.size.limit
156+
/**
157+
* Decimal implementation is based on Hive's one, which is itself inspired to SQLServer's one.
158+
* In particular, when a result precision is greater than {@link #MAX_PRECISION}, the
159+
* corresponding scale is reduced to prevent the integral part of a result from being truncated.
160+
*
161+
* For further reference, please see
162+
* https://blogs.msdn.microsoft.com/sqlprogrammability/2006/03/29/multiplication-and-division-with-numerics/.
163+
*
164+
* @param precision
165+
* @param scale
166+
* @return
167+
*/
168+
// scalastyle:on line.size.limit
169+
private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
170+
// Assumptions:
171+
// precision >= scale
172+
// scale >= 0
173+
if (precision <= MAX_PRECISION) {
174+
// Adjustment only needed when we exceed max precision
175+
DecimalType(precision, scale)
176+
} else {
177+
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
178+
val intDigits = precision - scale
179+
// If original scale less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
180+
// preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
181+
val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE)
182+
val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue)
183+
184+
DecimalType(MAX_PRECISION, adjustedScale)
185+
}
186+
}
187+
143188
override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT
144189

145190
override private[sql] def acceptsType(other: DataType): Boolean = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
408408
assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType)
409409
assertExpressionType(sum(Divide(1, 2.0f)), DoubleType)
410410
assertExpressionType(sum(Divide(1.0f, 2)), DoubleType)
411-
assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11))
412-
assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11))
411+
assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(22, 11))
412+
assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(26, 6))
413413
assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType)
414414
assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType)
415415
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
136136

137137
test("maximum decimals") {
138138
for (expr <- Seq(d1, d2, i, u)) {
139-
checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT)
140-
checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT)
139+
checkType(Add(expr, u), DecimalType(38, 17))
140+
checkType(Subtract(expr, u), DecimalType(38, 17))
141141
}
142142

143-
checkType(Multiply(d1, u), DecimalType(38, 19))
144-
checkType(Multiply(d2, u), DecimalType(38, 20))
145-
checkType(Multiply(i, u), DecimalType(38, 18))
146-
checkType(Multiply(u, u), DecimalType(38, 36))
143+
checkType(Multiply(d1, u), DecimalType(38, 16))
144+
checkType(Multiply(d2, u), DecimalType(38, 14))
145+
checkType(Multiply(i, u), DecimalType(38, 7))
146+
checkType(Multiply(u, u), DecimalType(38, 6))
147147

148-
checkType(Divide(u, d1), DecimalType(38, 18))
149-
checkType(Divide(u, d2), DecimalType(38, 19))
150-
checkType(Divide(u, i), DecimalType(38, 23))
151-
checkType(Divide(u, u), DecimalType(38, 18))
148+
checkType(Divide(u, d1), DecimalType(38, 17))
149+
checkType(Divide(u, d2), DecimalType(38, 16))
150+
checkType(Divide(u, i), DecimalType(38, 18))
151+
checkType(Divide(u, u), DecimalType(38, 6))
152152

153153
checkType(Remainder(d1, u), DecimalType(19, 18))
154154
checkType(Remainder(d2, u), DecimalType(21, 18))
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
-- tests for decimals handling in operations
2+
-- Spark draws its inspiration byt Hive implementation
3+
create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet;
4+
5+
insert into decimals_test values(1, 100.0, 999.0);
6+
insert into decimals_test values(2, 12345.123, 12345.123);
7+
insert into decimals_test values(3, 0.1234567891011, 1234.1);
8+
insert into decimals_test values(4, 123456789123456789.0, 1.123456789123456789);
9+
10+
-- test decimal operations
11+
select id, a+b, a-b, a*b, a/b from decimals_test order by id;
12+
13+
-- test operations between decimals and constants
14+
select id, a*10, b/10 from decimals_test order by id;
15+
16+
drop table decimals_test;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
-- Automatically generated by SQLQueryTestSuite
2+
-- Number of queries: 8
3+
4+
5+
-- !query 0
6+
create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet
7+
-- !query 0 schema
8+
struct<>
9+
-- !query 0 output
10+
11+
12+
13+
-- !query 1
14+
insert into decimals_test values(1, 100.0, 999.0)
15+
-- !query 1 schema
16+
struct<>
17+
-- !query 1 output
18+
19+
20+
21+
-- !query 2
22+
insert into decimals_test values(2, 12345.123, 12345.123)
23+
-- !query 2 schema
24+
struct<>
25+
-- !query 2 output
26+
27+
28+
29+
-- !query 3
30+
insert into decimals_test values(3, 0.1234567891011, 1234.1)
31+
-- !query 3 schema
32+
struct<>
33+
-- !query 3 output
34+
35+
36+
37+
-- !query 4
38+
insert into decimals_test values(4, 123456789123456789.0, 1.123456789123456789)
39+
-- !query 4 schema
40+
struct<>
41+
-- !query 4 output
42+
43+
44+
45+
-- !query 5
46+
select id, a+b, a-b, a*b, a/b from decimals_test order by id
47+
-- !query 5 schema
48+
struct<id:int,(a + b):decimal(38,17),(a - b):decimal(38,17),(a * b):decimal(38,6),(a / b):decimal(38,6)>
49+
-- !query 5 output
50+
1 1099 -899 99900 0.1001
51+
2 24690.246 0 152402061.885129 1
52+
3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001
53+
4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109
54+
55+
56+
-- !query 6
57+
select id, a*10, b/10 from decimals_test order by id
58+
-- !query 6 schema
59+
struct<id:int,(CAST(a AS DECIMAL(38,18)) * CAST(CAST(10 AS DECIMAL(2,0)) AS DECIMAL(38,18))):decimal(38,15),(CAST(b AS DECIMAL(38,18)) / CAST(CAST(10 AS DECIMAL(2,0)) AS DECIMAL(38,18))):decimal(38,18)>
60+
-- !query 6 output
61+
1 1000 99.9
62+
2 123451.23 1234.5123
63+
3 1.234567891011 123.41
64+
4 1234567891234567890 0.112345678912345679
65+
66+
67+
-- !query 7
68+
drop table decimals_test
69+
-- !query 7 schema
70+
struct<>
71+
-- !query 7 output
72+

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,15 +1526,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
15261526
checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"),
15271527
Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38))))
15281528
checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"),
1529-
Row(null))
1529+
Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38))))
15301530

15311531
checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333")))
15321532
checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333")))
15331533
checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333")))
15341534
checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"),
1535-
Row(BigDecimal("3.433333333333333333333333333", new MathContext(38))))
1535+
Row(BigDecimal("3.4333333333333333333", new MathContext(38))))
15361536
checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"),
1537-
Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38))))
1537+
Row(BigDecimal("3.4333333333333333333", new MathContext(38))))
15381538
}
15391539

15401540
test("SPARK-10215 Div of Decimal returns null") {

0 commit comments

Comments
 (0)