Skip to content

Commit 921c22b

Browse files
aokolnychyidbtsai
authored andcommitted
[SPARK-26706][SQL] Fix illegalNumericPrecedence for ByteType
This PR contains a minor change in `Cast$mayTruncate` that fixes its logic for bytes. Right now, `mayTruncate(ByteType, LongType)` returns `false` while `mayTruncate(ShortType, LongType)` returns `true`. Consequently, `spark.range(1, 3).as[Byte]` and `spark.range(1, 3).as[Short]` behave differently. Potentially, this bug can silently corrupt someone's data. ```scala // executes silently even though Long is converted into Byte spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte] .map(b => b - 1) .show() +-----+ |value| +-----+ | -12| | -11| | -10| | -9| | -8| | -7| | -6| | -5| | -4| | -3| +-----+ // throws an AnalysisException: Cannot up cast `id` from bigint to smallint as it may truncate spark.range(Long.MaxValue - 10, Long.MaxValue).as[Short] .map(s => s - 1) .show() ``` This PR comes with a set of unit tests. Closes apache#23632 from aokolnychyi/cast-fix. Authored-by: Anton Okolnychyi <[email protected]> Signed-off-by: DB Tsai <[email protected]>
1 parent f36d0c5 commit 921c22b

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ object Cast {
131131
private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
132132
val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from)
133133
val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to)
134-
toPrecedence > 0 && fromPrecedence > toPrecedence
134+
toPrecedence >= 0 && fromPrecedence > toPrecedence
135135
}
136136

137137
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone}
2323
import org.apache.spark.SparkFunSuite
2424
import org.apache.spark.sql.Row
2525
import org.apache.spark.sql.catalyst.InternalRow
26+
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
2627
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2728
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
2829
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -953,4 +954,39 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
953954
val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType)
954955
checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]")
955956
}
957+
958+
test("SPARK-26706: Fix Cast.mayTruncate for bytes") {
959+
assert(!Cast.mayTruncate(ByteType, ByteType))
960+
assert(!Cast.mayTruncate(DecimalType.ByteDecimal, ByteType))
961+
assert(Cast.mayTruncate(ShortType, ByteType))
962+
assert(Cast.mayTruncate(IntegerType, ByteType))
963+
assert(Cast.mayTruncate(LongType, ByteType))
964+
assert(Cast.mayTruncate(FloatType, ByteType))
965+
assert(Cast.mayTruncate(DoubleType, ByteType))
966+
assert(Cast.mayTruncate(DecimalType.IntDecimal, ByteType))
967+
}
968+
969+
test("canSafeCast and mayTruncate must be consistent for numeric types") {
970+
import DataTypeTestUtils._
971+
972+
def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match {
973+
case (_, dt: DecimalType) => dt.isWiderThan(from)
974+
case (dt: DecimalType, _) => dt.isTighterThan(to)
975+
case _ => numericPrecedence.indexOf(from) <= numericPrecedence.indexOf(to)
976+
}
977+
978+
numericTypes.foreach { from =>
979+
val (safeTargetTypes, unsafeTargetTypes) = numericTypes.partition(to => isCastSafe(from, to))
980+
981+
safeTargetTypes.foreach { to =>
982+
assert(Cast.canSafeCast(from, to), s"It should be possible to safely cast $from to $to")
983+
assert(!Cast.mayTruncate(from, to), s"No truncation is expected when casting $from to $to")
984+
}
985+
986+
unsafeTargetTypes.foreach { to =>
987+
assert(!Cast.canSafeCast(from, to), s"It shouldn't be possible to safely cast $from to $to")
988+
assert(Cast.mayTruncate(from, to), s"Truncation is expected when casting $from to $to")
989+
}
990+
}
991+
}
956992
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
15671567
val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
15681568
checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
15691569
}
1570+
1571+
test("SPARK-26706: Fix Cast.mayTruncate for bytes") {
1572+
val thrownException = intercept[AnalysisException] {
1573+
spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte]
1574+
.map(b => b - 1)
1575+
.collect()
1576+
}
1577+
assert(thrownException.message.contains("Cannot up cast `id` from bigint to tinyint"))
1578+
}
15701579
}
15711580

15721581
case class TestDataUnion(x: Int, y: Int, z: Int)

0 commit comments

Comments
 (0)