Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
*/
def canCast(from: DataType, to: DataType): Boolean

/**
* Returns the error message if casting from one type to another one is invalid.
*/
def typeCheckFailureMessage: String

override def toString: String = {
val ansi = if (ansiEnabled) "ansi_" else ""
s"${ansi}cast($child as ${dataType.simpleString})"
Expand All @@ -271,8 +276,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
if (canCast(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}")
TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage)
}
}

Expand Down Expand Up @@ -1755,6 +1759,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
} else {
Cast.canCast(from, to)
}

override def typeCheckFailureMessage: String = if (ansiEnabled) {
AnsiCast.typeCheckFailureMessage(child.dataType, dataType, SQLConf.ANSI_ENABLED.key, "false")
} else {
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
}
}

/**
Expand All @@ -1774,6 +1784,14 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St
override protected val ansiEnabled: Boolean = true

override def canCast(from: DataType, to: DataType): Boolean = AnsiCast.canCast(from, to)

// For now, this expression is only used in table insertion.
// If there are more scenarios for this expression, we should update the error message on type
// check failure.
override def typeCheckFailureMessage: String =
AnsiCast.typeCheckFailureMessage(child.dataType, dataType,
SQLConf.STORE_ASSIGNMENT_POLICY.key, SQLConf.StoreAssignmentPolicy.LEGACY.toString)

}

object AnsiCast {
Expand Down Expand Up @@ -1876,6 +1894,35 @@ object AnsiCast {

case _ => false
}

def typeCheckFailureMessage(
from: DataType,
to: DataType,
fallbackConfKey: String,
fallbackConfValue: String): String =
(from, to) match {
case (_: NumericType, TimestampType) =>
// scalastyle:off line.size.limit
s"""
| cannot cast ${from.catalogString} to ${to.catalogString}.
| To convert values from ${from.catalogString} to ${to.catalogString}, you can use functions TIMESTAMP_SECONDS/TIMESTAMP_MILLIS/TIMESTAMP_MICROS instead.
|""".stripMargin

case (_: ArrayType, StringType) =>
s"""
| cannot cast ${from.catalogString} to ${to.catalogString} with ANSI mode on.
| If you have to cast ${from.catalogString} to ${to.catalogString}, you can use the function ARRAY_JOIN or set $fallbackConfKey as $fallbackConfValue.
|""".stripMargin

case _ if Cast.canCast(from, to) =>
s"""
| cannot cast ${from.catalogString} to ${to.catalogString} with ANSI mode on.
| If you have to cast ${from.catalogString} to ${to.catalogString}, you can set $fallbackConfKey as $fallbackConfValue.
|""".stripMargin

case _ => s"cannot cast ${from.catalogString} to ${to.catalogString}"
// scalastyle:on line.size.limit
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.parallel.immutable.ParVector
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
import org.apache.spark.sql.catalyst.analysis.TypeCoercionSuite
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet}
Expand Down Expand Up @@ -841,12 +842,28 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented")
}

protected def setConfigurationHint: String

private def verifyCastFailure(c: CastBase, optionalExpectedMsg: Option[String] = None): Unit = {
val typeCheckResult = c.checkInputDataTypes()
assert(typeCheckResult.isFailure)
assert(typeCheckResult.isInstanceOf[TypeCheckFailure])
val message = typeCheckResult.asInstanceOf[TypeCheckFailure].message

if (optionalExpectedMsg.isDefined) {
assert(message.contains(optionalExpectedMsg.get))
} else {
assert(message.contains("with ANSI mode on"))
assert(message.contains(setConfigurationHint))
}
}

test("ANSI mode: disallow type conversions between Numeric types and Timestamp type") {
import DataTypeTestUtils.numericTypes
checkInvalidCastFromNumericType(TimestampType)
val timestampLiteral = Literal(1L, TimestampType)
numericTypes.foreach { numericType =>
assert(cast(timestampLiteral, numericType).checkInputDataTypes().isFailure)
verifyCastFailure(cast(timestampLiteral, numericType))
}
}

Expand All @@ -855,7 +872,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
checkInvalidCastFromNumericType(DateType)
val dateLiteral = Literal(1, DateType)
numericTypes.foreach { numericType =>
assert(cast(dateLiteral, numericType).checkInputDataTypes().isFailure)
verifyCastFailure(cast(dateLiteral, numericType))
}
}

Expand All @@ -880,9 +897,9 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
}

test("ANSI mode: disallow casting complex types as String type") {
assert(cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType).checkInputDataTypes().isFailure)
assert(cast(Literal.create(Map(1 -> "a")), StringType).checkInputDataTypes().isFailure)
assert(cast(Literal.create((1, "a", 0.1)), StringType).checkInputDataTypes().isFailure)
verifyCastFailure(cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType))
verifyCastFailure(cast(Literal.create(Map(1 -> "a")), StringType))
verifyCastFailure(cast(Literal.create((1, "a", 0.1)), StringType))
}

test("cast from invalid string to numeric should throw NumberFormatException") {
Expand Down Expand Up @@ -1489,6 +1506,9 @@ class CastSuiteWithAnsiModeOn extends AnsiCastSuiteBase {
case _ => Cast(Literal(v), targetType, timeZoneId)
}
}

override def setConfigurationHint: String =
s"set ${SQLConf.ANSI_ENABLED.key} as false"
}

/**
Expand All @@ -1511,6 +1531,10 @@ class AnsiCastSuiteWithAnsiModeOn extends AnsiCastSuiteBase {
case _ => AnsiCast(Literal(v), targetType, timeZoneId)
}
}

override def setConfigurationHint: String =
s"set ${SQLConf.STORE_ASSIGNMENT_POLICY.key} as" +
s" ${SQLConf.StoreAssignmentPolicy.LEGACY.toString}"
}

/**
Expand All @@ -1533,4 +1557,8 @@ class AnsiCastSuiteWithAnsiModeOff extends AnsiCastSuiteBase {
case _ => AnsiCast(Literal(v), targetType, timeZoneId)
}
}

override def setConfigurationHint: String =
s"set ${SQLConf.STORE_ASSIGNMENT_POLICY.key} as" +
s" ${SQLConf.StoreAssignmentPolicy.LEGACY.toString}"
}