-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-30292][SQL]Throw Exception when invalid string is cast to numeric type in ANSI mode #26933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
3ed7795
1686c6a
69ee231
74809d0
a336084
c0f8baf
f46181d
d3ffa3c
c7dbeef
7d0faa6
d454452
4b0149c
40afc54
2f845c3
0cb4edc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -482,6 +482,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| // LongConverter | ||
| private[this] def castToLong(from: DataType): Any => Any = from match { | ||
| case StringType if ansiEnabled => | ||
| val result = new LongWrapper() | ||
| buildCast[UTF8String](_, s => if (s.toLong(result)) result.value | ||
iRakson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| else throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")) | ||
| case StringType => | ||
| val result = new LongWrapper() | ||
| buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) | ||
|
|
@@ -499,6 +503,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| // IntConverter | ||
| private[this] def castToInt(from: DataType): Any => Any = from match { | ||
| case StringType if ansiEnabled => | ||
| val result = new IntWrapper() | ||
| buildCast[UTF8String](_, s => if (s.toInt(result)) result.value | ||
| else throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")) | ||
iRakson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| case StringType => | ||
| val result = new IntWrapper() | ||
| buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) | ||
|
|
@@ -523,7 +531,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| buildCast[UTF8String](_, s => if (s.toShort(result)) { | ||
| result.value.toShort | ||
| } else { | ||
| null | ||
| if (ansiEnabled) { | ||
| throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") | ||
| } else { | ||
| null | ||
| } | ||
| }) | ||
| case BooleanType => | ||
| buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) | ||
|
|
@@ -564,7 +576,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| buildCast[UTF8String](_, s => if (s.toByte(result)) { | ||
| result.value.toByte | ||
| } else { | ||
| null | ||
| if (ansiEnabled) { | ||
| throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") | ||
| } else { | ||
| null | ||
| } | ||
| }) | ||
| case BooleanType => | ||
| buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) | ||
|
|
@@ -636,7 +652,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| // Please refer to https://github.com/apache/spark/pull/26640 | ||
| changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target) | ||
| } catch { | ||
| case _: NumberFormatException => null | ||
| case _: NumberFormatException => | ||
| if (ansiEnabled) { | ||
| throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") | ||
| } else { | ||
| null | ||
| } | ||
| }) | ||
| case BooleanType => | ||
| buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) | ||
|
|
@@ -659,6 +680,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| // DoubleConverter | ||
| private[this] def castToDouble(from: DataType): Any => Any = from match { | ||
| case StringType if ansiEnabled => | ||
| buildCast[UTF8String](_, s => { | ||
| val doubleStr = s.toString | ||
| try doubleStr.toDouble catch { | ||
| case _: NumberFormatException => | ||
| val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false) | ||
| if(d == null) { | ||
| throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") | ||
| } else { | ||
| d.asInstanceOf[Double].doubleValue() | ||
| } | ||
| } | ||
| }) | ||
| case StringType => | ||
| buildCast[UTF8String](_, s => { | ||
| val doubleStr = s.toString | ||
|
|
@@ -679,6 +713,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| // FloatConverter | ||
| private[this] def castToFloat(from: DataType): Any => Any = from match { | ||
| case StringType if ansiEnabled => | ||
| buildCast[UTF8String](_, s => { | ||
| val floatStr = s.toString | ||
| try floatStr.toFloat catch { | ||
| case _: NumberFormatException => | ||
| val f = Cast.processFloatingPointSpecialLiterals(floatStr, true) | ||
| if (f == null) { | ||
|
||
| throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") | ||
| } else { | ||
| f.asInstanceOf[Float].floatValue() | ||
| } | ||
| } | ||
| }) | ||
| case StringType => | ||
| buildCast[UTF8String](_, s => { | ||
| val floatStr = s.toString | ||
|
|
@@ -1133,7 +1180,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim())); | ||
| ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} | ||
| } catch (java.lang.NumberFormatException e) { | ||
| $evNull = true; | ||
| if ($ansiEnabled) { | ||
cloud-fan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); | ||
| } else { | ||
| $evNull =true; | ||
| } | ||
| } | ||
| """ | ||
| case BooleanType => | ||
|
|
@@ -1354,7 +1405,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| """ | ||
| } | ||
|
|
||
| private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { | ||
| private[this] def castToByteCode( | ||
| from: DataType, | ||
| ctx: CodegenContext): CastFunction = from match { | ||
| case StringType => | ||
| val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) | ||
| (c, evPrim, evNull) => | ||
|
|
@@ -1363,7 +1416,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| if ($c.toByte($wrapper)) { | ||
| $evPrim = (byte) $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| if ($ansiEnabled) { | ||
iRakson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| } | ||
| $wrapper = null; | ||
| """ | ||
|
|
@@ -1394,7 +1451,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| if ($c.toShort($wrapper)) { | ||
| $evPrim = (short) $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| if ($ansiEnabled) { | ||
| throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| } | ||
| $wrapper = null; | ||
| """ | ||
|
|
@@ -1414,7 +1475,21 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| (c, evPrim, evNull) => code"$evPrim = (short) $c;" | ||
| } | ||
|
|
||
| private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { | ||
| private[this] def castToIntCode( | ||
| from: DataType, | ||
| ctx: CodegenContext): CastFunction = from match { | ||
| case StringType if ansiEnabled => | ||
| val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) | ||
| (c, evPrim, evNull) => | ||
| code""" | ||
| UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); | ||
| if ($c.toInt($wrapper)) { | ||
| $evPrim = $wrapper.value; | ||
| } else { | ||
| throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); | ||
| } | ||
| $wrapper = null; | ||
| """ | ||
| case StringType => | ||
| val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) | ||
| (c, evPrim, evNull) => | ||
|
|
@@ -1442,7 +1517,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| (c, evPrim, evNull) => code"$evPrim = (int) $c;" | ||
| } | ||
|
|
||
| private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { | ||
| private[this] def castToLongCode( | ||
| from: DataType, | ||
| ctx: CodegenContext): CastFunction = from match { | ||
| case StringType => | ||
| val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) | ||
|
|
||
|
|
@@ -1452,7 +1529,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| if ($c.toLong($wrapper)) { | ||
| $evPrim = $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| if ($ansiEnabled) { | ||
| throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| } | ||
| $wrapper = null; | ||
| """ | ||
|
|
@@ -1473,6 +1554,22 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { | ||
| from match { | ||
| case StringType if ansiEnabled => | ||
| val floatStr = ctx.freshVariable("floatStr", StringType) | ||
| (c, evPrim, evNull) => | ||
| code""" | ||
| final String $floatStr = $c.toString(); | ||
| try { | ||
| $evPrim = Float.valueOf($floatStr); | ||
| } catch (java.lang.NumberFormatException e) { | ||
| final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); | ||
| if (f == null) { | ||
|
||
| throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); | ||
| } else { | ||
| $evPrim = f.floatValue(); | ||
| } | ||
| } | ||
| """ | ||
| case StringType => | ||
| val floatStr = ctx.freshVariable("floatStr", StringType) | ||
| (c, evPrim, evNull) => | ||
|
|
@@ -1504,6 +1601,22 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { | ||
| from match { | ||
| case StringType if ansiEnabled => | ||
| val doubleStr = ctx.freshVariable("doubleStr", StringType) | ||
| (c, evPrim, evNull) => | ||
| code""" | ||
| final String $doubleStr = $c.toString(); | ||
| try { | ||
| $evPrim = Double.valueOf($doubleStr); | ||
| } catch (java.lang.NumberFormatException e) { | ||
| final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); | ||
| if (d == null) { | ||
| throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); | ||
| } else { | ||
| $evPrim = d.doubleValue(); | ||
| } | ||
| } | ||
| """ | ||
| case StringType => | ||
| val doubleStr = ctx.freshVariable("doubleStr", StringType) | ||
| (c, evPrim, evNull) => | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.