From a5ade70afe7601db16ec24956f270feb4499ee42 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 25 May 2017 18:21:15 +0800 Subject: [PATCH 01/13] Support TRUNC (number) --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/datetimeExpressions.scala | 80 ---------- .../spark/sql/catalyst/expressions/misc.scala | 137 ++++++++++++++++++ .../sql/catalyst/util/BigDecimalUtils.scala | 56 +++++++ .../expressions/DateExpressionsSuite.scala | 21 --- .../expressions/MiscExpressionsSuite.scala | 48 ++++++ .../catalyst/util/BigDecimalUtilsSuite.scala | 33 +++++ .../org/apache/spark/sql/functions.scala | 2 +- .../resources/sql-tests/inputs/datetime.sql | 9 ++ .../resources/sql-tests/inputs/operators.sql | 6 +- .../sql-tests/results/datetime.sql.out | 30 +++- .../sql-tests/results/operators.sql.out | 19 ++- 12 files changed, 337 insertions(+), 106 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7521a7e12432c..e42b3f3961589 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -355,7 +355,6 @@ object FunctionRegistry { expression[ParseToDate]("to_date"), expression[ToUnixTimestamp]("to_unix_timestamp"), expression[ToUTCTimestamp]("to_utc_timestamp"), - expression[TruncDate]("trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), @@ -388,6 +387,7 @@ object FunctionRegistry { expression[CurrentDatabase]("current_database"), expression[CallMethodViaReflection]("reflect"), expression[CallMethodViaReflection]("java_method"), + expression[Trunc]("trunc"), // grouping sets expression[Cube]("cube"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 43ca2cff58825..17df4044e3bca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1227,86 +1227,6 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child: override def dataType: DataType = TimestampType } -/** - * Returns date truncated to the unit specified by the format. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.", - extended = """ - Examples: - > SELECT _FUNC_('2009-02-12', 'MM'); - 2009-02-01 - > SELECT _FUNC_('2015-10-27', 'YEAR'); - 2015-01-01 - """) -// scalastyle:on line.size.limit -case class TruncDate(date: Expression, format: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - override def left: Expression = date - override def right: Expression = format - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) - override def dataType: DataType = DateType - override def nullable: Boolean = true - override def prettyName: String = "trunc" - - private lazy val truncLevel: Int = - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) - - override def eval(input: InternalRow): Any = { - val level = if (format.foldable) { - truncLevel - } else { - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) - } - if (level == -1) { - // unknown format - null - } else { - val d = date.eval(input) - if (d == null) { - null - } else { - DateTimeUtils.truncDate(d.asInstanceOf[Int], level) - } - } - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - - if (format.foldable) { - if (truncLevel == -1) { - ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") - } else { - val d = date.genCode(ctx) - ev.copy(code = s""" - ${d.code} - boolean ${ev.isNull} = ${d.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $dtu.truncDate(${d.value}, $truncLevel); - }""") - } - } else { - nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { - val form = ctx.freshName("form") - s""" - int $form = $dtu.parseTruncLevel($fmt); - if ($form == -1) { - ${ev.isNull} = true; - } else { - ${ev.value} = $dtu.truncDate($dateVal, $form); - } - """ - }) - } - } -} - /** * Returns the number of days from startDate to endDate. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index bb9368cf6d774..bfddfd6307f90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.{BigDecimalUtils, DateTimeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Print the result of an expression to stderr (used for debugging codegen). @@ -104,3 +106,138 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { override def nullable: Boolean = false override def prettyName: String = "current_database" } + +/** + * Returns date truncated to the unit specified by the format or + * numeric truncated to scale decimal places. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(data, fmt) - Returns `data` truncated by the format model `fmt`. + If `data` is DateType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`. + If `data` is DoubleType, returns `data` truncated to `fmt` decimal places. + """, + extended = """ + Examples: + > SELECT _FUNC_('2009-02-12', 'MM'); + 2009-02-01 + > SELECT _FUNC_('2015-10-27', 'YEAR'); + 2015-01-01 + > SELECT _FUNC_(1234567891.1234567891, 4); + 1234567891.1234 + > SELECT _FUNC_(1234567891.1234567891, -4); + 1234560000 + """) +// scalastyle:on line.size.limit +case class Trunc(data: Expression, format: Expression = Literal(0)) + extends BinaryExpression with ImplicitCastInputTypes { + + def this(numeric: Expression) = { + this(numeric, Literal(0)) + } + + override def left: Expression = data + override def right: Expression = format + + override def dataType: DataType = data.dataType + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(DoubleType, DateType), TypeCollection(StringType, IntegerType)) + + override def nullable: Boolean = true + override def prettyName: String = "trunc" + + private lazy val truncFormat: Int = dataType match { + case doubleType: DoubleType => + format.eval().asInstanceOf[Int] + case dateType: DateType => + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + } + + override def eval(input: InternalRow): Any = { + val d = data.eval(input) + val form = format.eval() + if (null == d || null == form) { + null + } else { + dataType match { + case doubleType: DoubleType => + val scale = if (format.foldable) { + truncFormat + } else { + format.eval().asInstanceOf[Int] + } + BigDecimalUtils.trunc(d.asInstanceOf[Double], scale).doubleValue() + case dateType: DateType => + val level = if (format.foldable) { + truncFormat + } else { + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + } + if (level == -1) { + // unknown format + null + } else { + DateTimeUtils.truncDate(d.asInstanceOf[Int], level) + } + } + } + + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + + dataType match { + case doubleType: DoubleType => + val bdu = BigDecimalUtils.getClass.getName.stripSuffix("$") + + if (format.foldable) { + val d = data.genCode(ctx) + ev.copy(code = s""" + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $bdu.trunc(${d.value}, $truncFormat).doubleValue(); + }""") + } else { + nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => { + s"${ev.value} = $bdu.trunc($doubleVal, $fmt).doubleValue();" + }) + } + case dateType: DateType => + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + + if (format.foldable) { + if (truncFormat == -1) { + ev.copy(code = s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + } else { + val d = data.genCode(ctx) + ev.copy(code = s""" + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $dtu.truncDate(${d.value}, $truncFormat); + }""") + } + } else { + nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { + val form = ctx.freshName("form") + s""" + int $form = $dtu.parseTruncLevel($fmt); + if ($form == -1) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate($dateVal, $form); + } + """ + }) + } + } + + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala new file mode 100644 index 0000000000000..c7d26b534ace9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.math.{BigDecimal => JBigDecimal} + +/** + * Helper functions for BigDecimal. + */ +object BigDecimalUtils { + + /** + * Returns double type input truncated to scale decimal places. + */ + def trunc(input: Double, scale: Int): JBigDecimal = { + trunc(JBigDecimal.valueOf(input), scale) + } + + /** + * Returns BigDecimal type input truncated to scale decimal places. + */ + def trunc(input: JBigDecimal, scale: Int): JBigDecimal = { + + val pow = if (scale >= 0) { + JBigDecimal.valueOf(Math.pow(10, scale)) + } else { + JBigDecimal.valueOf(Math.pow(10, Math.abs(scale))) + } + + val out = if (scale > 0) { + val longValue = input.multiply(pow).longValue() + JBigDecimal.valueOf(longValue).divide(pow) + } else if (scale == 0) { + JBigDecimal.valueOf(input.longValue()) + } else { + val longValue = input.divide(pow).longValue() + JBigDecimal.valueOf(longValue).multiply(pow) + } + out + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 4ce68538c87a1..c409b80f176ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -513,27 +513,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) } - test("function trunc") { - def testTrunc(input: Date, fmt: String, expected: Date): Unit = { - checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)), - expected) - checkEvaluation( - TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)), - expected) - } - val date = Date.valueOf("2015-07-22") - Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt => - testTrunc(date, fmt, Date.valueOf("2015-01-01")) - } - Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => - testTrunc(date, fmt, Date.valueOf("2015-07-01")) - } - testTrunc(date, "DD", null) - testTrunc(date, null, null) - testTrunc(null, "MON", null) - testTrunc(null, null, null) - } - test("from_unixtime") { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index a26d070a99c52..790dfc5d4e986 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Date + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -39,4 +41,50 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null) } + test("trunc") { + + // numeric + def testTruncNumber(input: Double, fmt: Int, expected: Double): Unit = { + checkEvaluation(Trunc(Literal.create(input, DoubleType), + Literal.create(fmt, IntegerType)), + expected) + checkEvaluation(Trunc(Literal.create(input, DoubleType), + NonFoldableLiteral.create(fmt, IntegerType)), + expected) + } + + testTruncNumber(1234567891.1234567891, 4, 1234567891.1234) + testTruncNumber(1234567891.1234567891, -4, 1234560000) + testTruncNumber(1234567891.1234567891, 0, 1234567891) + + checkEvaluation(Trunc(Literal.create(1D, DoubleType), + NonFoldableLiteral.create(null, IntegerType)), + null) + checkEvaluation(Trunc(Literal.create(null, DoubleType), + NonFoldableLiteral.create(1, IntegerType)), + null) + checkEvaluation(Trunc(Literal.create(null, DoubleType), + NonFoldableLiteral.create(null, IntegerType)), + null) + + // date + def testTruncDate(input: Date, fmt: String, expected: Date): Unit = { + checkEvaluation(Trunc(Literal.create(input, DateType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + Trunc(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + val date = Date.valueOf("2015-07-22") + Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt => + testTruncDate(date, fmt, Date.valueOf("2015-01-01")) + } + Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => + testTruncDate(date, fmt, Date.valueOf("2015-07-01")) + } + testTruncDate(date, "DD", null) + testTruncDate(date, null, null) + testTruncDate(null, "MON", null) + testTruncDate(null, null, null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala new file mode 100644 index 0000000000000..9fb7134f24da4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.BigDecimalUtils._ + +class BigDecimalUtilsSuite extends SparkFunSuite { + + test("trunc number") { + val bg = 1234567891.1234567891D + assert(trunc(bg, 4) === 1234567891.1234) + assert(trunc(bg, -4) === 1234560000) + assert(trunc(bg, 0) === 1234567891) + assert(trunc(12345.5554f, 0) === 1234567891) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 36c0f18b6e2e3..11612535e8c91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2733,7 +2733,7 @@ object functions { * @since 1.5.0 */ def trunc(date: Column, format: String): Column = withExpr { - TruncDate(date.expr, Literal(format)) + Trunc(date.expr, Literal(format)) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index e957f693a983f..79575c0e5aca9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -6,3 +6,12 @@ select current_date = current_date(), current_timestamp = current_timestamp(); select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd'); select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd'); + +-- trunc date +select trunc(to_date('2015-07-22'), 'yyyy'), trunc(to_date('2015-07-22'), 'YYYY'), + trunc(to_date('2015-07-22'), 'year'), trunc(to_date('2015-07-22'), 'YEAR'), + trunc(to_date('2015-07-22'), 'yy'), trunc(to_date('2015-07-22'), 'YY'); +select trunc(to_date('2015-07-22'), 'month'), trunc(to_date('2015-07-22'), 'MONTH'), + trunc(to_date('2015-07-22'), 'mon'), trunc(to_date('2015-07-22'), 'MON'), + trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM'); +select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index f7167472b05c6..261fdc000b398 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -1,4 +1,3 @@ - -- unary minus and plus select -100; select +230; @@ -73,3 +72,8 @@ select floor(0); select floor(1); select floor(1234567890123456); select floor(12345678901234567); + +-- trunc number +select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), + trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891); +select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 13e1e48b038ad..4f3faa94b8454 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 3 +-- Number of queries: 6 -- !query 0 @@ -24,3 +24,31 @@ select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('20 struct -- !query 2 output NULL 2016-12-31 00:12:00 2016-12-31 00:00:00 + + +-- !query 3 +select trunc(to_date('2015-07-22'), 'yyyy'), trunc(to_date('2015-07-22'), 'YYYY'), + trunc(to_date('2015-07-22'), 'year'), trunc(to_date('2015-07-22'), 'YEAR'), + trunc(to_date('2015-07-22'), 'yy'), trunc(to_date('2015-07-22'), 'YY') +-- !query 3 schema +struct +-- !query 3 output +2015-01-01 2015-01-01 2015-01-01 2015-01-01 2015-01-01 2015-01-01 + + +-- !query 4 +select trunc(to_date('2015-07-22'), 'month'), trunc(to_date('2015-07-22'), 'MONTH'), + trunc(to_date('2015-07-22'), 'mon'), trunc(to_date('2015-07-22'), 'MON'), + trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM') +-- !query 4 schema +struct +-- !query 4 output +2015-07-01 2015-07-01 2015-07-01 2015-07-01 2015-07-01 2015-07-01 + + +-- !query 5 +select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null) +-- !query 5 schema +struct +-- !query 5 output +NULL NULL NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fe52005aa91da..3475e91c42ef2 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 38 +-- Number of queries: 50 -- !query 0 @@ -396,3 +396,20 @@ select floor(12345678901234567) struct -- !query 47 output 12345678901234567 + + +-- !query 48 +select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), + trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891) +-- !query 48 schema +struct +-- !query 48 output +1.2345678911234E9 1.23456E9 1.2345678911234E9 1.234567891E9 1.234567891E9 + + +-- !query 49 +select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null) +-- !query 49 schema +struct +-- !query 49 output +NULL NULL NULL From c63856b666e0992bf464c726f78b02e2521a41b5 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 25 May 2017 22:47:08 +0800 Subject: [PATCH 02/13] Fix failed tests --- .../org/apache/spark/sql/catalyst/expressions/misc.scala | 6 +++--- .../apache/spark/sql/catalyst/util/BigDecimalUtils.scala | 4 ++-- .../spark/sql/catalyst/util/BigDecimalUtilsSuite.scala | 1 - 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index bfddfd6307f90..adcc9f475feb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -168,7 +168,7 @@ case class Trunc(data: Expression, format: Expression = Literal(0)) } else { format.eval().asInstanceOf[Int] } - BigDecimalUtils.trunc(d.asInstanceOf[Double], scale).doubleValue() + BigDecimalUtils.trunc(d.asInstanceOf[Double], scale) case dateType: DateType => val level = if (format.foldable) { truncFormat @@ -199,11 +199,11 @@ case class Trunc(data: Expression, format: Expression = Literal(0)) boolean ${ev.isNull} = ${d.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $bdu.trunc(${d.value}, $truncFormat).doubleValue(); + ${ev.value} = $bdu.trunc(${d.value}, $truncFormat); }""") } else { nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => { - s"${ev.value} = $bdu.trunc($doubleVal, $fmt).doubleValue();" + s"${ev.value} = $bdu.trunc($doubleVal, $fmt);" }) } case dateType: DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala index c7d26b534ace9..50de5ebd9e16a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala @@ -27,8 +27,8 @@ object BigDecimalUtils { /** * Returns double type input truncated to scale decimal places. */ - def trunc(input: Double, scale: Int): JBigDecimal = { - trunc(JBigDecimal.valueOf(input), scale) + def trunc(input: Double, scale: Int): Double = { + trunc(JBigDecimal.valueOf(input), scale).doubleValue() } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala index 9fb7134f24da4..87b66af34c0dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala @@ -27,7 +27,6 @@ class BigDecimalUtilsSuite extends SparkFunSuite { assert(trunc(bg, 4) === 1234567891.1234) assert(trunc(bg, -4) === 1234560000) assert(trunc(bg, 0) === 1234567891) - assert(trunc(12345.5554f, 0) === 1234567891) } } From 224c8672a666a1cdded844f7d0a160c4ce77f6a4 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 26 May 2017 08:09:01 +0800 Subject: [PATCH 03/13] Add test for python api --- python/pyspark/sql/functions.py | 13 ++++++++++--- .../main/scala/org/apache/spark/sql/functions.scala | 12 ++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d9b86aff63fa0..33a23b1eaec13 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1028,17 +1028,24 @@ def to_timestamp(col, format=None): @since(1.5) -def trunc(date, format): +def trunc(date, format=0): """ Returns date truncated to the unit specified by the format. :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) - >>> df.select(trunc(df.d, 'year').alias('year')).collect() + >>> df.select(trunc(to_date(df.d), 'year').alias('year')).collect() [Row(year=datetime.date(1997, 1, 1))] - >>> df.select(trunc(df.d, 'mon').alias('month')).collect() + >>> df.select(trunc(to_date(df.d), 'mon').alias('month')).collect() [Row(month=datetime.date(1997, 2, 1))] + >>> df = spark.createDataFrame([(1234567891.1234567891,)], ['d']) + >>> df.select(trunc(df.d, 4).alias('positive')).collect() + [Row(positive=1234567891.1234)] + >>> df.select(trunc(df.d, -4).alias('negative')).collect() + [Row(negative=1234560000.0)] + >>> df.select(trunc(df.d).alias('zero')).collect() + [Row(negative=1234567891.0)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 11612535e8c91..91db7ba27a918 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2067,6 +2067,18 @@ object functions { */ def radians(columnName: String): Column = radians(Column(columnName)) + /** + * returns number truncated by specified decimal places. + * + * @param scale: 4. -4, 0 + * + * @group math_funcs + * @since 2.3.0 + */ + def trunc(db: Column, scale: Int = 0): Column = withExpr { + Trunc(db.expr, Literal(scale)) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Misc functions ////////////////////////////////////////////////////////////////////////////////////////////// From e7e6e5b436be61b174253ea6c520a8fe2ad87383 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 26 May 2017 08:12:28 +0800 Subject: [PATCH 04/13] Update comment. --- python/pyspark/sql/functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 33a23b1eaec13..15f78b4e46095 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1030,7 +1030,8 @@ def to_timestamp(col, format=None): @since(1.5) def trunc(date, format=0): """ - Returns date truncated to the unit specified by the format. + Returns date truncated to the unit specified by the format or + number truncated by specified decimal places.. :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' From 7157820ae48528059e8008535a37ae18b6df5530 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 26 May 2017 10:37:48 +0800 Subject: [PATCH 05/13] Fix PySpark unit test issue --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 15f78b4e46095..58daa6204e979 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1046,7 +1046,7 @@ def trunc(date, format=0): >>> df.select(trunc(df.d, -4).alias('negative')).collect() [Row(negative=1234560000.0)] >>> df.select(trunc(df.d).alias('zero')).collect() - [Row(negative=1234567891.0)] + [Row(zero=1234567891.0)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) From 3d92a48a54c5bf0222c032fc5c205ea1792630a2 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 14 Jun 2017 23:17:38 +0800 Subject: [PATCH 06/13] Refactor code. --- python/pyspark/sql/functions.py | 8 +- .../spark/sql/catalyst/expressions/misc.scala | 132 ++++++++++-------- .../expressions/MiscExpressionsSuite.scala | 1 - .../resources/sql-tests/inputs/datetime.sql | 10 +- .../sql-tests/results/datetime.sql.out | 32 ++++- .../sql-tests/results/operators.sql.out | 19 ++- 6 files changed, 128 insertions(+), 74 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 58daa6204e979..7bb46f44848da 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1028,17 +1028,17 @@ def to_timestamp(col, format=None): @since(1.5) -def trunc(date, format=0): +def trunc(date, format): """ Returns date truncated to the unit specified by the format or - number truncated by specified decimal places.. + number truncated by specified decimal places. :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) - >>> df.select(trunc(to_date(df.d), 'year').alias('year')).collect() + >>> df.select(trunc(df.d, 'year').alias('year')).collect() [Row(year=datetime.date(1997, 1, 1))] - >>> df.select(trunc(to_date(df.d), 'mon').alias('month')).collect() + >>> df.select(trunc(df.d, 'mon').alias('month')).collect() [Row(month=datetime.date(1997, 2, 1))] >>> df = spark.createDataFrame([(1234567891.1234567891,)], ['d']) >>> df.select(trunc(df.d, 4).alias('positive')).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 88152bc2bb5b5..ae4ffec971473 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -141,45 +141,57 @@ case class Uuid() extends LeafExpression { // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(data, fmt) - Returns `data` truncated by the format model `fmt`. - If `data` is DateType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`. - If `data` is DoubleType, returns `data` truncated to `fmt` decimal places. + _FUNC_(data[, fmt]) - Returns `data` truncated by the format model `fmt`. + If `data` is DateType/StringType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`. + If `data` is DecimalType/DoubleType, returns `data` truncated to `fmt` decimal places. """, extended = """ Examples: > SELECT _FUNC_('2009-02-12', 'MM'); - 2009-02-01 + 2009-02-01. > SELECT _FUNC_('2015-10-27', 'YEAR'); 2015-01-01 + > SELECT _FUNC_('2015-10-27'); + 2015-10-01 > SELECT _FUNC_(1234567891.1234567891, 4); 1234567891.1234 > SELECT _FUNC_(1234567891.1234567891, -4); 1234560000 - """) + > SELECT _FUNC_(1234567891.1234567891); + 1234567891 + """) // scalastyle:on line.size.limit -case class Trunc(data: Expression, format: Expression = Literal(0)) +case class Trunc(data: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { - def this(numeric: Expression) = { - this(numeric, Literal(0)) + def this(data: Expression) = { + this(data, Literal( + if (data.dataType.isInstanceOf[DecimalType] || data.dataType.isInstanceOf[DoubleType]) { + 0 + } else { + "MM" + })) } override def left: Expression = data override def right: Expression = format + val isTruncNumber = format.dataType.isInstanceOf[IntegerType] + override def dataType: DataType = data.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DateType), TypeCollection(StringType, IntegerType)) + Seq(TypeCollection(DateType, StringType, DoubleType, DecimalType), + TypeCollection(StringType, IntegerType)) override def nullable: Boolean = true + override def prettyName: String = "trunc" - private lazy val truncFormat: Int = dataType match { - case doubleType: DoubleType => - format.eval().asInstanceOf[Int] - case dateType: DateType => - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + private lazy val truncFormat: Int = if (isTruncNumber) { + format.eval().asInstanceOf[Int] + } else { + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) } override def eval(input: InternalRow): Any = { @@ -188,73 +200,70 @@ case class Trunc(data: Expression, format: Expression = Literal(0)) if (null == d || null == form) { null } else { - dataType match { - case doubleType: DoubleType => - val scale = if (format.foldable) { - truncFormat - } else { - format.eval().asInstanceOf[Int] - } - BigDecimalUtils.trunc(d.asInstanceOf[Double], scale) - case dateType: DateType => - val level = if (format.foldable) { - truncFormat - } else { - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) - } - if (level == -1) { - // unknown format - null - } else { - DateTimeUtils.truncDate(d.asInstanceOf[Int], level) - } + if (isTruncNumber) { + val scale = if (format.foldable) truncFormat else format.eval().asInstanceOf[Int] + data.dataType match { + case DoubleType => BigDecimalUtils.trunc(d.asInstanceOf[Double], scale) + case DecimalType.Fixed(_, _) => + BigDecimalUtils.trunc(d.asInstanceOf[Decimal].toJavaBigDecimal, scale) + } + } else { + val level = if (format.foldable) { + truncFormat + } else { + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + } + if (level == -1) { + // unknown format + null + } else { + DateTimeUtils.truncDate(d.asInstanceOf[Int], level) + } } } - } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - dataType match { - case doubleType: DoubleType => - val bdu = BigDecimalUtils.getClass.getName.stripSuffix("$") + if (isTruncNumber) { + val bdu = BigDecimalUtils.getClass.getName.stripSuffix("$") - if (format.foldable) { - val d = data.genCode(ctx) - ev.copy(code = s""" + if (format.foldable) { + val d = data.genCode(ctx) + ev.copy(code = s""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $bdu.trunc(${d.value}, $truncFormat); }""") - } else { - nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => { - s"${ev.value} = $bdu.trunc($doubleVal, $fmt);" - }) - } - case dateType: DateType => - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + } else { + nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => { + s"${ev.value} = $bdu.trunc($doubleVal, $fmt);" + }) + } + } else { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - if (format.foldable) { - if (truncFormat == -1) { - ev.copy(code = s""" + if (format.foldable) { + if (truncFormat == -1) { + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") - } else { - val d = data.genCode(ctx) - ev.copy(code = s""" + } else { + val d = data.genCode(ctx) + ev.copy(code = s""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.truncDate(${d.value}, $truncFormat); }""") - } - } else { - nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { - val form = ctx.freshName("form") - s""" + } + } else { + nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { + val form = ctx.freshName("form") + s""" int $form = $dtu.parseTruncLevel($fmt); if ($form == -1) { ${ev.isNull} = true; @@ -262,9 +271,8 @@ case class Trunc(data: Expression, format: Expression = Literal(0)) ${ev.value} = $dtu.truncDate($dateVal, $form); } """ - }) - } + }) + } } - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 0797f8cbc8046..dcf58526b757a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -47,7 +47,6 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("trunc") { - // numeric def testTruncNumber(input: Double, fmt: Int, expected: Double): Unit = { checkEvaluation(Trunc(Literal.create(input, DoubleType), diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 1bc072720c0d9..f9f8351f08ded 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -10,10 +10,10 @@ select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('20 select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15'); -- trunc date -select trunc(to_date('2015-07-22'), 'yyyy'), trunc(to_date('2015-07-22'), 'YYYY'), - trunc(to_date('2015-07-22'), 'year'), trunc(to_date('2015-07-22'), 'YEAR'), - trunc(to_date('2015-07-22'), 'yy'), trunc(to_date('2015-07-22'), 'YY'); -select trunc(to_date('2015-07-22'), 'month'), trunc(to_date('2015-07-22'), 'MONTH'), - trunc(to_date('2015-07-22'), 'mon'), trunc(to_date('2015-07-22'), 'MON'), +select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'), + trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'), + trunc(to_date('2015-07-22'), 'yy'), trunc('2015-07-22', 'YY'); +select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), + trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM'); select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index a28b91c77324b..9f0e3176ac6ae 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 7 -- !query 0 @@ -32,3 +32,33 @@ select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27') struct -- !query 3 output 7 5 7 NULL 6 + + +-- !query 4 +select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'), + trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'), + trunc(to_date('2015-07-22'), 'yy'), trunc('2015-07-22', 'YY') +-- !query 4 schema +struct<> +-- !query 4 output +java.lang.ClassCastException +org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer + + +-- !query 5 +select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), + trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), + trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM') +-- !query 5 schema +struct<> +-- !query 5 output +java.lang.ClassCastException +org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer + + +-- !query 6 +select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null) +-- !query 6 schema +struct +-- !query 6 output +NULL NULL NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 51ccf764d952f..83dededf83764 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 51 +-- Number of queries: 53 -- !query 0 @@ -420,3 +420,20 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double> -- !query 50 output 1 NULL 0 NULL NULL NULL + + +-- !query 51 +select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), + trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891) +-- !query 51 schema +struct +-- !query 51 output +1234567891.1234 1234560000 1234567891.1234 1234567891 1234567891 + + +-- !query 52 +select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null) +-- !query 52 schema +struct +-- !query 52 output +NULL NULL NULL From b391b6a3e51229e501982fe184685d7c2e185172 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 15 Jun 2017 12:48:25 +0800 Subject: [PATCH 07/13] Add comment. --- python/pyspark/sql/functions.py | 6 +++--- .../spark/sql/catalyst/expressions/misc.scala | 15 +++++---------- .../spark/sql/catalyst/util/BigDecimalUtils.scala | 3 ++- .../test/resources/sql-tests/inputs/operators.sql | 3 +-- .../resources/sql-tests/results/datetime.sql.out | 12 +++++------- .../resources/sql-tests/results/operators.sql.out | 7 +++---- 6 files changed, 19 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7bb46f44848da..18b2ce0c5dbe7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1028,7 +1028,7 @@ def to_timestamp(col, format=None): @since(1.5) -def trunc(date, format): +def trunc(data, format): """ Returns date truncated to the unit specified by the format or number truncated by specified decimal places. @@ -1045,11 +1045,11 @@ def trunc(date, format): [Row(positive=1234567891.1234)] >>> df.select(trunc(df.d, -4).alias('negative')).collect() [Row(negative=1234560000.0)] - >>> df.select(trunc(df.d).alias('zero')).collect() + >>> df.select(trunc(df.d, 0).alias('zero')).collect() [Row(zero=1234567891.0)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) + return Column(sc._jvm.functions.trunc(_to_java_column(data), format)) @since(1.5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index ae4ffec971473..82babc9be0798 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -142,7 +142,7 @@ case class Uuid() extends LeafExpression { @ExpressionDescription( usage = """ _FUNC_(data[, fmt]) - Returns `data` truncated by the format model `fmt`. - If `data` is DateType/StringType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`. + If `data` is DateType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`. If `data` is DecimalType/DoubleType, returns `data` truncated to `fmt` decimal places. """, extended = """ @@ -151,8 +151,8 @@ case class Uuid() extends LeafExpression { 2009-02-01. > SELECT _FUNC_('2015-10-27', 'YEAR'); 2015-01-01 - > SELECT _FUNC_('2015-10-27'); - 2015-10-01 + > SELECT _FUNC_('1989-03-13'); + 1989-03-01 > SELECT _FUNC_(1234567891.1234567891, 4); 1234567891.1234 > SELECT _FUNC_(1234567891.1234567891, -4); @@ -165,12 +165,7 @@ case class Trunc(data: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { def this(data: Expression) = { - this(data, Literal( - if (data.dataType.isInstanceOf[DecimalType] || data.dataType.isInstanceOf[DoubleType]) { - 0 - } else { - "MM" - })) + this(data, Literal(if (data.dataType.isInstanceOf[DateType]) "MM" else 0)) } override def left: Expression = data @@ -181,7 +176,7 @@ case class Trunc(data: Expression, format: Expression) override def dataType: DataType = data.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DateType, StringType, DoubleType, DecimalType), + Seq(TypeCollection(DateType, DoubleType, DecimalType), TypeCollection(StringType, IntegerType)) override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala index 50de5ebd9e16a..931ed20ca75b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala @@ -35,7 +35,8 @@ object BigDecimalUtils { * Returns BigDecimal type input truncated to scale decimal places. */ def trunc(input: JBigDecimal, scale: Int): JBigDecimal = { - + // Copy from (https://github.com/apache/hive/blob/release-2.3.0-rc0 + // /ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java#L471-L487) val pow = if (scale >= 0) { JBigDecimal.valueOf(Math.pow(10, scale)) } else { diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index d1c4fea496b99..b7509c01cbe2b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -81,6 +81,5 @@ select 1 > 0.00001; select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null); -- trunc number -select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), - trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891); +select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891); select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 9f0e3176ac6ae..378f4f3fcf552 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -39,10 +39,9 @@ select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'), trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'), trunc(to_date('2015-07-22'), 'yy'), trunc('2015-07-22', 'YY') -- !query 4 schema -struct<> +struct -- !query 4 output -java.lang.ClassCastException -org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer +2015-01-01 2015-01-01 2015-01-01 2015-01-01 2015-01-01 2015-01-01 -- !query 5 @@ -50,15 +49,14 @@ select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM') -- !query 5 schema -struct<> +struct -- !query 5 output -java.lang.ClassCastException -org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer +2015-07-01 2015-07-01 2015-07-01 2015-07-01 2015-07-01 2015-07-01 -- !query 6 select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null) -- !query 6 schema -struct +struct -- !query 6 output NULL NULL NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 83dededf83764..fb36e6c6a89d9 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -423,12 +423,11 @@ struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NUL -- !query 51 -select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), - trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891) +select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891) -- !query 51 schema -struct +struct -- !query 51 output -1234567891.1234 1234560000 1234567891.1234 1234567891 1234567891 +1234567891.1234 1234560000 1234567891 1234567891 -- !query 52 From d40a46fbbd3c69f590ca5f0d604da84d218fef2e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 22 Jun 2017 18:00:46 +0800 Subject: [PATCH 08/13] Redefine inputTypes --- python/pyspark/sql/functions.py | 4 +- .../spark/sql/catalyst/expressions/misc.scala | 77 +++++++++++-------- .../resources/sql-tests/inputs/datetime.sql | 13 ++-- .../resources/sql-tests/inputs/operators.sql | 5 +- .../sql-tests/results/datetime.sql.out | 30 +++++--- .../sql-tests/results/operators.sql.out | 31 +++++++- 6 files changed, 106 insertions(+), 54 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 975ba6c721918..0e39ea6a360a9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1036,9 +1036,9 @@ def trunc(data, format): :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) - >>> df.select(trunc(df.d, 'year').alias('year')).collect() + >>> df.select(trunc(to_date(df.d), 'year').alias('year')).collect() [Row(year=datetime.date(1997, 1, 1))] - >>> df.select(trunc(df.d, 'mon').alias('month')).collect() + >>> df.select(trunc(to_date(df.d), 'mon').alias('month')).collect() [Row(month=datetime.date(1997, 2, 1))] >>> df = spark.createDataFrame([(1234567891.1234567891,)], ['d']) >>> df.select(trunc(df.d, 4).alias('positive')).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 82babc9be0798..4aa7134c56d3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -162,7 +162,7 @@ case class Uuid() extends LeafExpression { """) // scalastyle:on line.size.limit case class Trunc(data: Expression, format: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ExpectsInputTypes { def this(data: Expression) = { this(data, Literal(if (data.dataType.isInstanceOf[DateType]) "MM" else 0)) @@ -171,22 +171,32 @@ case class Trunc(data: Expression, format: Expression) override def left: Expression = data override def right: Expression = format - val isTruncNumber = format.dataType.isInstanceOf[IntegerType] - override def dataType: DataType = data.dataType - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DateType, DoubleType, DecimalType), + override def inputTypes: Seq[AbstractDataType] = dataType match { + case NullType => Seq(dataType, TypeCollection(StringType, IntegerType)) + case DateType => Seq(dataType, StringType) + case DoubleType | DecimalType.Fixed(_, _) => Seq(dataType, IntegerType) + case _ => Seq(TypeCollection(DateType, DoubleType, DecimalType), TypeCollection(StringType, IntegerType)) + } override def nullable: Boolean = true override def prettyName: String = "trunc" + private val isTruncNumber = + (dataType.isInstanceOf[DoubleType] || dataType.isInstanceOf[DecimalType]) && + format.dataType.isInstanceOf[IntegerType] + private val isTruncDate = + dataType.isInstanceOf[DateType] && format.dataType.isInstanceOf[StringType] + private lazy val truncFormat: Int = if (isTruncNumber) { format.eval().asInstanceOf[Int] - } else { + } else if (isTruncDate) { DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + } else { + 0 } override def eval(input: InternalRow): Any = { @@ -202,7 +212,7 @@ case class Trunc(data: Expression, format: Expression) case DecimalType.Fixed(_, _) => BigDecimalUtils.trunc(d.asInstanceOf[Decimal].toJavaBigDecimal, scale) } - } else { + } else if (isTruncDate) { val level = if (format.foldable) { truncFormat } else { @@ -214,6 +224,8 @@ case class Trunc(data: Expression, format: Expression) } else { DateTimeUtils.truncDate(d.asInstanceOf[Int], level) } + } else { + null } } } @@ -226,48 +238,49 @@ case class Trunc(data: Expression, format: Expression) if (format.foldable) { val d = data.genCode(ctx) ev.copy(code = s""" - ${d.code} - boolean ${ev.isNull} = ${d.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $bdu.trunc(${d.value}, $truncFormat); - }""") + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $bdu.trunc(${d.value}, $truncFormat); + }""") } else { - nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => { - s"${ev.value} = $bdu.trunc($doubleVal, $fmt);" - }) + nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => s"${ev.value} = $bdu.trunc($doubleVal, $fmt);") } - } else { + } else if (isTruncDate) { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (format.foldable) { if (truncFormat == -1) { ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + """) } else { val d = data.genCode(ctx) ev.copy(code = s""" - ${d.code} - boolean ${ev.isNull} = ${d.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $dtu.truncDate(${d.value}, $truncFormat); - }""") + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $dtu.truncDate(${d.value}, $truncFormat); + }""") } } else { nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { val form = ctx.freshName("form") s""" - int $form = $dtu.parseTruncLevel($fmt); - if ($form == -1) { - ${ev.isNull} = true; - } else { - ${ev.value} = $dtu.truncDate($dateVal, $form); - } - """ + int $form = $dtu.parseTruncLevel($fmt); + if ($form == -1) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate($dateVal, $form); + } + """ }) } + } else { + nullSafeCodeGen(ctx, ev, (dateVal, fmt) => s"${ev.isNull} = true;") } } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index f9f8351f08ded..99cec3af3b39b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -10,10 +10,11 @@ select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('20 select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15'); -- trunc date -select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'), - trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'), - trunc(to_date('2015-07-22'), 'yy'), trunc('2015-07-22', 'YY'); -select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), - trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), +select trunc(to_date('2015-07-22'), 'yyyy'), trunc(to_date('2015-07-22'), 'YYYY'), + trunc(to_date('2015-07-22'), 'year'), trunc(to_date('2015-07-22'), 'YEAR'), + trunc(to_date('2015-07-22'), 'yy'), trunc(to_date('2015-07-22'), 'YY'); +select trunc(to_date('2015-07-22'), 'month'), trunc(to_date('2015-07-22'), 'MONTH'), + trunc(to_date('2015-07-22'), 'mon'), trunc(to_date('2015-07-22'), 'MON'), trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM'); -select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null); +select trunc(to_date('2015-07-22'), 'DD'), trunc(to_date('2015-07-22'), null); +select trunc(null, 'MON'), trunc(null, null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 58266f46108d3..b78693c77605d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -92,6 +92,9 @@ select abs(-3.13), abs('-2.19'); -- positive/negative select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11); --- trunc number +-- trunc select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891); select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null); +select trunc(1234567891.1234567891, 'yyyy'); +select trunc(to_date('2015-07-22'), 4); +select trunc('2015-07-22', 4); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 378f4f3fcf552..0a609c75ce537 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 8 -- !query 0 @@ -35,28 +35,36 @@ struct +struct -- !query 4 output 2015-01-01 2015-01-01 2015-01-01 2015-01-01 2015-01-01 2015-01-01 -- !query 5 -select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), - trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), +select trunc(to_date('2015-07-22'), 'month'), trunc(to_date('2015-07-22'), 'MONTH'), + trunc(to_date('2015-07-22'), 'mon'), trunc(to_date('2015-07-22'), 'MON'), trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM') -- !query 5 schema -struct +struct -- !query 5 output 2015-07-01 2015-07-01 2015-07-01 2015-07-01 2015-07-01 2015-07-01 -- !query 6 -select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null) +select trunc(to_date('2015-07-22'), 'DD'), trunc(to_date('2015-07-22'), null) -- !query 6 schema -struct +struct -- !query 6 output -NULL NULL NULL NULL +NULL NULL + + +-- !query 7 +select trunc(null, 'MON'), trunc(null, null) +-- !query 7 schema +struct +-- !query 7 output +NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fbb63ea247beb..def1772058efd 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 59 +-- Number of queries: 62 -- !query 0 @@ -481,6 +481,33 @@ struct +struct -- !query 58 output NULL NULL NULL + + +-- !query 59 +select trunc(1234567891.1234567891, 'yyyy') +-- !query 59 schema +struct<> +-- !query 59 output +org.apache.spark.sql.AnalysisException +cannot resolve 'trunc(1234567891.1234567891BD, 'yyyy')' due to data type mismatch: argument 2 requires int type, however, ''yyyy'' is of string type.; line 1 pos 7 + + +-- !query 60 +select trunc(to_date('2015-07-22'), 4) +-- !query 60 schema +struct<> +-- !query 60 output +org.apache.spark.sql.AnalysisException +cannot resolve 'trunc(to_date('2015-07-22'), 4)' due to data type mismatch: argument 2 requires string type, however, '4' is of int type.; line 1 pos 7 + + +-- !query 61 +select trunc('2015-07-22', 4) +-- !query 61 schema +struct<> +-- !query 61 output +org.apache.spark.sql.AnalysisException +cannot resolve 'trunc('2015-07-22', 4)' due to data type mismatch: argument 1 requires (date or double or decimal) type, however, ''2015-07-22'' is of string type.; line 1 pos 7 From 88d1e38c8fe55c52aaa64754b66c56472141a682 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 23 Jun 2017 10:22:01 +0800 Subject: [PATCH 09/13] Fix error tests --- .../test/scala/org/apache/spark/sql/DateFunctionsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 3a8694839bb24..d22b107a9fd84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -427,11 +427,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t") checkAnswer( - df.select(trunc(col("t"), "YY")), + df.select(trunc(to_date(col("t")), "YY")), Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01")))) checkAnswer( - df.selectExpr("trunc(t, 'Month')"), + df.selectExpr("trunc(to_date(t), 'Month')"), Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01")))) } From 7fee61b1e084a1ae9966e7ad62b1509085b24151 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 23 Jun 2017 13:41:07 +0800 Subject: [PATCH 10/13] Fix R error tests --- R/pkg/tests/fulltests/test_sparkSQL.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 911b73b9ee551..c4bdd38c77d3b 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1382,8 +1382,8 @@ test_that("column functions", { c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") c21 <- posexplode_outer(c) + explode_outer(c) c22 <- not(c) - c23 <- trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + - trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm") + c23 <- trunc(to_date(c), "year") + trunc(to_date(c), "yyyy") + trunc(to_date(c), "yy") + + trunc(to_date(c), "month") + trunc(to_date(c), "mon") + trunc(to_date(c), "mm") # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) From f8b1f4426836e01e755e73253cb3011d4d6d1bee Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 27 Jun 2017 18:54:24 +0800 Subject: [PATCH 11/13] Support timestamp and string type. --- R/pkg/tests/fulltests/test_sparkSQL.R | 4 +- python/pyspark/sql/functions.py | 15 +-- .../spark/sql/catalyst/expressions/misc.scala | 100 +++++++++++------- ...{BigDecimalUtils.scala => MathUtils.scala} | 2 +- .../expressions/MiscExpressionsSuite.scala | 30 +++--- ...lUtilsSuite.scala => MathUtilsSuite.scala} | 4 +- .../org/apache/spark/sql/functions.scala | 12 ++- .../resources/sql-tests/inputs/datetime.sql | 11 +- .../resources/sql-tests/inputs/operators.sql | 1 + .../sql-tests/results/datetime.sql.out | 30 ++++-- .../sql-tests/results/operators.sql.out | 15 ++- .../apache/spark/sql/DateFunctionsSuite.scala | 4 +- 12 files changed, 136 insertions(+), 92 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/{BigDecimalUtils.scala => MathUtils.scala} (98%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/{BigDecimalUtilsSuite.scala => MathUtilsSuite.scala} (90%) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index c4bdd38c77d3b..911b73b9ee551 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1382,8 +1382,8 @@ test_that("column functions", { c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") c21 <- posexplode_outer(c) + explode_outer(c) c22 <- not(c) - c23 <- trunc(to_date(c), "year") + trunc(to_date(c), "yyyy") + trunc(to_date(c), "yy") + - trunc(to_date(c), "month") + trunc(to_date(c), "mon") + trunc(to_date(c), "mm") + c23 <- trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + + trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm") # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index da2cbb74c8937..ebe7d572f2b4c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1028,17 +1028,18 @@ def to_timestamp(col, format=None): @since(1.5) -def trunc(data, format): +def trunc(data, truncParam): """ - Returns date truncated to the unit specified by the format or - number truncated by specified decimal places. + Returns date truncated to the unit specified by the truncParam or + numeric truncated by specified decimal places. - :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' + :param truncParam: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date + and any int for numeric. >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) - >>> df.select(trunc(to_date(df.d), 'year').alias('year')).collect() + >>> df.select(trunc(df.d, 'year').alias('year')).collect() [Row(year=datetime.date(1997, 1, 1))] - >>> df.select(trunc(to_date(df.d), 'mon').alias('month')).collect() + >>> df.select(trunc(df.d, 'mon').alias('month')).collect() [Row(month=datetime.date(1997, 2, 1))] >>> df = spark.createDataFrame([(1234567891.1234567891,)], ['d']) >>> df.select(trunc(df.d, 4).alias('positive')).collect() @@ -1049,7 +1050,7 @@ def trunc(data, format): [Row(zero=1234567891.0)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.trunc(_to_java_column(data), format)) + return Column(sc._jvm.functions.trunc(_to_java_column(data), truncParam)) @since(1.5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4aa7134c56d3a..17311a837c9db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,7 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{BigDecimalUtils, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, MathUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -141,9 +141,9 @@ case class Uuid() extends LeafExpression { // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(data[, fmt]) - Returns `data` truncated by the format model `fmt`. - If `data` is DateType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`. - If `data` is DecimalType/DoubleType, returns `data` truncated to `fmt` decimal places. + _FUNC_(data[, trunc_param]) - Returns `data` truncated by the format model `trunc_param`. + If `data` is date/timestamp/string type, returns `data` with the time portion of the day truncated to the unit specified by the format model `trunc_param`. If `trunc_param` is omitted, then the default `trunc_param` is 'MM'. + If `data` is decimal/double type, returns `data` truncated to `trunc_param` decimal places. If `trunc_param` is omitted, then the default `trunc_param` is 0. """, extended = """ Examples: @@ -161,68 +161,87 @@ case class Uuid() extends LeafExpression { 1234567891 """) // scalastyle:on line.size.limit -case class Trunc(data: Expression, format: Expression) +case class Trunc(data: Expression, truncExpr: Expression) extends BinaryExpression with ExpectsInputTypes { def this(data: Expression) = { - this(data, Literal(if (data.dataType.isInstanceOf[DateType]) "MM" else 0)) + this(data, Literal( + if (data.dataType.isInstanceOf[DateType] || + data.dataType.isInstanceOf[TimestampType] || + data.dataType.isInstanceOf[StringType]) { + "MM" + } else { + 0 + }) + ) } override def left: Expression = data - override def right: Expression = format - - override def dataType: DataType = data.dataType - - override def inputTypes: Seq[AbstractDataType] = dataType match { - case NullType => Seq(dataType, TypeCollection(StringType, IntegerType)) - case DateType => Seq(dataType, StringType) - case DoubleType | DecimalType.Fixed(_, _) => Seq(dataType, IntegerType) - case _ => Seq(TypeCollection(DateType, DoubleType, DecimalType), - TypeCollection(StringType, IntegerType)) + override def right: Expression = truncExpr + + private val isTruncNumber = truncExpr.dataType.isInstanceOf[IntegerType] + private val isTruncDate = truncExpr.dataType.isInstanceOf[StringType] + + override def dataType: DataType = if (isTruncDate) DateType else data.dataType + + override def inputTypes: Seq[AbstractDataType] = data.dataType match { + case NullType => + Seq(dataType, TypeCollection(StringType, IntegerType)) + case DateType | TimestampType | StringType => + Seq(TypeCollection(DateType, TimestampType, StringType), StringType) + case DoubleType | DecimalType.Fixed(_, _) => + Seq(TypeCollection(DoubleType, DecimalType), IntegerType) + case _ => + Seq(TypeCollection(DateType, StringType, TimestampType, DoubleType, DecimalType), + TypeCollection(StringType, IntegerType)) } override def nullable: Boolean = true override def prettyName: String = "trunc" - private val isTruncNumber = - (dataType.isInstanceOf[DoubleType] || dataType.isInstanceOf[DecimalType]) && - format.dataType.isInstanceOf[IntegerType] - private val isTruncDate = - dataType.isInstanceOf[DateType] && format.dataType.isInstanceOf[StringType] private lazy val truncFormat: Int = if (isTruncNumber) { - format.eval().asInstanceOf[Int] + truncExpr.eval().asInstanceOf[Int] } else if (isTruncDate) { - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + DateTimeUtils.parseTruncLevel(truncExpr.eval().asInstanceOf[UTF8String]) } else { 0 } override def eval(input: InternalRow): Any = { val d = data.eval(input) - val form = format.eval() - if (null == d || null == form) { + val truncParam = truncExpr.eval() + if (null == d || null == truncParam) { null } else { if (isTruncNumber) { - val scale = if (format.foldable) truncFormat else format.eval().asInstanceOf[Int] + val scale = if (truncExpr.foldable) truncFormat else truncExpr.eval().asInstanceOf[Int] data.dataType match { - case DoubleType => BigDecimalUtils.trunc(d.asInstanceOf[Double], scale) + case DoubleType => MathUtils.trunc(d.asInstanceOf[Double], scale) case DecimalType.Fixed(_, _) => - BigDecimalUtils.trunc(d.asInstanceOf[Decimal].toJavaBigDecimal, scale) + MathUtils.trunc(d.asInstanceOf[Decimal].toJavaBigDecimal, scale) } } else if (isTruncDate) { - val level = if (format.foldable) { + val level = if (truncExpr.foldable) { truncFormat } else { - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + DateTimeUtils.parseTruncLevel(truncExpr.eval().asInstanceOf[UTF8String]) } if (level == -1) { // unknown format null } else { - DateTimeUtils.truncDate(d.asInstanceOf[Int], level) + data.dataType match { + case DateType => DateTimeUtils.truncDate(d.asInstanceOf[Int], level) + case TimestampType => + val ts = DateTimeUtils.timestampToString(d.asInstanceOf[Long]) + val dt = DateTimeUtils.stringToDate(UTF8String.fromString(ts)) + if (dt.isDefined) DateTimeUtils.truncDate(dt.get, level) else null + case StringType => + val dt = DateTimeUtils.stringToDate(d.asInstanceOf[UTF8String]) + if (dt.isDefined) DateTimeUtils.truncDate(dt.get, level) else null + } } } else { null @@ -233,9 +252,9 @@ case class Trunc(data: Expression, format: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (isTruncNumber) { - val bdu = BigDecimalUtils.getClass.getName.stripSuffix("$") + val bdu = MathUtils.getClass.getName.stripSuffix("$") - if (format.foldable) { + if (truncExpr.foldable) { val d = data.genCode(ctx) ev.copy(code = s""" ${d.code} @@ -245,12 +264,13 @@ case class Trunc(data: Expression, format: Expression) ${ev.value} = $bdu.trunc(${d.value}, $truncFormat); }""") } else { - nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => s"${ev.value} = $bdu.trunc($doubleVal, $fmt);") + nullSafeCodeGen(ctx, ev, (doubleVal, truncParam) => + s"${ev.value} = $bdu.trunc($doubleVal, $truncParam);") } } else if (isTruncDate) { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - if (format.foldable) { + if (truncExpr.foldable) { if (truncFormat == -1) { ev.copy(code = s""" boolean ${ev.isNull} = true; @@ -268,19 +288,19 @@ case class Trunc(data: Expression, format: Expression) } } else { nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { - val form = ctx.freshName("form") + val truncParam = ctx.freshName("truncParam") s""" - int $form = $dtu.parseTruncLevel($fmt); - if ($form == -1) { + int $truncParam = $dtu.parseTruncLevel($fmt); + if ($truncParam == -1) { ${ev.isNull} = true; } else { - ${ev.value} = $dtu.truncDate($dateVal, $form); + ${ev.value} = $dtu.truncDate($dateVal, $truncParam); } """ }) } } else { - nullSafeCodeGen(ctx, ev, (dateVal, fmt) => s"${ev.isNull} = true;") + nullSafeCodeGen(ctx, ev, (dataVal, fmt) => s"${ev.isNull} = true;") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 931ed20ca75b6..cc826545fbd41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JBigDecimal} /** * Helper functions for BigDecimal. */ -object BigDecimalUtils { +object MathUtils { /** * Returns double type input truncated to scale decimal places. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index dcf58526b757a..6af21ca05cd6c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -46,9 +46,8 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(evaluate(Uuid()) !== evaluate(Uuid())) } - test("trunc") { - // numeric - def testTruncNumber(input: Double, fmt: Int, expected: Double): Unit = { + test("trunc numeric") { + def test(input: Double, fmt: Int, expected: Double): Unit = { checkEvaluation(Trunc(Literal.create(input, DoubleType), Literal.create(fmt, IntegerType)), expected) @@ -57,9 +56,11 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { expected) } - testTruncNumber(1234567891.1234567891, 4, 1234567891.1234) - testTruncNumber(1234567891.1234567891, -4, 1234560000) - testTruncNumber(1234567891.1234567891, 0, 1234567891) + test(1234567891.1234567891, 4, 1234567891.1234) + test(1234567891.1234567891, -4, 1234560000) + test(1234567891.1234567891, 0, 1234567891) + test(0.123, -1, 0) + test(0.123, 0, 0) checkEvaluation(Trunc(Literal.create(1D, DoubleType), NonFoldableLiteral.create(null, IntegerType)), @@ -70,9 +71,10 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Trunc(Literal.create(null, DoubleType), NonFoldableLiteral.create(null, IntegerType)), null) + } - // date - def testTruncDate(input: Date, fmt: String, expected: Date): Unit = { + test("trunc date") { + def test(input: Date, fmt: String, expected: Date): Unit = { checkEvaluation(Trunc(Literal.create(input, DateType), Literal.create(fmt, StringType)), expected) checkEvaluation( @@ -81,14 +83,14 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } val date = Date.valueOf("2015-07-22") Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt => - testTruncDate(date, fmt, Date.valueOf("2015-01-01")) + test(date, fmt, Date.valueOf("2015-01-01")) } Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => - testTruncDate(date, fmt, Date.valueOf("2015-07-01")) + test(date, fmt, Date.valueOf("2015-07-01")) } - testTruncDate(date, "DD", null) - testTruncDate(date, null, null) - testTruncDate(null, "MON", null) - testTruncDate(null, null, null) + test(date, "DD", null) + test(date, null, null) + test(null, "MON", null) + test(null, null, null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MathUtilsSuite.scala similarity index 90% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MathUtilsSuite.scala index 87b66af34c0dc..a3afe26bb408b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/BigDecimalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MathUtilsSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.BigDecimalUtils._ +import org.apache.spark.sql.catalyst.util.MathUtils._ -class BigDecimalUtilsSuite extends SparkFunSuite { +class MathUtilsSuite extends SparkFunSuite { test("trunc number") { val bg = 1234567891.1234567891D diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 49d4901c2e802..ee6b0abc02f8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2068,15 +2068,17 @@ object functions { def radians(columnName: String): Column = radians(Column(columnName)) /** - * returns number truncated by specified decimal places. - * - * @param scale: 4. -4, 0 + * Returns numeric truncated by specified decimal places. + * If scale is positive or 0, numeric is truncated to the absolute value of scale number + * of places to the right of the decimal point. + * If scale is negative, numeric is truncated to the absolute value of scale + 1 number + * of places to the left of the decimal point. * * @group math_funcs * @since 2.3.0 */ - def trunc(db: Column, scale: Int = 0): Column = withExpr { - Trunc(db.expr, Literal(scale)) + def trunc(numeric: Column, scale: Int): Column = withExpr { + Trunc(numeric.expr, Literal(scale)) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 99cec3af3b39b..740e98d9e6ecd 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -10,11 +10,12 @@ select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('20 select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15'); -- trunc date -select trunc(to_date('2015-07-22'), 'yyyy'), trunc(to_date('2015-07-22'), 'YYYY'), - trunc(to_date('2015-07-22'), 'year'), trunc(to_date('2015-07-22'), 'YEAR'), +select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'), + trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'), trunc(to_date('2015-07-22'), 'yy'), trunc(to_date('2015-07-22'), 'YY'); -select trunc(to_date('2015-07-22'), 'month'), trunc(to_date('2015-07-22'), 'MONTH'), - trunc(to_date('2015-07-22'), 'mon'), trunc(to_date('2015-07-22'), 'MON'), +select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), + trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM'); -select trunc(to_date('2015-07-22'), 'DD'), trunc(to_date('2015-07-22'), null); +select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null); +select trunc('2015-07-2200', 'DD'), trunc('123', null); select trunc(null, 'MON'), trunc(null, null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index b78693c77605d..f40450284ca45 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -98,3 +98,4 @@ select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null); select trunc(1234567891.1234567891, 'yyyy'); select trunc(to_date('2015-07-22'), 4); select trunc('2015-07-22', 4); +select trunc(false, 4); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 0a609c75ce537..21852396b424a 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 9 -- !query 0 @@ -35,36 +35,44 @@ struct +struct -- !query 4 output 2015-01-01 2015-01-01 2015-01-01 2015-01-01 2015-01-01 2015-01-01 -- !query 5 -select trunc(to_date('2015-07-22'), 'month'), trunc(to_date('2015-07-22'), 'MONTH'), - trunc(to_date('2015-07-22'), 'mon'), trunc(to_date('2015-07-22'), 'MON'), +select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), + trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM') -- !query 5 schema -struct +struct -- !query 5 output 2015-07-01 2015-07-01 2015-07-01 2015-07-01 2015-07-01 2015-07-01 -- !query 6 -select trunc(to_date('2015-07-22'), 'DD'), trunc(to_date('2015-07-22'), null) +select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null) -- !query 6 schema -struct +struct -- !query 6 output NULL NULL -- !query 7 -select trunc(null, 'MON'), trunc(null, null) +select trunc('2015-07-2200', 'DD'), trunc('123', null) -- !query 7 schema -struct +struct -- !query 7 output NULL NULL + + +-- !query 8 +select trunc(null, 'MON'), trunc(null, null) +-- !query 8 schema +struct +-- !query 8 output +NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index def1772058efd..8ec0bf8566af7 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 62 +-- Number of queries: 63 -- !query 0 @@ -481,7 +481,7 @@ struct +struct -- !query 58 output NULL NULL NULL @@ -510,4 +510,13 @@ select trunc('2015-07-22', 4) struct<> -- !query 61 output org.apache.spark.sql.AnalysisException -cannot resolve 'trunc('2015-07-22', 4)' due to data type mismatch: argument 1 requires (date or double or decimal) type, however, ''2015-07-22'' is of string type.; line 1 pos 7 +cannot resolve 'trunc('2015-07-22', 4)' due to data type mismatch: argument 2 requires string type, however, '4' is of int type.; line 1 pos 7 + + +-- !query 62 +select trunc(false, 4) +-- !query 62 schema +struct<> +-- !query 62 output +org.apache.spark.sql.AnalysisException +cannot resolve 'trunc(false, 4)' due to data type mismatch: argument 1 requires (date or string or timestamp or double or decimal) type, however, 'false' is of boolean type.; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index d22b107a9fd84..3a8694839bb24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -427,11 +427,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t") checkAnswer( - df.select(trunc(to_date(col("t")), "YY")), + df.select(trunc(col("t"), "YY")), Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01")))) checkAnswer( - df.selectExpr("trunc(to_date(t), 'Month')"), + df.selectExpr("trunc(t, 'Month')"), Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01")))) } From 3d40c366892303cd0de8259b31aebe7a748d89e6 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 2 Aug 2017 14:18:35 +0800 Subject: [PATCH 12/13] codegen support String and Timestamp type. --- .../spark/sql/catalyst/expressions/misc.scala | 70 +++++++++++++++---- .../expressions/MiscExpressionsSuite.scala | 52 +++++++++++--- 2 files changed, 98 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 17311a837c9db..8a6caf4f7fa39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -274,29 +274,69 @@ case class Trunc(data: Expression, truncExpr: Expression) if (truncFormat == -1) { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """) + int ${ev.value} = ${ctx.defaultValue(DateType)};""") } else { val d = data.genCode(ctx) - ev.copy(code = s""" + val dt = ctx.freshName("dt") + val pre = s""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $dtu.truncDate(${d.value}, $truncFormat); - }""") + int ${ev.value} = ${ctx.defaultValue(DateType)};""" + data.dataType match { + case DateType => + ev.copy(code = pre + s""" + if (!${ev.isNull}) { + ${ev.value} = $dtu.truncDate(${d.value}, $truncFormat); + }""") + case TimestampType => + val ts = ctx.freshName("ts") + ev.copy(code = pre + s""" + String $ts = $dtu.timestampToString(${d.value}); + scala.Option $dt = $dtu.stringToDate(UTF8String.fromString($ts)); + if (!${ev.isNull}) { + ${ev.value} = $dtu.truncDate((Integer)dt.get(), $truncFormat); + }""") + case StringType => + ev.copy(code = pre + s""" + scala.Option $dt = $dtu.stringToDate(${d.value}); + if (!${ev.isNull} && $dt.isDefined()) { + ${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncFormat); + }""") + } } } else { nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { val truncParam = ctx.freshName("truncParam") - s""" - int $truncParam = $dtu.parseTruncLevel($fmt); - if ($truncParam == -1) { - ${ev.isNull} = true; - } else { - ${ev.value} = $dtu.truncDate($dateVal, $truncParam); - } - """ + val dt = ctx.freshName("dt") + val pre = s"int $truncParam = $dtu.parseTruncLevel($fmt);" + data.dataType match { + case DateType => + pre + s""" + if ($truncParam == -1) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate($dateVal, $truncParam); + }""" + case TimestampType => + val ts = ctx.freshName("ts") + pre + s""" + String $ts = $dtu.timestampToString($dateVal); + scala.Option $dt = $dtu.stringToDate(UTF8String.fromString($ts)); + if ($truncParam == -1 || $dt.isEmpty()) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncParam); + }""" + case StringType => + pre + s""" + scala.Option $dt = $dtu.stringToDate($dateVal); + ${ev.value} = ${ctx.defaultValue(DateType)}; + if ($truncParam == -1 || $dt.isEmpty()) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncParam); + }""" + } }) } } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 6af21ca05cd6c..c65bc72f67fc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Date +import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -74,23 +74,57 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("trunc date") { - def test(input: Date, fmt: String, expected: Date): Unit = { + def testDate(input: Date, fmt: String, expected: Date): Unit = { checkEvaluation(Trunc(Literal.create(input, DateType), Literal.create(fmt, StringType)), expected) checkEvaluation( Trunc(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)), expected) } - val date = Date.valueOf("2015-07-22") + + def testString(input: String, fmt: String, expected: Date): Unit = { + checkEvaluation(Trunc(Literal.create(input, StringType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + Trunc(Literal.create(input, StringType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + + def testTimestamp(input: Timestamp, fmt: String, expected: Date): Unit = { + checkEvaluation(Trunc(Literal.create(input, TimestampType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + Trunc(Literal.create(input, TimestampType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + + val dateStr = "2015-07-22" + val date = Date.valueOf(dateStr) + val ts = new Timestamp(date.getTime) + Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt => - test(date, fmt, Date.valueOf("2015-01-01")) + testDate(date, fmt, Date.valueOf("2015-01-01")) + testString(dateStr, fmt, Date.valueOf("2015-01-01")) + testTimestamp(ts, fmt, Date.valueOf("2015-01-01")) } Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => - test(date, fmt, Date.valueOf("2015-07-01")) + testDate(date, fmt, Date.valueOf("2015-07-01")) + testString(dateStr, fmt, Date.valueOf("2015-07-01")) + testTimestamp(ts, fmt, Date.valueOf("2015-07-01")) } - test(date, "DD", null) - test(date, null, null) - test(null, "MON", null) - test(null, null, null) + testDate(date, "DD", null) + testDate(date, null, null) + testDate(null, "MON", null) + testDate(null, null, null) + + testString(dateStr, "DD", null) + testString(dateStr, null, null) + testString(null, "MON", null) + testString(null, null, null) + + testTimestamp(ts, "DD", null) + testTimestamp(ts, null, null) + testTimestamp(null, "MON", null) + testTimestamp(null, null, null) } } From 931f07de787081cdd6822dbf396ec1b8d205f25e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 2 Aug 2017 20:42:48 +0800 Subject: [PATCH 13/13] Revert trunc(date, format). --- python/pyspark/sql/functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ebe7d572f2b4c..51bb4557e8ee2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1028,12 +1028,12 @@ def to_timestamp(col, format=None): @since(1.5) -def trunc(data, truncParam): +def trunc(date, format): """ - Returns date truncated to the unit specified by the truncParam or + Returns date truncated to the unit specified by the format or numeric truncated by specified decimal places. - :param truncParam: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date + :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date and any int for numeric. >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) @@ -1050,7 +1050,7 @@ def trunc(data, truncParam): [Row(zero=1234567891.0)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.trunc(_to_java_column(data), truncParam)) + return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) @since(1.5)