Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import java.time.LocalDate

import scala.collection.mutable.ListBuffer

import ai.rapids.cudf.{DType, Scalar}

import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.localDateToDays

Expand Down Expand Up @@ -57,7 +59,13 @@ object DateUtils {
val YESTERDAY = "yesterday"
val TOMORROW = "tomorrow"

def specialDatesDays: Map[String, Int] = {
private lazy val isSpark320OrLater: Boolean = {
ShimLoader.getSparkShims.getSparkShimVersion.toString >= "3.2"
}

def specialDatesDays: Map[String, Int] = if (isSpark320OrLater) {
Map.empty
} else {
val today = currentDate()
Map(
EPOCH -> 0,
Expand All @@ -68,7 +76,9 @@ object DateUtils {
)
}

def specialDatesSeconds: Map[String, Long] = {
def specialDatesSeconds: Map[String, Long] = if (isSpark320OrLater) {
Map.empty
} else {
val today = currentDate()
val now = DateTimeUtils.currentTimestamp()
Map(
Expand All @@ -80,7 +90,9 @@ object DateUtils {
)
}

def specialDatesMicros: Map[String, Long] = {
def specialDatesMicros: Map[String, Long] = if (isSpark320OrLater) {
Map.empty
} else {
val today = currentDate()
val now = DateTimeUtils.currentTimestamp()
Map(
Expand All @@ -92,6 +104,17 @@ object DateUtils {
)
}

def fetchSpecialDates(unit: DType): Map[String, Scalar] = unit match {
case DType.TIMESTAMP_DAYS =>
DateUtils.specialDatesDays.map { case (k, v) => k -> Scalar.timestampDaysFromInt(v) }
case DType.TIMESTAMP_SECONDS =>
DateUtils.specialDatesSeconds.map { case (k, v) => k -> Scalar.timestampFromLong(unit, v) }
case DType.TIMESTAMP_MICROSECONDS =>
DateUtils.specialDatesMicros.map { case (k, v) => k -> Scalar.timestampFromLong(unit, v) }
case _ =>
throw new IllegalArgumentException(s"unsupported DType: $unit")
}

def currentDate(): Int = localDateToDays(LocalDate.now())

case class FormatKeywordToReplace(word: String, startIndex: Int, endIndex: Int)
Expand Down
75 changes: 26 additions & 49 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids
import java.text.SimpleDateFormat
import java.time.DateTimeException

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar}
Expand All @@ -27,6 +28,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, NullIntolerant, TimeZoneAwareExpression}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuToTimestamp.replaceSpecialDates
import org.apache.spark.sql.rapids.RegexReplace
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -788,27 +790,6 @@ object GpuCast extends Arm {
}
}

/**
* Replace special date strings such as "now" with timestampDays. This method does not
* close the `input` ColumnVector.
*/
def specialDateOr(
input: ColumnVector,
special: String,
value: Int,
orColumnVector: ColumnVector): ColumnVector = {

withResource(orColumnVector) { other =>
withResource(Scalar.fromString(special)) { str =>
withResource(input.equalTo(str)) { isStr =>
withResource(Scalar.timestampDaysFromInt(value)) { date =>
isStr.ifElse(date, other)
}
}
}
}
}

/** This method does not close the `input` ColumnVector. */
def convertDateOrNull(
input: ColumnVector,
Expand Down Expand Up @@ -884,16 +865,24 @@ object GpuCast extends Arm {
*/
private def castStringToDate(sanitizedInput: ColumnVector): ColumnVector = {

val specialDates = DateUtils.specialDatesDays

// convert dates that are in valid formats yyyy, yyyy-mm, yyyy-mm-dd
val converted = convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM_DD, "%Y-%m-%d",
convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM, "%Y-%m",
convertDateOrNull(sanitizedInput, DATE_REGEX_YYYY, "%Y")))

// handle special dates like "epoch", "now", etc.
specialDates.foldLeft(converted)((prev, specialDate) =>
specialDateOr(sanitizedInput, specialDate._1, specialDate._2, prev))
closeOnExcept(converted) { tsVector =>
DateUtils.fetchSpecialDates(DType.TIMESTAMP_DAYS) match {
case dates if dates.nonEmpty =>
// `tsVector` will be closed in replaceSpecialDates
val (specialNames, specialValues) = dates.unzip
withResource(specialValues.toList) { scalars =>
replaceSpecialDates(sanitizedInput, tsVector, specialNames.toList, scalars)
}
case _ =>
tsVector
}
}
}

private def castStringToDateAnsi(input: ColumnVector, ansiMode: Boolean): ColumnVector = {
Expand All @@ -908,27 +897,6 @@ object GpuCast extends Arm {
}
}

/**
* Replace special date strings such as "now" with timestampMicros. This method does not
* close the `input` ColumnVector.
*/
private def specialTimestampOr(
input: ColumnVector,
special: String,
value: Long,
orColumnVector: ColumnVector): ColumnVector = {

withResource(orColumnVector) { other =>
withResource(Scalar.fromString(special)) { str =>
withResource(input.equalTo(str)) { isStr =>
withResource(Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, value)) { date =>
isStr.ifElse(date, other)
}
}
}
}
}

/** This method does not close the `input` ColumnVector. */
private def convertTimestampOrNull(
input: ColumnVector,
Expand Down Expand Up @@ -1009,7 +977,6 @@ object GpuCast extends Arm {
val today = DateUtils.currentDate()
val todayStr = new SimpleDateFormat("yyyy-MM-dd")
.format(today * DateUtils.ONE_DAY_SECONDS * 1000L)
val specialDates = DateUtils.specialDatesMicros

var sanitizedInput = input.incRefCount()

Expand All @@ -1027,8 +994,18 @@ object GpuCast extends Arm {
convertTimestampOrNull(sanitizedInput, TIMESTAMP_REGEX_YYYY, "%Y"))))

// handle special dates like "epoch", "now", etc.
val finalResult = specialDates.foldLeft(converted)((prev, specialDate) =>
specialTimestampOr(sanitizedInput, specialDate._1, specialDate._2, prev))
val finalResult = closeOnExcept(converted) { tsVector =>
DateUtils.fetchSpecialDates(DType.TIMESTAMP_MICROSECONDS) match {
case dates if dates.nonEmpty =>
// `tsVector` will be closed in replaceSpecialDates.
val (specialNames, specialValues) = dates.unzip
withResource(specialValues.toList) { scalars =>
replaceSpecialDates(sanitizedInput, tsVector, specialNames.toList, scalars)
}
case _ =>
tsVector
}
}

if (ansiMode) {
// When ANSI mode is enabled, we need to throw an exception if any values could not be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package org.apache.spark.sql.rapids

import java.util.concurrent.TimeUnit

import scala.collection.mutable

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta, ShimLoader}
import com.nvidia.spark.rapids.DateUtils.TimestampFormatConversionException
import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
Expand Down Expand Up @@ -531,20 +533,30 @@ object GpuToTimestamp extends Arm {
FIX_SINGLE_DIGIT_SECOND
)

def daysScalarSeconds(name: String): Scalar = {
Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name))
}

def daysScalarMicros(name: String): Scalar = {
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name))
}

def daysEqual(col: ColumnVector, name: String): ColumnVector = {
withResource(Scalar.fromString(name)) { scalarName =>
col.equalTo(scalarName)
}
}

/**
* Replace special date strings such as "now" with timestampDays. This method does not
* close the `stringVector` and `specialValues`.
*/
def replaceSpecialDates(
stringVector: ColumnVector,
chronoVector: ColumnVector,
specialNames: Seq[String],
specialValues: Seq[Scalar]): ColumnVector = {
specialValues.zip(specialNames).foldLeft(chronoVector) { case (buffer, (scalar, name)) =>
withResource(buffer) { bufVector =>
withResource(daysEqual(stringVector, name)) { isMatch =>
isMatch.ifElse(scalar, bufVector)
}
}
}
}

def isTimestamp(col: ColumnVector, sparkFormat: String, strfFormat: String) : ColumnVector = {
if (CORRECTED_COMPATIBLE_FORMATS.contains(sparkFormat)) {
// the cuDF `is_timestamp` function is less restrictive than Spark's behavior for UnixTime
Expand Down Expand Up @@ -572,50 +584,30 @@ object GpuToTimestamp extends Arm {
lhs: GpuColumnVector,
sparkFormat: String,
strfFormat: String,
dtype: DType,
daysScalar: String => Scalar,
asTimestamp: (ColumnVector, String) => ColumnVector): ColumnVector = {
dtype: DType): ColumnVector = {

// `tsVector` will be closed in replaceSpecialDates
val tsVector = withResource(isTimestamp(lhs.getBase, sparkFormat, strfFormat)) { isTs =>
withResource(Scalar.fromNull(dtype)) { nullValue =>
withResource(lhs.getBase.asTimestamp(dtype, strfFormat)) { tsVec =>
isTs.ifElse(tsVec, nullValue)
}
}
}

// in addition to date/timestamp strings, we also need to check for special dates and null
// values, since anything else is invalid and should throw an error or be converted to null
// depending on the policy
withResource(isTimestamp(lhs.getBase, sparkFormat, strfFormat)) { isTimestamp =>
withResource(daysEqual(lhs.getBase, DateUtils.EPOCH)) { isEpoch =>
withResource(daysEqual(lhs.getBase, DateUtils.NOW)) { isNow =>
withResource(daysEqual(lhs.getBase, DateUtils.TODAY)) { isToday =>
withResource(daysEqual(lhs.getBase, DateUtils.YESTERDAY)) { isYesterday =>
withResource(daysEqual(lhs.getBase, DateUtils.TOMORROW)) { isTomorrow =>
withResource(lhs.getBase.isNull) { _ =>
withResource(Scalar.fromNull(dtype)) { nullValue =>
withResource(asTimestamp(lhs.getBase, strfFormat)) { converted =>
withResource(daysScalar(DateUtils.EPOCH)) { epoch =>
withResource(daysScalar(DateUtils.NOW)) { now =>
withResource(daysScalar(DateUtils.TODAY)) { today =>
withResource(daysScalar(DateUtils.YESTERDAY)) { yesterday =>
withResource(daysScalar(DateUtils.TOMORROW)) { tomorrow =>
withResource(isTomorrow.ifElse(tomorrow, nullValue)) { a =>
withResource(isYesterday.ifElse(yesterday, a)) { b =>
withResource(isToday.ifElse(today, b)) { c =>
withResource(isNow.ifElse(now, c)) { d =>
withResource(isEpoch.ifElse(epoch, d)) { e =>
isTimestamp.ifElse(converted, e)
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
closeOnExcept(tsVector) { tsVector =>
DateUtils.fetchSpecialDates(dtype) match {
case dates if dates.nonEmpty =>
// `tsVector` will be closed in replaceSpecialDates
val (specialNames, specialValues) = dates.unzip
withResource(specialValues.toList) { scalars =>
replaceSpecialDates(lhs.getBase, tsVector, specialNames.toList, scalars)
}
}
case _ =>
tsVector
}
}
}
Expand Down Expand Up @@ -760,9 +752,7 @@ abstract class GpuToTimestamp
lhs,
sparkFormat,
strfFormat,
DType.TIMESTAMP_MICROSECONDS,
daysScalarMicros,
(col, strfFormat) => col.asTimestampMicroseconds(strfFormat))
DType.TIMESTAMP_MICROSECONDS)
}
} else { // Timestamp or DateType
lhs.getBase.asTimestampMicroseconds()
Expand Down Expand Up @@ -811,9 +801,7 @@ abstract class GpuToTimestampImproved extends GpuToTimestamp {
lhs,
sparkFormat,
strfFormat,
DType.TIMESTAMP_SECONDS,
daysScalarSeconds,
(col, strfFormat) => col.asTimestampSeconds(strfFormat))
DType.TIMESTAMP_SECONDS)
}
} else if (lhs.dataType() == DateType){
lhs.getBase.asTimestampSeconds()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,8 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE
.repartition(2)
.withColumn("c1", unix_timestamp(col("c0"), "yyyy-MM-dd HH:mm:ss"))
}
val startTimeSeconds = System.currentTimeMillis()/1000L
val cpuNowSeconds = withCpuSparkSession(now).collect().head.toSeq(1).asInstanceOf[Long]
val gpuNowSeconds = withGpuSparkSession(now).collect().head.toSeq(1).asInstanceOf[Long]
assert(cpuNowSeconds >= startTimeSeconds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for removing these assertions? Are they no longer valid?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, under spark 3.2+, the result will be zero instead of current time, since NOW is not longer being parsed in 3.2.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then have the test check the version. and have it check that now is replaced with a 0 instead. Or just have the test skip entirely if it is 3.2+

assert(gpuNowSeconds >= startTimeSeconds)
// CPU ran first so cannot have a greater value than the GPU run (but could be the same second)
assert(cpuNowSeconds <= gpuNowSeconds)
}
Expand Down Expand Up @@ -534,4 +531,3 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE
)

}