Skip to content

Commit e0efd21

Browse files
iRaksoncloud-fan
authored andcommitted
[SPARK-30292][SQL] Throw Exception when invalid string is cast to numeric type in ANSI mode
### What changes were proposed in this pull request? If spark.sql.ansi.enabled is set, throw exception when cast to any numeric type do not follow the ANSI SQL standards. ### Why are the changes needed? ANSI SQL standards do not allow invalid strings to get casted into numeric types and throw exception for that. Currently spark sql gives NULL in such cases. Before: `select cast('str' as decimal) => NULL` After : `select cast('str' as decimal) => invalid input syntax for type numeric: str` These results are after setting `spark.sql.ansi.enabled=true` ### Does this PR introduce any user-facing change? Yes. Now when ansi mode is on users will get arithmetic exception for invalid strings. ### How was this patch tested? Unit Tests Added. Closes #26933 from iRakson/castDecimalANSI. Lead-authored-by: root1 <[email protected]> Co-authored-by: iRakson <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 88fc8db commit e0efd21

8 files changed

Lines changed: 278 additions & 118 deletions

File tree

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,52 @@ public boolean toByte(IntWrapper intWrapper) {
12941294
return false;
12951295
}
12961296

1297+
/**
1298+
* Parses UTF8String(trimmed if needed) to long. This method is used when ANSI is enabled.
1299+
*
1300+
* @return If string contains valid numeric value then it returns the long value otherwise a
1301+
* NumberFormatException is thrown.
1302+
*/
1303+
public long toLongExact() {
1304+
LongWrapper result = new LongWrapper();
1305+
if (toLong(result)) {
1306+
return result.value;
1307+
}
1308+
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
1309+
}
1310+
1311+
/**
1312+
* Parses UTF8String(trimmed if needed) to int. This method is used when ANSI is enabled.
1313+
*
1314+
* @return If string contains valid numeric value then it returns the int value otherwise a
1315+
* NumberFormatException is thrown.
1316+
*/
1317+
public int toIntExact() {
1318+
IntWrapper result = new IntWrapper();
1319+
if (toInt(result)) {
1320+
return result.value;
1321+
}
1322+
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
1323+
}
1324+
1325+
public short toShortExact() {
1326+
int value = this.toIntExact();
1327+
short result = (short) value;
1328+
if (result == value) {
1329+
return result;
1330+
}
1331+
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
1332+
}
1333+
1334+
public byte toByteExact() {
1335+
int value = this.toIntExact();
1336+
byte result = (byte) value;
1337+
if (result == value) {
1338+
return result;
1339+
}
1340+
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
1341+
}
1342+
12971343
@Override
12981344
public String toString() {
12991345
return new String(getBytes(), StandardCharsets.UTF_8);

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
482482

483483
// LongConverter
484484
private[this] def castToLong(from: DataType): Any => Any = from match {
485+
case StringType if ansiEnabled =>
486+
buildCast[UTF8String](_, _.toLongExact())
485487
case StringType =>
486488
val result = new LongWrapper()
487489
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
@@ -499,6 +501,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
499501

500502
// IntConverter
501503
private[this] def castToInt(from: DataType): Any => Any = from match {
504+
case StringType if ansiEnabled =>
505+
buildCast[UTF8String](_, _.toIntExact())
502506
case StringType =>
503507
val result = new IntWrapper()
504508
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
@@ -518,6 +522,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
518522

519523
// ShortConverter
520524
private[this] def castToShort(from: DataType): Any => Any = from match {
525+
case StringType if ansiEnabled =>
526+
buildCast[UTF8String](_, _.toShortExact())
521527
case StringType =>
522528
val result = new IntWrapper()
523529
buildCast[UTF8String](_, s => if (s.toShort(result)) {
@@ -559,6 +565,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
559565

560566
// ByteConverter
561567
private[this] def castToByte(from: DataType): Any => Any = from match {
568+
case StringType if ansiEnabled =>
569+
buildCast[UTF8String](_, _.toByteExact())
562570
case StringType =>
563571
val result = new IntWrapper()
564572
buildCast[UTF8String](_, s => if (s.toByte(result)) {
@@ -636,7 +644,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
636644
// Please refer to https://github.com/apache/spark/pull/26640
637645
changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target)
638646
} catch {
639-
case _: NumberFormatException => null
647+
case _: NumberFormatException =>
648+
if (ansiEnabled) {
649+
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
650+
} else {
651+
null
652+
}
640653
})
641654
case BooleanType =>
642655
buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
@@ -664,7 +677,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
664677
val doubleStr = s.toString
665678
try doubleStr.toDouble catch {
666679
case _: NumberFormatException =>
667-
Cast.processFloatingPointSpecialLiterals(doubleStr, false)
680+
val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
681+
if(ansiEnabled && d == null) {
682+
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
683+
} else {
684+
d
685+
}
668686
}
669687
})
670688
case BooleanType =>
@@ -684,7 +702,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
684702
val floatStr = s.toString
685703
try floatStr.toFloat catch {
686704
case _: NumberFormatException =>
687-
Cast.processFloatingPointSpecialLiterals(floatStr, true)
705+
val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
706+
if (ansiEnabled && f == null) {
707+
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
708+
} else {
709+
f
710+
}
688711
}
689712
})
690713
case BooleanType =>
@@ -1128,12 +1151,17 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
11281151
from match {
11291152
case StringType =>
11301153
(c, evPrim, evNull) =>
1154+
val handleException = if (ansiEnabled) {
1155+
s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");"""
1156+
} else {
1157+
s"$evNull =true;"
1158+
}
11311159
code"""
11321160
try {
11331161
Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim()));
11341162
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)}
11351163
} catch (java.lang.NumberFormatException e) {
1136-
$evNull = true;
1164+
$handleException
11371165
}
11381166
"""
11391167
case BooleanType =>
@@ -1355,6 +1383,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
13551383
}
13561384

13571385
private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
1386+
case StringType if ansiEnabled =>
1387+
(c, evPrim, evNull) => code"$evPrim = $c.toByteExact();"
13581388
case StringType =>
13591389
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
13601390
(c, evPrim, evNull) =>
@@ -1386,6 +1416,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
13861416
private[this] def castToShortCode(
13871417
from: DataType,
13881418
ctx: CodegenContext): CastFunction = from match {
1419+
case StringType if ansiEnabled =>
1420+
(c, evPrim, evNull) => code"$evPrim = $c.toShortExact();"
13891421
case StringType =>
13901422
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
13911423
(c, evPrim, evNull) =>
@@ -1415,6 +1447,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
14151447
}
14161448

14171449
private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
1450+
case StringType if ansiEnabled =>
1451+
(c, evPrim, evNull) => code"$evPrim = $c.toIntExact();"
14181452
case StringType =>
14191453
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
14201454
(c, evPrim, evNull) =>
@@ -1443,9 +1477,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
14431477
}
14441478

14451479
private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
1480+
case StringType if ansiEnabled =>
1481+
(c, evPrim, evNull) => code"$evPrim = $c.toLongExact();"
14461482
case StringType =>
14471483
val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper])
1448-
14491484
(c, evPrim, evNull) =>
14501485
code"""
14511486
UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
@@ -1476,14 +1511,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
14761511
case StringType =>
14771512
val floatStr = ctx.freshVariable("floatStr", StringType)
14781513
(c, evPrim, evNull) =>
1514+
val handleNull = if (ansiEnabled) {
1515+
s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");"""
1516+
} else {
1517+
s"$evNull = true;"
1518+
}
14791519
code"""
14801520
final String $floatStr = $c.toString();
14811521
try {
14821522
$evPrim = Float.valueOf($floatStr);
14831523
} catch (java.lang.NumberFormatException e) {
14841524
final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
14851525
if (f == null) {
1486-
$evNull = true;
1526+
$handleNull
14871527
} else {
14881528
$evPrim = f.floatValue();
14891529
}
@@ -1507,14 +1547,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
15071547
case StringType =>
15081548
val doubleStr = ctx.freshVariable("doubleStr", StringType)
15091549
(c, evPrim, evNull) =>
1550+
val handleNull = if (ansiEnabled) {
1551+
s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");"""
1552+
} else {
1553+
s"$evNull = true;"
1554+
}
15101555
code"""
15111556
final String $doubleStr = $c.toString();
15121557
try {
15131558
$evPrim = Double.valueOf($doubleStr);
15141559
} catch (java.lang.NumberFormatException e) {
15151560
final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
15161561
if (d == null) {
1517-
$evNull = true;
1562+
$handleNull
15181563
} else {
15191564
$evPrim = d.doubleValue();
15201565
}

0 commit comments

Comments
 (0)