Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ package com.nvidia.spark.rapids.shims.v2

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig}
import ai.rapids.cudf.{DType, Scalar}
import com.nvidia.spark.rapids.{DateUtils, ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig}
import com.nvidia.spark.rapids.GpuOverrides.exec
import org.apache.hadoop.fs.FileStatus

Expand Down Expand Up @@ -126,4 +127,15 @@ trait Spark30XShims extends SparkShims {
ss.sparkContext.defaultParallelism
}

override def getSpecialDate(name: String, unit: DType): Scalar = unit match {
case DType.TIMESTAMP_DAYS =>
Scalar.timestampDaysFromInt(DateUtils.specialDatesDays(name))
case DType.TIMESTAMP_SECONDS =>
Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name))
case DType.TIMESTAMP_MICROSECONDS =>
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name))
case _ =>
throw new IllegalArgumentException(s"unsupported DType: $unit")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids.shims.v2

import ai.rapids.cudf.{DType, Scalar}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.GpuOverrides.exec
import org.apache.hadoop.fs.FileStatus
Expand Down Expand Up @@ -137,6 +138,12 @@ trait Spark32XShims extends SparkShims {
Spark32XShimsUtils.leafNodeDefaultParallelism(ss)
}

override def getSpecialDate(name: String, unit: DType): Scalar = unit match {
case DType.TIMESTAMP_DAYS => Scalar.fromNull(DType.TIMESTAMP_DAYS)
case DType.TIMESTAMP_SECONDS => Scalar.fromNull(DType.TIMESTAMP_SECONDS)
case DType.TIMESTAMP_MICROSECONDS => Scalar.fromNull(DType.TIMESTAMP_MICROSECONDS)
case _ => throw new IllegalArgumentException(s"unsupported DType: $unit")
}
}

// TODO dedupe utils inside shims
Expand Down
76 changes: 27 additions & 49 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids
import java.text.SimpleDateFormat
import java.time.DateTimeException

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar}
Expand All @@ -27,6 +28,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, NullIntolerant, TimeZoneAwareExpression}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuToTimestamp.replaceSpecialDates
import org.apache.spark.sql.rapids.RegexReplace
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -805,27 +807,6 @@ object GpuCast extends Arm {
}
}

/**
* Replace special date strings such as "now" with timestampDays. This method does not
* close the `input` ColumnVector.
*/
def specialDateOr(
input: ColumnVector,
special: String,
value: Int,
orColumnVector: ColumnVector): ColumnVector = {

withResource(orColumnVector) { other =>
withResource(Scalar.fromString(special)) { str =>
withResource(input.equalTo(str)) { isStr =>
withResource(Scalar.timestampDaysFromInt(value)) { date =>
isStr.ifElse(date, other)
}
}
}
}
}

/**
* Parse dates that match the provided length and format. This method does not
* close the `input` ColumnVector.
Expand Down Expand Up @@ -928,8 +909,6 @@ object GpuCast extends Arm {
*/
private def castStringToDate(input: ColumnVector): ColumnVector = {

val specialDates = DateUtils.specialDatesDays

withResource(sanitizeStringToDate(input)) { sanitizedInput =>

// convert dates that are in valid formats yyyy, yyyy-mm, yyyy-mm-dd
Expand All @@ -938,8 +917,18 @@ object GpuCast extends Arm {
convertFixedLenDateOrNull(sanitizedInput, 4, "%Y")))

// handle special dates like "epoch", "now", etc.
specialDates.foldLeft(converted)((prev, specialDate) =>
specialDateOr(sanitizedInput, specialDate._1, specialDate._2, prev))
// `converted` will be closed in replaceSpecialDates. We wrap it with closeOnExcept in case
// of exception before replaceSpecialDates.
closeOnExcept(converted) { timeStampVector =>
val specialDates = Seq(DateUtils.EPOCH, DateUtils.NOW, DateUtils.TODAY,
DateUtils.YESTERDAY, DateUtils.TOMORROW)
val specialValues = mutable.ListBuffer.empty[Scalar]
withResource(specialValues) { _ =>
specialDates.foreach(
specialValues += ShimLoader.getSparkShims.getSpecialDate(_, DType.TIMESTAMP_DAYS))
replaceSpecialDates(sanitizedInput, timeStampVector, specialDates, specialValues)
}
}
}
}

Expand All @@ -955,27 +944,6 @@ object GpuCast extends Arm {
}
}

/**
* Replace special date strings such as "now" with timestampMicros. This method does not
* close the `input` ColumnVector.
*/
private def specialTimestampOr(
input: ColumnVector,
special: String,
value: Long,
orColumnVector: ColumnVector): ColumnVector = {

withResource(orColumnVector) { other =>
withResource(Scalar.fromString(special)) { str =>
withResource(input.equalTo(str)) { isStr =>
withResource(Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, value)) { date =>
isStr.ifElse(date, other)
}
}
}
}
}

/**
* Parse dates that match the the provided regex. This method does not close the `input`
* ColumnVector.
Expand Down Expand Up @@ -1059,7 +1027,6 @@ object GpuCast extends Arm {
val today = DateUtils.currentDate()
val todayStr = new SimpleDateFormat("yyyy-MM-dd")
.format(today * DateUtils.ONE_DAY_SECONDS * 1000L)
val specialDates = DateUtils.specialDatesMicros

var sanitizedInput = input.incRefCount()

Expand Down Expand Up @@ -1093,8 +1060,19 @@ object GpuCast extends Arm {
convertFixedLenTimestampOrNull(sanitizedInput, 4, "%Y"))))

// handle special dates like "epoch", "now", etc.
val finalResult = specialDates.foldLeft(converted)((prev, specialDate) =>
specialTimestampOr(sanitizedInput, specialDate._1, specialDate._2, prev))
// `converted` will be closed in replaceSpecialDates. We wrap it with closeOnExcept in case
// of exception before replaceSpecialDates.
val finalResult = closeOnExcept(converted) { timeStampVector =>
val specialDates = Seq(DateUtils.EPOCH, DateUtils.NOW, DateUtils.TODAY,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems expensive to replace the special dates with null with Spark 3.2 rather than just skip handling them at all. They will already be ignored by is_timestamp.

I will pull this PR locally today and experiment with it and come back with more detailed suggestions.

DateUtils.YESTERDAY, DateUtils.TOMORROW)
val specialValues = mutable.ListBuffer.empty[Scalar]
withResource(specialValues) { _ =>
specialDates.foreach(
specialValues +=
ShimLoader.getSparkShims.getSpecialDate(_, DType.TIMESTAMP_MICROSECONDS))
replaceSpecialDates(sanitizedInput, timeStampVector, specialDates, specialValues)
}
}

if (ansiMode) {
// When ANSI mode is enabled, we need to throw an exception if any values could not be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids
import java.net.URI
import java.nio.ByteBuffer

import ai.rapids.cudf.{DType, Scalar}
import org.apache.arrow.memory.ReferenceManager
import org.apache.arrow.vector.ValueVector
import org.apache.hadoop.fs.{FileStatus, Path}
Expand Down Expand Up @@ -266,6 +267,8 @@ trait SparkShims {
def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean

def leafNodeDefaultParallelism(ss: SparkSession): Int

def getSpecialDate(name: String, unit: DType): Scalar
}

abstract class SparkCommonShims extends SparkShims {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package org.apache.spark.sql.rapids

import java.util.concurrent.TimeUnit

import scala.collection.mutable

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta, ShimLoader}
import com.nvidia.spark.rapids.DateUtils.TimestampFormatConversionException
import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
Expand Down Expand Up @@ -531,20 +533,30 @@ object GpuToTimestamp extends Arm {
FIX_SINGLE_DIGIT_SECOND
)

def daysScalarSeconds(name: String): Scalar = {
Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name))
}

def daysScalarMicros(name: String): Scalar = {
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name))
}

def daysEqual(col: ColumnVector, name: String): ColumnVector = {
withResource(Scalar.fromString(name)) { scalarName =>
col.equalTo(scalarName)
}
}

/**
* Replace special date strings such as "now" with timestampDays. This method does not
* close the `stringVector` and `specialValues`.
*/
def replaceSpecialDates(
stringVector: ColumnVector,
chronoVector: ColumnVector,
specialNames: Seq[String],
specialValues: Seq[Scalar]): ColumnVector = {
specialValues.zip(specialNames).foldLeft(chronoVector) { case (buffer, (scalar, name)) =>
withResource(buffer) { bufVector =>
withResource(daysEqual(stringVector, name)) { isMatch =>
isMatch.ifElse(scalar, bufVector)
}
}
}
}

def isTimestamp(col: ColumnVector, sparkFormat: String, strfFormat: String) : ColumnVector = {
if (CORRECTED_COMPATIBLE_FORMATS.contains(sparkFormat)) {
// the cuDF `is_timestamp` function is less restrictive than Spark's behavior for UnixTime
Expand Down Expand Up @@ -572,50 +584,29 @@ object GpuToTimestamp extends Arm {
lhs: GpuColumnVector,
sparkFormat: String,
strfFormat: String,
dtype: DType,
daysScalar: String => Scalar,
asTimestamp: (ColumnVector, String) => ColumnVector): ColumnVector = {
dtype: DType): ColumnVector = {

// `tsVector` will be closed in replaceSpecialDates
val tsVector = withResource(isTimestamp(lhs.getBase, sparkFormat, strfFormat)) { isTs =>
withResource(Scalar.fromNull(dtype)) { nullValue =>
withResource(lhs.getBase.asTimestamp(dtype, strfFormat)) { tsVec =>
isTs.ifElse(tsVec, nullValue)
}
}
}

// in addition to date/timestamp strings, we also need to check for special dates and null
// values, since anything else is invalid and should throw an error or be converted to null
// depending on the policy
withResource(isTimestamp(lhs.getBase, sparkFormat, strfFormat)) { isTimestamp =>
withResource(daysEqual(lhs.getBase, DateUtils.EPOCH)) { isEpoch =>
withResource(daysEqual(lhs.getBase, DateUtils.NOW)) { isNow =>
withResource(daysEqual(lhs.getBase, DateUtils.TODAY)) { isToday =>
withResource(daysEqual(lhs.getBase, DateUtils.YESTERDAY)) { isYesterday =>
withResource(daysEqual(lhs.getBase, DateUtils.TOMORROW)) { isTomorrow =>
withResource(lhs.getBase.isNull) { _ =>
withResource(Scalar.fromNull(dtype)) { nullValue =>
withResource(asTimestamp(lhs.getBase, strfFormat)) { converted =>
withResource(daysScalar(DateUtils.EPOCH)) { epoch =>
withResource(daysScalar(DateUtils.NOW)) { now =>
withResource(daysScalar(DateUtils.TODAY)) { today =>
withResource(daysScalar(DateUtils.YESTERDAY)) { yesterday =>
withResource(daysScalar(DateUtils.TOMORROW)) { tomorrow =>
withResource(isTomorrow.ifElse(tomorrow, nullValue)) { a =>
withResource(isYesterday.ifElse(yesterday, a)) { b =>
withResource(isToday.ifElse(today, b)) { c =>
withResource(isNow.ifElse(now, c)) { d =>
withResource(isEpoch.ifElse(epoch, d)) { e =>
isTimestamp.ifElse(converted, e)
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
val specialDates = Seq(DateUtils.EPOCH, DateUtils.NOW, DateUtils.TODAY,
DateUtils.YESTERDAY, DateUtils.TOMORROW)
val specialValues = mutable.ListBuffer.empty[Scalar]

withResource(specialValues) { _ =>
closeOnExcept(tsVector) { _ =>
specialDates.foreach(
specialValues += ShimLoader.getSparkShims.getSpecialDate(_, dtype))
replaceSpecialDates(lhs.getBase, tsVector, specialDates, specialValues)
}
}
}
Expand Down Expand Up @@ -760,9 +751,7 @@ abstract class GpuToTimestamp
lhs,
sparkFormat,
strfFormat,
DType.TIMESTAMP_MICROSECONDS,
daysScalarMicros,
(col, strfFormat) => col.asTimestampMicroseconds(strfFormat))
DType.TIMESTAMP_MICROSECONDS)
}
} else { // Timestamp or DateType
lhs.getBase.asTimestampMicroseconds()
Expand Down Expand Up @@ -811,9 +800,7 @@ abstract class GpuToTimestampImproved extends GpuToTimestamp {
lhs,
sparkFormat,
strfFormat,
DType.TIMESTAMP_SECONDS,
daysScalarSeconds,
(col, strfFormat) => col.asTimestampSeconds(strfFormat))
DType.TIMESTAMP_SECONDS)
}
} else if (lhs.dataType() == DateType){
lhs.getBase.asTimestampSeconds()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,8 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE
.repartition(2)
.withColumn("c1", unix_timestamp(col("c0"), "yyyy-MM-dd HH:mm:ss"))
}
val startTimeSeconds = System.currentTimeMillis()/1000L
val cpuNowSeconds = withCpuSparkSession(now).collect().head.toSeq(1).asInstanceOf[Long]
val gpuNowSeconds = withGpuSparkSession(now).collect().head.toSeq(1).asInstanceOf[Long]
assert(cpuNowSeconds >= startTimeSeconds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for removing these assertions? Are they no longer valid?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, under spark 3.2+, the result will be zero instead of current time, since NOW is not longer being parsed in 3.2.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then have the test check the version. and have it check that now is replaced with a 0 instead. Or just have the test skip entirely if it is 3.2+

assert(gpuNowSeconds >= startTimeSeconds)
// CPU ran first so cannot have a greater value than the GPU run (but could be the same second)
assert(cpuNowSeconds <= gpuNowSeconds)
}
Expand Down Expand Up @@ -534,4 +531,3 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE
)

}