Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
val result = new LongWrapper()
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value
else throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s"))
case StringType =>
val result = new LongWrapper()
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
Expand All @@ -499,6 +503,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value
else throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s"))
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
Expand All @@ -523,7 +531,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
buildCast[UTF8String](_, s => if (s.toShort(result)) {
result.value.toShort
} else {
null
if (ansiEnabled) {
throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")
} else {
null
}
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
Expand Down Expand Up @@ -564,7 +576,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
buildCast[UTF8String](_, s => if (s.toByte(result)) {
result.value.toByte
} else {
null
if (ansiEnabled) {
throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")
} else {
null
}
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
Expand Down Expand Up @@ -636,7 +652,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
// Please refer to https://github.com/apache/spark/pull/26640
changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target)
} catch {
case _: NumberFormatException => null
case _: NumberFormatException =>
if (ansiEnabled) {
throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")
} else {
null
}
})
case BooleanType =>
buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
Expand All @@ -659,6 +680,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, s => {
val doubleStr = s.toString
try doubleStr.toDouble catch {
case _: NumberFormatException =>
val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
if(d == null) {
throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")
} else {
d.asInstanceOf[Double].doubleValue()
}
}
})
case StringType =>
buildCast[UTF8String](_, s => {
val doubleStr = s.toString
Expand All @@ -679,6 +713,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, s => {
val floatStr = s.toString
try floatStr.toFloat catch {
case _: NumberFormatException =>
val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
if (f == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is too much code duplication. How about unifying these 2 cases?

val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
if (f == null && ansiEnabled) {
  throw ...
} else {
  f
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")
} else {
f.asInstanceOf[Float].floatValue()
}
}
})
case StringType =>
buildCast[UTF8String](_, s => {
val floatStr = s.toString
Expand Down Expand Up @@ -1133,7 +1180,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim()));
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)}
} catch (java.lang.NumberFormatException e) {
$evNull = true;
if ($ansiEnabled) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this will generate java code with if-else, we can do better

val handleException = if (ansiEnabled) {
  s"throw new NumberFormatException("invalid input syntax for type numeric: $c");"
} else {
  s"$evNull =true;"
}
code"""
  ...
  } catch (java.lang.NumberFormatException e) {
    $handleException
  }
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

throw new IllegalArgumentException("invalid input syntax for type numeric: $c");
} else {
$evNull =true;
}
}
"""
case BooleanType =>
Expand Down Expand Up @@ -1354,7 +1405,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
"""
}

private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
private[this] def castToByteCode(
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
(c, evPrim, evNull) =>
Expand All @@ -1363,7 +1416,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
if ($c.toByte($wrapper)) {
$evPrim = (byte) $wrapper.value;
} else {
$evNull = true;
if ($ansiEnabled) {
throw new IllegalArgumentException("invalid input syntax for type numeric: $c");
} else {
$evNull = true;
}
}
$wrapper = null;
"""
Expand Down Expand Up @@ -1394,7 +1451,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
if ($c.toShort($wrapper)) {
$evPrim = (short) $wrapper.value;
} else {
$evNull = true;
if ($ansiEnabled) {
throw new IllegalArgumentException("invalid input syntax for type numeric: $c");
} else {
$evNull = true;
}
}
$wrapper = null;
"""
Expand All @@ -1414,7 +1475,21 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
(c, evPrim, evNull) => code"$evPrim = (short) $c;"
}

private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
private[this] def castToIntCode(
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType if ansiEnabled =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
(c, evPrim, evNull) =>
code"""
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toInt($wrapper)) {
$evPrim = $wrapper.value;
} else {
throw new IllegalArgumentException("invalid input syntax for type numeric: $c");
}
$wrapper = null;
"""
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
(c, evPrim, evNull) =>
Expand Down Expand Up @@ -1442,7 +1517,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
(c, evPrim, evNull) => code"$evPrim = (int) $c;"
}

private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
private[this] def castToLongCode(
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper])

Expand All @@ -1452,7 +1529,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
if ($c.toLong($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
if ($ansiEnabled) {
throw new IllegalArgumentException("invalid input syntax for type numeric: $c");
} else {
$evNull = true;
}
}
$wrapper = null;
"""
Expand All @@ -1473,6 +1554,22 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case StringType if ansiEnabled =>
val floatStr = ctx.freshVariable("floatStr", StringType)
(c, evPrim, evNull) =>
code"""
final String $floatStr = $c.toString();
try {
$evPrim = Float.valueOf($floatStr);
} catch (java.lang.NumberFormatException e) {
final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
if (f == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can unify the code a little bit

val handleNull = if (ansiEnabled) {
  s"throw ..."
} else {
  s"$evNull = true;"
}
...
code"""
  ...
  if (f == null) {
    $handleNull
  } else ...
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

throw new IllegalArgumentException("invalid input syntax for type numeric: $c");
} else {
$evPrim = f.floatValue();
}
}
"""
case StringType =>
val floatStr = ctx.freshVariable("floatStr", StringType)
(c, evPrim, evNull) =>
Expand Down Expand Up @@ -1504,6 +1601,22 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case StringType if ansiEnabled =>
val doubleStr = ctx.freshVariable("doubleStr", StringType)
(c, evPrim, evNull) =>
code"""
final String $doubleStr = $c.toString();
try {
$evPrim = Double.valueOf($doubleStr);
} catch (java.lang.NumberFormatException e) {
final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
if (d == null) {
throw new IllegalArgumentException("invalid input syntax for type numeric: $c");
} else {
$evPrim = d.doubleValue();
}
}
"""
case StringType =>
val doubleStr = ctx.freshVariable("doubleStr", StringType)
(c, evPrim, evNull) =>
Expand Down
Loading