Skip to content

Commit 821db48

Browse files
committed
[SPARK-26233][SQL] CheckOverflow when encoding a decimal value
When we encode a Decimal from external source we don't check for overflow. That method is useful not only in order to enforce that we can represent the correct value in the specified range, but it also changes the underlying data to the right precision/scale. Since in our code generation we assume that a decimal has exactly the same precision and scale of its data type, missing to enforce it can lead to corrupted output/results when there are subsequent transformations. added UT Closes #23210 from mgaido91/SPARK-26233. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent c9fd14c commit 821db48

2 files changed

Lines changed: 11 additions & 2 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ object RowEncoder {
108108
returnNullable = false)
109109

110110
case d: DecimalType =>
111-
StaticInvoke(
111+
CheckOverflow(StaticInvoke(
112112
Decimal.getClass,
113113
d,
114114
"fromDecimal",
115115
inputObject :: Nil,
116-
returnNullable = false)
116+
returnNullable = false), d)
117117

118118
case StringType =>
119119
StaticInvoke(

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,6 +1547,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
15471547
df.where($"city".contains(new java.lang.Character('A'))),
15481548
Seq(Row("Amsterdam")))
15491549
}
1550+
1551+
test("SPARK-26233: serializer should enforce decimal precision and scale") {
1552+
val s = StructType(Seq(StructField("a", StringType), StructField("b", DecimalType(38, 8))))
1553+
val encoder = RowEncoder(s)
1554+
implicit val uEnc = encoder
1555+
val df = spark.range(2).map(l => Row(l.toString, BigDecimal.valueOf(l + 0.1111)))
1556+
checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
1557+
Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111))))
1558+
}
15501559
}
15511560

15521561
case class TestDataUnion(x: Int, y: Int, z: Int)

0 commit comments

Comments
 (0)