Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ singleTableSchema
: colTypeList EOF
;

singleInterval
: INTERVAL? multiUnitsInterval EOF
;

statement
: query #statementDefault
| ctes? dmlStatementNoWith #dmlStatement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => IntervalUtils.stringToInterval(s))
buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s))
}

// LongConverter
Expand Down Expand Up @@ -1216,7 +1216,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case StringType =>
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
(c, evPrim, evNull) =>
code"""$evPrim = $util.stringToInterval($c);
code"""$evPrim = $util.safeStringToInterval($c);
if(${evPrim} == null) {
${evNull} = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class TimeWindow(
timeColumn: Expression,
Expand Down Expand Up @@ -103,7 +104,7 @@ object TimeWindow {
* precision.
*/
private def getIntervalInMicroSeconds(interval: String): Long = {
val cal = IntervalUtils.fromString(interval)
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
if (cal.months != 0) {
throw new IllegalArgumentException(
s"Intervals greater than a month is not supported ($interval).")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
}

override def visitSingleInterval(ctx: SingleIntervalContext): CalendarInterval = {
withOrigin(ctx)(visitMultiUnitsInterval(ctx.multiUnitsInterval))
}

/* ********************************************************************************************
* Plan parsing
* ******************************************************************************************** */
Expand Down Expand Up @@ -1870,7 +1866,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
toLiteral(stringToTimestamp(_, zoneId), TimestampType)
case "INTERVAL" =>
val interval = try {
IntervalUtils.fromString(value)
IntervalUtils.stringToInterval(UTF8String.fromString(value))
} catch {
case e: IllegalArgumentException =>
val ex = new ParseException("Cannot parse the INTERVAL value: " + value, ctx)
Expand Down Expand Up @@ -2069,22 +2065,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = {
withOrigin(ctx) {
val units = ctx.intervalUnit().asScala.map { unit =>
val u = unit.getText.toLowerCase(Locale.ROOT)
// Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
if (u.endsWith("s")) u.substring(0, u.length - 1) else u
}.map(IntervalUtils.IntervalUnit.withName).toArray

val values = ctx.intervalValue().asScala.map { value =>
if (value.STRING() != null) {
string(value.STRING())
} else {
value.getText
}
}.toArray

val units = ctx.intervalUnit().asScala
val values = ctx.intervalValue().asScala
try {
IntervalUtils.fromUnitStrings(units, values)
assert(units.length == values.length)
val kvs = units.indices.map { i =>
val u = units(i).getText
val v = if (values(i).STRING() != null) {
string(values(i).STRING())
} else {
values(i).getText
}
UTF8String.fromString(" " + v + " " + u)
}
IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*))
} catch {
case i: IllegalArgumentException =>
val e = new ParseException(i.getMessage, ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,12 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.unsafe.types.CalendarInterval

/**
* Base SQL parsing infrastructure.
*/
abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Logging {

/**
* Creates [[CalendarInterval]] for a given SQL String. Throws [[ParseException]] if the SQL
* string is not a valid interval format.
*/
def parseInterval(sqlText: String): CalendarInterval = parse(sqlText) { parser =>
astBuilder.visitSingleInterval(parser.singleInterval())
}

/** Creates/Resolves DataType for a given SQL string. */
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.util.concurrent.TimeUnit

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
Expand Down Expand Up @@ -101,34 +100,6 @@ object IntervalUtils {
Decimal(result, 18, 6)
}

/**
* Converts a string to [[CalendarInterval]] case-insensitively.
*
* @throws IllegalArgumentException if the input string is not in valid interval format.
*/
def fromString(str: String): CalendarInterval = {
if (str == null) throw new IllegalArgumentException("Interval string cannot be null")
try {
CatalystSqlParser.parseInterval(str)
} catch {
case e: ParseException =>
val ex = new IllegalArgumentException(s"Invalid interval string: $str\n" + e.message)
ex.setStackTrace(e.getStackTrace)
throw ex
}
}

/**
* A safe version of `fromString`. It returns null for invalid input string.
*/
def safeFromString(str: String): CalendarInterval = {
try {
fromString(str)
} catch {
case _: IllegalArgumentException => null
}
}

private def toLongWithRange(
fieldName: IntervalUnit,
s: String,
Expand Down Expand Up @@ -250,46 +221,6 @@ object IntervalUtils {
}
}

def fromUnitStrings(units: Array[IntervalUnit], values: Array[String]): CalendarInterval = {
assert(units.length == values.length)
var months: Int = 0
var days: Int = 0
var microseconds: Long = 0
var i = 0
while (i < units.length) {
try {
units(i) match {
case YEAR =>
months = Math.addExact(months, Math.multiplyExact(values(i).toInt, 12))
case MONTH =>
months = Math.addExact(months, values(i).toInt)
case WEEK =>
days = Math.addExact(days, Math.multiplyExact(values(i).toInt, 7))
case DAY =>
days = Math.addExact(days, values(i).toInt)
case HOUR =>
val hoursUs = Math.multiplyExact(values(i).toLong, MICROS_PER_HOUR)
microseconds = Math.addExact(microseconds, hoursUs)
case MINUTE =>
val minutesUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MINUTE)
microseconds = Math.addExact(microseconds, minutesUs)
case SECOND =>
microseconds = Math.addExact(microseconds, parseSecondNano(values(i)))
case MILLISECOND =>
val millisUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MILLIS)
microseconds = Math.addExact(microseconds, millisUs)
case MICROSECOND =>
microseconds = Math.addExact(microseconds, values(i).toLong)
}
} catch {
case e: Exception =>
throw new IllegalArgumentException(s"Error parsing interval string: ${e.getMessage}", e)
}
i += 1
}
new CalendarInterval(months, days, microseconds)
}

// Parses a string with nanoseconds, truncates the result and returns microseconds
private def parseNanos(nanosStr: String, isNegative: Boolean): Long = {
if (nanosStr != null) {
Expand All @@ -305,30 +236,6 @@ object IntervalUtils {
}
}

/**
* Parse second_nano string in ss.nnnnnnnnn format to microseconds
*/
private def parseSecondNano(secondNano: String): Long = {
def parseSeconds(secondsStr: String): Long = {
toLongWithRange(
SECOND,
secondsStr,
Long.MinValue / MICROS_PER_SECOND,
Long.MaxValue / MICROS_PER_SECOND) * MICROS_PER_SECOND
}

secondNano.split("\\.") match {
case Array(secondsStr) => parseSeconds(secondsStr)
case Array("", nanosStr) => parseNanos(nanosStr, false)
case Array(secondsStr, nanosStr) =>
val seconds = parseSeconds(secondsStr)
Math.addExact(seconds, parseNanos(nanosStr, seconds < 0))
case _ =>
throw new IllegalArgumentException(
"Interval string does not match second-nano format of ss.nnnnnnnnn")
}
}

/**
* Gets interval duration
*
Expand Down Expand Up @@ -452,20 +359,40 @@ object IntervalUtils {
private final val millisStr = unitToUtf8(MILLISECOND)
private final val microsStr = unitToUtf8(MICROSECOND)

/**
* A safe version of `stringToInterval`. It returns null for invalid input string.
*/
def safeStringToInterval(input: UTF8String): CalendarInterval = {
try {
stringToInterval(input)
} catch {
case _: IllegalArgumentException => null
}
}

/**
* Converts a string to [[CalendarInterval]] case-insensitively.
*
* @throws IllegalArgumentException if the input string is not in valid interval format.
*/
def stringToInterval(input: UTF8String): CalendarInterval = {
import ParseState._
var state = PREFIX
def throwIAE(msg: String, e: Exception = null) = {
throw new IllegalArgumentException(s"Error parsing interval, $msg", e)
}

if (input == null) {
return null
throwIAE("interval string cannot be null")
}
// scalastyle:off caselocale .toLowerCase
val s = input.trim.toLowerCase
// scalastyle:on
val bytes = s.getBytes
if (bytes.isEmpty) {
return null
throwIAE("interval string cannot be empty")
}
var state = PREFIX

var i = 0
var currentValue: Long = 0
var isNegative: Boolean = false
Expand All @@ -482,13 +409,17 @@ object IntervalUtils {
}
}

def nextWord: UTF8String = {
s.substring(i, s.numBytes()).subStringIndex(UTF8String.blankString(1), 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for error reporting, so perf doesn't matter. Can we make it more readable? like

s.substring(i, s.numBytes()).toString.split("\\s+").head

}

while (i < bytes.length) {
val b = bytes(i)
state match {
case PREFIX =>
if (s.startsWith(intervalStr)) {
if (s.numBytes() == intervalStr.numBytes()) {
return null
throwIAE("interval string cannot be empty")
} else {
i += intervalStr.numBytes()
}
Expand Down Expand Up @@ -518,10 +449,10 @@ object IntervalUtils {
isNegative = false
case '.' =>
isNegative = false
fractionScale = (NANOS_PER_SECOND / 10).toInt
i += 1
fractionScale = (NANOS_PER_SECOND / 10).toInt
state = VALUE_FRACTIONAL_PART
case _ => return null
case _ => throwIAE( s"unrecognized sign '$nextWord'")
}
case TRIM_BEFORE_VALUE => trimToNextState(b, VALUE)
case VALUE =>
Expand All @@ -530,13 +461,13 @@ object IntervalUtils {
try {
currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0'))
} catch {
case _: ArithmeticException => return null
case e: ArithmeticException => throwIAE(e.getMessage, e)
}
case ' ' => state = TRIM_BEFORE_UNIT
case '.' =>
fractionScale = (NANOS_PER_SECOND / 10).toInt
state = VALUE_FRACTIONAL_PART
case _ => return null
case _ => throwIAE(s"invalid value '$nextWord'")
}
i += 1
case VALUE_FRACTIONAL_PART =>
Expand All @@ -547,14 +478,16 @@ object IntervalUtils {
case ' ' =>
fraction /= NANOS_PER_MICROS.toInt
state = TRIM_BEFORE_UNIT
case _ => return null
case _ if '0' <= b && b <= '9' =>
throwIAE(s"invalid value fractional part '$fraction$nextWord' out of range")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about interval can only support nanosecond precision, '$nextWord' is out of range

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW can we really implement nextWord correctly? We need to know where we start to parse a number, seems we don't track it now.

Copy link
Member Author

@yaooqinn yaooqinn Nov 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about interval can only support nanosecond precision, '$nextWord' is out of range

the is not suitable, 0.9999999999 the nextword will be 9 only but '$fraction$nextWord' is the exact 9999999999

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd improve this.

case _ => throwIAE(s"invalid value '$nextWord' in fractional part")
}
i += 1
case TRIM_BEFORE_UNIT => trimToNextState(b, UNIT_BEGIN)
case UNIT_BEGIN =>
// Checks that only seconds can have the fractional part
if (b != 's' && fractionScale >= 0) {
return null
throwIAE(s"'$nextWord' with fractional part is unsupported")
}
if (isNegative) {
currentValue = -currentValue
Expand Down Expand Up @@ -598,26 +531,26 @@ object IntervalUtils {
} else if (s.matchAt(microsStr, i)) {
microseconds = Math.addExact(microseconds, currentValue)
i += microsStr.numBytes()
} else return null
case _ => return null
} else throwIAE(s"invalid unit '$nextWord'")
case _ => throwIAE(s"invalid unit '$nextWord'")
}
} catch {
case _: ArithmeticException => return null
case e: ArithmeticException => throwIAE(e.getMessage, e)
}
state = UNIT_SUFFIX
case UNIT_SUFFIX =>
b match {
case 's' => state = UNIT_END
case ' ' => state = TRIM_BEFORE_SIGN
case _ => return null
case _ => throwIAE(s"invalid unit suffix '$nextWord'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invalid unit '$nextWord' is better if we can implement nextword correctly. Or we should introduce "currentWord"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the nextword represents which character the error occurs, the only special case is for nanoseconds out of range, the fraction + nextword could just exactly show the out of range number

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we change the logic of unit parsing we can extract the whole part like case _ if b>= 'A' $$ b<= 'z' => unit = s.substring(i, s.numBytes()).subStringIndex(UTF8String.blankString(1), 1) than do unit case matching and error capture.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for error reporting, usually backtracing is necessary. For example, it's better to tell users that 123a is not valid, instead of just saying a is not valid.

}
i += 1
case UNIT_END =>
b match {
case ' ' =>
i += 1
state = TRIM_BEFORE_SIGN
case _ => return null
case _ => throwIAE(s"invalid unit suffix '$nextWord'")
}
}
}
Expand Down
Loading