Skip to content

Commit 0d268e9

Browse files
committed
[SPARK-26706][SQL] Fix Cast$mayTruncate for bytes
1 parent f92d276 commit 0d268e9

3 files changed

Lines changed: 46 additions & 1 deletion

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ object Cast {
131131
private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
132132
val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from)
133133
val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to)
134-
toPrecedence > 0 && fromPrecedence > toPrecedence
134+
toPrecedence >= 0 && fromPrecedence > toPrecedence
135135
}
136136

137137
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.util.Random
2525
import org.apache.spark.SparkFunSuite
2626
import org.apache.spark.sql.Row
2727
import org.apache.spark.sql.catalyst.InternalRow
28+
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
2829
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2930
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
3031
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -955,4 +956,39 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
955956
val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType)
956957
checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]")
957958
}
959+
960+
test("SPARK-26706: Fix Cast.mayTruncate for bytes") {
961+
assert(!Cast.mayTruncate(ByteType, ByteType))
962+
assert(!Cast.mayTruncate(DecimalType.ByteDecimal, ByteType))
963+
assert(Cast.mayTruncate(ShortType, ByteType))
964+
assert(Cast.mayTruncate(IntegerType, ByteType))
965+
assert(Cast.mayTruncate(LongType, ByteType))
966+
assert(Cast.mayTruncate(FloatType, ByteType))
967+
assert(Cast.mayTruncate(DoubleType, ByteType))
968+
assert(Cast.mayTruncate(DecimalType.IntDecimal, ByteType))
969+
}
970+
971+
test("canSafeCast and mayTruncate must be consistent for numeric types") {
972+
import DataTypeTestUtils._
973+
974+
def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match {
975+
case (_, dt: DecimalType) => dt.isWiderThan(from)
976+
case (dt: DecimalType, _) => dt.isTighterThan(to)
977+
case _ => numericPrecedence.indexOf(from) <= numericPrecedence.indexOf(to)
978+
}
979+
980+
numericTypes.foreach { from =>
981+
val (safeTargetTypes, unsafeTargetTypes) = numericTypes.partition(to => isCastSafe(from, to))
982+
983+
safeTargetTypes.foreach { to =>
984+
assert(Cast.canSafeCast(from, to), s"It should be possible to safely cast $from to $to")
985+
assert(!Cast.mayTruncate(from, to), s"No truncation is expected when casting $from to $to")
986+
}
987+
988+
unsafeTargetTypes.foreach { to =>
989+
assert(!Cast.canSafeCast(from, to), s"It shouldn't be possible to safely cast $from to $to")
990+
assert(Cast.mayTruncate(from, to), s"Truncation is expected when casting $from to $to")
991+
}
992+
}
993+
}
958994
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,6 +1678,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
16781678
assert(serializer.serializer.size == 1)
16791679
checkAnswer(ds, Seq(Row("a"), Row("b"), Row("c")))
16801680
}
1681+
1682+
test("SPARK-26706: Fix Cast.mayTruncate for bytes") {
1683+
val thrownException = intercept[AnalysisException] {
1684+
spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte]
1685+
.map(b => b - 1)
1686+
.collect()
1687+
}
1688+
assert(thrownException.message.contains("Cannot up cast `id` from bigint to tinyint"))
1689+
}
16811690
}
16821691

16831692
case class TestDataUnion(x: Int, y: Int, z: Int)

0 commit comments

Comments
 (0)