@@ -482,6 +482,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
482482
483483 // LongConverter
484484 private [this ] def castToLong (from : DataType ): Any => Any = from match {
485+ case StringType if ansiEnabled =>
486+ buildCast[UTF8String ](_, _.toLongExact())
485487 case StringType =>
486488 val result = new LongWrapper ()
487489 buildCast[UTF8String ](_, s => if (s.toLong(result)) result.value else null )
@@ -499,6 +501,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
499501
500502 // IntConverter
501503 private [this ] def castToInt (from : DataType ): Any => Any = from match {
504+ case StringType if ansiEnabled =>
505+ buildCast[UTF8String ](_, _.toIntExact())
502506 case StringType =>
503507 val result = new IntWrapper ()
504508 buildCast[UTF8String ](_, s => if (s.toInt(result)) result.value else null )
@@ -518,6 +522,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
518522
519523 // ShortConverter
520524 private [this ] def castToShort (from : DataType ): Any => Any = from match {
525+ case StringType if ansiEnabled =>
526+ buildCast[UTF8String ](_, _.toShortExact())
521527 case StringType =>
522528 val result = new IntWrapper ()
523529 buildCast[UTF8String ](_, s => if (s.toShort(result)) {
@@ -559,6 +565,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
559565
560566 // ByteConverter
561567 private [this ] def castToByte (from : DataType ): Any => Any = from match {
568+ case StringType if ansiEnabled =>
569+ buildCast[UTF8String ](_, _.toByteExact())
562570 case StringType =>
563571 val result = new IntWrapper ()
564572 buildCast[UTF8String ](_, s => if (s.toByte(result)) {
@@ -636,7 +644,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
636644 // Please refer to https://github.com/apache/spark/pull/26640
637645 changePrecision(Decimal (new JavaBigDecimal (s.toString.trim)), target)
638646 } catch {
639- case _ : NumberFormatException => null
647+ case _ : NumberFormatException =>
648+ if (ansiEnabled) {
649+ throw new NumberFormatException (s " invalid input syntax for type numeric: $s" )
650+ } else {
651+ null
652+ }
640653 })
641654 case BooleanType =>
642655 buildCast[Boolean ](_, b => toPrecision(if (b) Decimal .ONE else Decimal .ZERO , target))
@@ -664,7 +677,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
664677 val doubleStr = s.toString
665678 try doubleStr.toDouble catch {
666679 case _ : NumberFormatException =>
667- Cast .processFloatingPointSpecialLiterals(doubleStr, false )
680+ val d = Cast .processFloatingPointSpecialLiterals(doubleStr, false )
681+ if (ansiEnabled && d == null ) {
682+ throw new NumberFormatException (s " invalid input syntax for type numeric: $s" )
683+ } else {
684+ d
685+ }
668686 }
669687 })
670688 case BooleanType =>
@@ -684,7 +702,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
684702 val floatStr = s.toString
685703 try floatStr.toFloat catch {
686704 case _ : NumberFormatException =>
687- Cast .processFloatingPointSpecialLiterals(floatStr, true )
705+ val f = Cast .processFloatingPointSpecialLiterals(floatStr, true )
706+ if (ansiEnabled && f == null ) {
707+ throw new NumberFormatException (s " invalid input syntax for type numeric: $s" )
708+ } else {
709+ f
710+ }
688711 }
689712 })
690713 case BooleanType =>
@@ -1128,12 +1151,17 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
11281151 from match {
11291152 case StringType =>
11301153 (c, evPrim, evNull) =>
1154+ val handleException = if (ansiEnabled) {
1155+ s """ throw new NumberFormatException("invalid input syntax for type numeric: $c"); """
1156+ } else {
1157+ s " $evNull =true; "
1158+ }
11311159 code """
11321160 try {
11331161 Decimal $tmp = Decimal.apply(new java.math.BigDecimal( $c.toString().trim()));
11341162 ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)}
11351163 } catch (java.lang.NumberFormatException e) {
1136- $evNull = true;
1164+ $handleException
11371165 }
11381166 """
11391167 case BooleanType =>
@@ -1355,6 +1383,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
13551383 }
13561384
13571385 private [this ] def castToByteCode (from : DataType , ctx : CodegenContext ): CastFunction = from match {
1386+ case StringType if ansiEnabled =>
1387+ (c, evPrim, evNull) => code " $evPrim = $c.toByteExact(); "
13581388 case StringType =>
13591389 val wrapper = ctx.freshVariable(" intWrapper" , classOf [UTF8String .IntWrapper ])
13601390 (c, evPrim, evNull) =>
@@ -1386,6 +1416,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
13861416 private [this ] def castToShortCode (
13871417 from : DataType ,
13881418 ctx : CodegenContext ): CastFunction = from match {
1419+ case StringType if ansiEnabled =>
1420+ (c, evPrim, evNull) => code " $evPrim = $c.toShortExact(); "
13891421 case StringType =>
13901422 val wrapper = ctx.freshVariable(" intWrapper" , classOf [UTF8String .IntWrapper ])
13911423 (c, evPrim, evNull) =>
@@ -1415,6 +1447,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
14151447 }
14161448
14171449 private [this ] def castToIntCode (from : DataType , ctx : CodegenContext ): CastFunction = from match {
1450+ case StringType if ansiEnabled =>
1451+ (c, evPrim, evNull) => code " $evPrim = $c.toIntExact(); "
14181452 case StringType =>
14191453 val wrapper = ctx.freshVariable(" intWrapper" , classOf [UTF8String .IntWrapper ])
14201454 (c, evPrim, evNull) =>
@@ -1443,9 +1477,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
14431477 }
14441478
14451479 private [this ] def castToLongCode (from : DataType , ctx : CodegenContext ): CastFunction = from match {
1480+ case StringType if ansiEnabled =>
1481+ (c, evPrim, evNull) => code " $evPrim = $c.toLongExact(); "
14461482 case StringType =>
14471483 val wrapper = ctx.freshVariable(" longWrapper" , classOf [UTF8String .LongWrapper ])
1448-
14491484 (c, evPrim, evNull) =>
14501485 code """
14511486 UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
@@ -1476,14 +1511,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
14761511 case StringType =>
14771512 val floatStr = ctx.freshVariable(" floatStr" , StringType )
14781513 (c, evPrim, evNull) =>
1514+ val handleNull = if (ansiEnabled) {
1515+ s """ throw new NumberFormatException("invalid input syntax for type numeric: $c"); """
1516+ } else {
1517+ s " $evNull = true; "
1518+ }
14791519 code """
14801520 final String $floatStr = $c.toString();
14811521 try {
14821522 $evPrim = Float.valueOf( $floatStr);
14831523 } catch (java.lang.NumberFormatException e) {
14841524 final Float f = (Float) Cast.processFloatingPointSpecialLiterals( $floatStr, true);
14851525 if (f == null) {
1486- $evNull = true;
1526+ $handleNull
14871527 } else {
14881528 $evPrim = f.floatValue();
14891529 }
@@ -1507,14 +1547,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
15071547 case StringType =>
15081548 val doubleStr = ctx.freshVariable(" doubleStr" , StringType )
15091549 (c, evPrim, evNull) =>
1550+ val handleNull = if (ansiEnabled) {
1551+ s """ throw new NumberFormatException("invalid input syntax for type numeric: $c"); """
1552+ } else {
1553+ s " $evNull = true; "
1554+ }
15101555 code """
15111556 final String $doubleStr = $c.toString();
15121557 try {
15131558 $evPrim = Double.valueOf( $doubleStr);
15141559 } catch (java.lang.NumberFormatException e) {
15151560 final Double d = (Double) Cast.processFloatingPointSpecialLiterals( $doubleStr, false);
15161561 if (d == null) {
1517- $evNull = true;
1562+ $handleNull
15181563 } else {
15191564 $evPrim = d.doubleValue();
15201565 }
0 commit comments