diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c451eb2b877d..078f1d194107 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -831,6 +831,8 @@ object TypeCoercion { * 2. Turns Add/Subtract of TimestampType/DateType/IntegerType * and TimestampType/IntegerType/DateType to DateAdd/DateSub/SubtractDates and * to SubtractTimestamps. + * 3. Turns Multiply/Divide of CalendarIntervalType and NumericType + * to MultiplyInterval/DivideInterval */ object DateTimeOperations extends Rule[LogicalPlan] { @@ -846,6 +848,12 @@ object TypeCoercion { Cast(TimeAdd(l, r), l.dataType) case Subtract(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) => Cast(TimeSub(l, r), l.dataType) + case Multiply(l @ CalendarIntervalType(), r @ NumericType()) => + MultiplyInterval(l, r) + case Multiply(l @ NumericType(), r @ CalendarIntervalType()) => + MultiplyInterval(r, l) + case Divide(l @ CalendarIntervalType(), r @ NumericType()) => + DivideInterval(l, r) case Add(l @ DateType(), r @ IntegerType()) => DateAdd(l, r) case Add(l @ IntegerType(), r @ DateType()) => DateAdd(r, l) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index c3a3b3cb58f4..e37681bca604 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -45,45 +45,45 @@ abstract class ExtractIntervalPart( } case class ExtractIntervalMillenniums(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getMillenniums, "getMillenniums") + extends ExtractIntervalPart(child, IntegerType, getMillenniums, "getMillenniums") case class ExtractIntervalCenturies(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getCenturies, "getCenturies") + extends ExtractIntervalPart(child, IntegerType, getCenturies, "getCenturies") case class ExtractIntervalDecades(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getDecades, "getDecades") + extends ExtractIntervalPart(child, IntegerType, getDecades, "getDecades") case class ExtractIntervalYears(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getYears, "getYears") + extends ExtractIntervalPart(child, IntegerType, getYears, "getYears") case class ExtractIntervalQuarters(child: Expression) - extends ExtractIntervalPart(child, ByteType, getQuarters, "getQuarters") + extends ExtractIntervalPart(child, ByteType, getQuarters, "getQuarters") case class ExtractIntervalMonths(child: Expression) - extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths") + extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths") case class ExtractIntervalDays(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getDays, "getDays") + extends ExtractIntervalPart(child, IntegerType, getDays, "getDays") case class ExtractIntervalHours(child: Expression) - extends ExtractIntervalPart(child, LongType, getHours, "getHours") + extends ExtractIntervalPart(child, LongType, getHours, "getHours") case class ExtractIntervalMinutes(child: Expression) - extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes") + extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes") case class ExtractIntervalSeconds(child: Expression) - extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, "getSeconds") + extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, "getSeconds") case class ExtractIntervalMilliseconds(child: Expression) - extends ExtractIntervalPart(child, DecimalType(8, 3), getMilliseconds, "getMilliseconds") + extends ExtractIntervalPart(child, DecimalType(8, 3), getMilliseconds, "getMilliseconds") case class ExtractIntervalMicroseconds(child: Expression) - extends ExtractIntervalPart(child, LongType, getMicroseconds, "getMicroseconds") + extends ExtractIntervalPart(child, LongType, getMicroseconds, "getMicroseconds") // Number of seconds in 10000 years is 315576000001 (30 days per one month) // which is 12 digits + 6 digits for the fractional part of seconds. case class ExtractIntervalEpoch(child: Expression) - extends ExtractIntervalPart(child, DecimalType(18, 6), getEpoch, "getEpoch") + extends ExtractIntervalPart(child, DecimalType(18, 6), getEpoch, "getEpoch") object ExtractIntervalPart { @@ -109,3 +109,47 @@ object ExtractIntervalPart { case _ => errorHandleFunc } } + +abstract class IntervalNumOperation( + interval: Expression, + num: Expression, + operation: (CalendarInterval, Double) => CalendarInterval, + operationName: String) + extends BinaryExpression with ImplicitCastInputTypes with Serializable { + override def left: Expression = interval + override def right: Expression = num + + override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType, DoubleType) + override def dataType: DataType = CalendarIntervalType + + override def nullable: Boolean = true + + override def nullSafeEval(interval: Any, num: Any): Any = { + try { + operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + } catch { + case _: java.lang.ArithmeticException => null + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (interval, num) => { + val iu = IntervalUtils.getClass.getName.stripSuffix("$") + s""" + try { + ${ev.value} = $iu.$operationName($interval, $num); + } catch (java.lang.ArithmeticException e) { + ${ev.isNull} = true; + } + """ + }) + } + + override def prettyName: String = operationName + "_interval" +} + +case class MultiplyInterval(interval: Expression, num: Expression) + extends IntervalNumOperation(interval, num, multiply, "multiply") + +case class DivideInterval(interval: Expression, num: Expression) + extends IntervalNumOperation(interval, num, divide, "divide") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 73e9f37c9452..1f63f1e14b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -365,4 +365,27 @@ object IntervalUtils { def isNegative(interval: CalendarInterval, daysPerMonth: Int = 31): Boolean = { getDuration(interval, TimeUnit.MICROSECONDS, daysPerMonth) < 0 } + + /** + * Makes an interval from months, days and micros with the fractional part by + * adding the month fraction to days and the days fraction to micros. + */ + private def fromDoubles( + monthsWithFraction: Double, + daysWithFraction: Double, + microsWithFraction: Double): CalendarInterval = { + val truncatedMonths = Math.toIntExact(monthsWithFraction.toLong) + val days = daysWithFraction + DAYS_PER_MONTH * (monthsWithFraction - truncatedMonths) + val truncatedDays = Math.toIntExact(days.toLong) + val micros = microsWithFraction + DateTimeUtils.MICROS_PER_DAY * (days - truncatedDays) + new CalendarInterval(truncatedMonths, truncatedDays, micros.round) + } + + def multiply(interval: CalendarInterval, num: Double): CalendarInterval = { + fromDoubles(num * interval.months, num * interval.days, num * interval.microseconds) + } + + def divide(interval: CalendarInterval, num: Double): CalendarInterval = { + fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 75bb460e2575..c7371a7911df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1597,6 +1597,27 @@ class TypeCoercionSuite extends AnalysisTest { Multiply(CaseWhen(Seq((EqualTo(1, 2), Cast(1, DecimalType(34, 24)))), Cast(100, DecimalType(34, 24))), Cast(1, IntegerType))) } + + test("rule for interval operations") { + val dateTimeOperations = TypeCoercion.DateTimeOperations + val interval = Literal(new CalendarInterval(0, 0, 0)) + + Seq( + Literal(10.toByte, ByteType), + Literal(10.toShort, ShortType), + Literal(10, IntegerType), + Literal(10L, LongType), + Literal(Decimal(10), DecimalType.SYSTEM_DEFAULT), + Literal(10.5.toFloat, FloatType), + Literal(10.5, DoubleType)).foreach { num => + ruleTest(dateTimeOperations, Multiply(interval, num), + MultiplyInterval(interval, num)) + ruleTest(dateTimeOperations, Multiply(num, interval), + MultiplyInterval(interval, num)) + ruleTest(dateTimeOperations, Divide(interval, num), + DivideInterval(interval, num)) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 0c292e11485a..d0610e7260fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import scala.language.implicitConversions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.catalyst.util.IntervalUtils.fromString import org.apache.spark.sql.types.Decimal class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { implicit def interval(s: String): Literal = { - Literal(IntervalUtils.fromString("interval " + s)) + Literal(fromString("interval " + s)) } test("millenniums") { @@ -191,4 +191,39 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ExtractIntervalEpoch("1 second 1 millisecond 1 microsecond"), Decimal(1.001001, 18, 6)) } + + test("multiply") { + def check(interval: String, num: Double, expected: String): Unit = { + checkEvaluation( + MultiplyInterval(Literal(fromString(interval)), Literal(num)), + if (expected == null) null else fromString(expected)) + } + + check("0 seconds", 10, "0 seconds") + check("10 hours", 0, "0 hours") + check("12 months 1 microseconds", 2, "2 years 2 microseconds") + check("-5 year 3 seconds", 3, "-15 years 9 seconds") + check("1 year 1 second", 0.5, "6 months 500 milliseconds") + check("-100 years -1 millisecond", 0.5, "-50 years -500 microseconds") + check("2 months 4 seconds", -0.5, "-1 months -2 seconds") + check("1 month 2 microseconds", 1.5, "1 months 15 days 3 microseconds") + check("2 months", Int.MaxValue, null) + } + + test("divide") { + def check(interval: String, num: Double, expected: String): Unit = { + checkEvaluation( + DivideInterval(Literal(fromString(interval)), Literal(num)), + if (expected == null) null else fromString(expected)) + } + + check("0 seconds", 10, "0 seconds") + check("12 months 3 milliseconds", 2, "6 months 0.0015 seconds") + check("-5 year 3 seconds", 3, "-1 years -8 months 1 seconds") + check("6 years -7 seconds", 3, "2 years -2.333333 seconds") + check("2 years -8 seconds", 0.5, "4 years -16 seconds") + check("-1 month 2 microseconds", -0.25, "4 months -8 microseconds") + check("1 month 3 microsecond", 1.5, "20 days 2 microseconds") + check("2 months", 0, null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index 9bdd5aac28a5..7015475c4b20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.util import java.util.concurrent.TimeUnit import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.IntervalUtils.{fromDayTimeString, fromString, fromYearMonthString} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{MICROS_PER_MILLIS, MICROS_PER_SECOND} +import org.apache.spark.sql.catalyst.util.IntervalUtils._ import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.unsafe.types.CalendarInterval._ class IntervalUtilsSuite extends SparkFunSuite { @@ -34,7 +34,7 @@ class IntervalUtilsSuite extends SparkFunSuite { testSingleUnit("HouR", 3, 0, 0, 3 * MICROS_PER_HOUR) testSingleUnit("MiNuTe", 3, 0, 0, 3 * MICROS_PER_MINUTE) testSingleUnit("Second", 3, 0, 0, 3 * MICROS_PER_SECOND) - testSingleUnit("MilliSecond", 3, 0, 0, 3 * MICROS_PER_MILLI) + testSingleUnit("MilliSecond", 3, 0, 0, 3 * MICROS_PER_MILLIS) testSingleUnit("MicroSecond", 3, 0, 0, 3) for (input <- Seq(null, "", " ")) { @@ -125,7 +125,7 @@ class IntervalUtilsSuite extends SparkFunSuite { new CalendarInterval( 0, 10, - 12 * MICROS_PER_MINUTE + 888 * MICROS_PER_MILLI)) + 12 * MICROS_PER_MINUTE + 888 * MICROS_PER_MILLIS)) assert(fromDayTimeString("-3 0:0:0") === new CalendarInterval(0, -3, 0L)) try { @@ -186,4 +186,43 @@ class IntervalUtilsSuite extends SparkFunSuite { assert(!isNegative("1 year -360 days", 31)) assert(!isNegative("-1 year 380 days", 31)) } + + test("multiply by num") { + var interval = new CalendarInterval(0, 0, 0) + assert(interval === multiply(interval, 0)) + interval = new CalendarInterval(123, 456, 789) + assert(new CalendarInterval(123 * 42, 456 * 42, 789 * 42) === multiply(interval, 42)) + interval = new CalendarInterval(-123, -456, -789) + assert(new CalendarInterval(-123 * 42, -456 * 42, -789 * 42) === multiply(interval, 42)) + assert(new CalendarInterval(1, 22, 12 * MICROS_PER_HOUR) === + multiply(new CalendarInterval(1, 5, 0), 1.5)) + assert(new CalendarInterval(2, 14, 12 * MICROS_PER_HOUR) === + multiply(new CalendarInterval(2, 2, 2 * MICROS_PER_HOUR), 1.2)) + try { + multiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE) + fail("Expected to throw an exception on months overflow") + } catch { + case e: ArithmeticException => + assert(e.getMessage.contains("overflow")) + } + } + + test("divide by num") { + var interval = new CalendarInterval(0, 0, 0) + assert(interval === divide(interval, 10)) + interval = new CalendarInterval(1, 3, 30 * MICROS_PER_SECOND) + assert(new CalendarInterval(0, 16, 12 * MICROS_PER_HOUR + 15 * MICROS_PER_SECOND) === + divide(interval, 2)) + assert(new CalendarInterval(2, 6, MICROS_PER_MINUTE) === divide(interval, 0.5)) + interval = new CalendarInterval(-1, 0, -30 * MICROS_PER_SECOND) + assert(new CalendarInterval(0, -15, -15 * MICROS_PER_SECOND) === divide(interval, 2)) + assert(new CalendarInterval(-2, 0, -1 * MICROS_PER_MINUTE) === divide(interval, 0.5)) + try { + divide(new CalendarInterval(123, 456, 789), 0) + fail("Expected to throw an exception on divide by zero") + } catch { + case e: ArithmeticException => + assert(e.getMessage.contains("overflow")) + } + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 0e22af1fbdf2..8aa173d4dbb6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -36,3 +36,8 @@ select date '2001-10-01' - 7; select date '2001-10-01' - date '2001-09-28'; select date'2020-01-01' - timestamp'2019-10-06 10:11:12.345678'; select timestamp'2019-10-06 10:11:12.345678' - date'2020-01-01'; + +-- interval operations +select 3 * (timestamp'2019-10-15 10:11:12.001002' - date'2019-10-15'); +select interval 4 month 2 weeks 3 microseconds * 1.5; +select (timestamp'2019-10-15' - timestamp'2019-10-14') / 1.5; diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index f85531196c20..781db39b876a 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 20 -- !query 0 @@ -145,3 +145,27 @@ select timestamp'2019-10-06 10:11:12.345678' - date'2020-01-01' struct -- !query 16 output interval -2078 hours -48 minutes -47.654322 seconds + + +-- !query 17 +select 3 * (timestamp'2019-10-15 10:11:12.001002' - date'2019-10-15') +-- !query 17 schema +struct +-- !query 17 output +interval 30 hours 33 minutes 36.003006 seconds + + +-- !query 18 +select interval 4 month 2 weeks 3 microseconds * 1.5 +-- !query 18 schema +struct +-- !query 18 output +interval 6 months 21 days 0.000005 seconds + + +-- !query 19 +select (timestamp'2019-10-15' - timestamp'2019-10-14') / 1.5 +-- !query 19 schema +struct +-- !query 19 output +interval 16 hours