From 5213bd60f4be0795e23362f555dcdcf1a1d060cd Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 11 Oct 2016 18:21:18 +0900 Subject: [PATCH 01/12] rand/randn allows null as input seed --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 4 +++- .../apache/spark/sql/catalyst/expressions/RandomSuite.scala | 6 ++++++ .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 5 +++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index a331a5557b455..df364ba3e5372 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -21,7 +21,7 @@ import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.types.{DataType, DoubleType} +import org.apache.spark.sql.types.{DataType, DoubleType, NullType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -73,6 +73,7 @@ case class Rand(seed: Long) extends RDG { def this(seed: Expression) = this(seed match { case IntegerLiteral(s) => s + case Literal(null, NullType) => 0 case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) @@ -106,6 +107,7 @@ case class Randn(seed: Long) extends RDG { def this(seed: Expression) = this(seed match { case IntegerLiteral(s) => s + case Literal(null, NullType) => 0 case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index b7a0d44fa7e57..a7c14f13f43b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -20,12 +20,18 @@ package org.apache.spark.sql.catalyst.expressions import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.NullType class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001) checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001) + + checkDoubleEvaluation( + new Rand(Literal.create(null, NullType)), 0.8446490682263027 +- 0.001) + checkDoubleEvaluation( + new Randn(Literal.create(null, NullType)), 1.1164209726833079 +- 0.001) } test("SPARK-9127 codegen with long seed") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f5bc8785d5a2c..b9fccb3f4a4df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1728,4 +1728,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) assert(df.filter($"array1" === $"array2").count() == 1) } + + test("SPARK-17854: rand allows null as input seed") { + checkAnswer(testData.selectExpr("rand(NULL)"), testData.selectExpr("rand(0)")) + checkAnswer(testData.selectExpr("randn(NULL)"), testData.selectExpr("randn(0)")) + } } From 182548b963eaf0a9ae480b6e35131c823498e99b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 12 Oct 2016 21:23:56 +0900 Subject: [PATCH 02/12] Use ExpectsInputTypes and allow LongType and IntegerType --- .../expressions/randomExpressions.scala | 48 +++++++++---------- .../catalyst/expressions/RandomSuite.scala | 6 +-- .../org/apache/spark/sql/DataFrameSuite.scala | 5 +- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index df364ba3e5372..0d802cd22ee98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.types.{DataType, DoubleType, NullType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -32,10 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends LeafExpression with Nondeterministic { - - protected def seed: Long - +abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -46,12 +42,18 @@ abstract class RDG extends LeafExpression with Nondeterministic { rng = new XORShiftRandom(seed + partitionIndex) } + @transient protected lazy val seed: Long = child match { + case Literal(s, IntegerType) => s.asInstanceOf[Int] + case Literal(s, LongType) => s.asInstanceOf[Long] + case _ => throw new AnalysisException( + s"Input argument to $prettyName must be an integer/long literal.") + } + override def nullable: Boolean = false override def dataType: DataType = DoubleType - // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. - override def sql: String = s"$prettyName($seed)" + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType)) } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ @@ -66,17 +68,9 @@ abstract class RDG extends LeafExpression with Nondeterministic { 0.8446490682263027 """) // scalastyle:on line.size.limit -case class Rand(seed: Long) extends RDG { +case class Rand(child: Expression) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() - def this() = this(Utils.random.nextLong()) - - def this(seed: Expression) = this(seed match { - case IntegerLiteral(s) => s - case Literal(null, NullType) => 0 - case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") - }) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName @@ -88,6 +82,11 @@ case class Rand(seed: Long) extends RDG { } } +object Rand { + def apply(seed: Long): Rand = Rand(Literal(seed)) + def apply(): Rand = Rand(Literal(Utils.random.nextLong())) +} + /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -100,17 +99,9 @@ case class Rand(seed: Long) extends RDG { 1.1164209726833079 """) // scalastyle:on line.size.limit -case class Randn(seed: Long) extends RDG { +case class Randn(child: Expression) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() - def this() = this(Utils.random.nextLong()) - - def this(seed: Expression) = this(seed match { - case IntegerLiteral(s) => s - case Literal(null, NullType) => 0 - case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") - }) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName @@ -121,3 +112,8 @@ case class Randn(seed: Long) extends RDG { final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } } + +object Randn { + def apply(seed: Long): Randn = Randn(Literal(seed)) + def apply(): Randn = Randn(Literal(Utils.random.nextLong())) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index a7c14f13f43b0..752c9d5449ee2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.NullType +import org.apache.spark.sql.types.{IntegerType, LongType} class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -29,9 +29,9 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001) checkDoubleEvaluation( - new Rand(Literal.create(null, NullType)), 0.8446490682263027 +- 0.001) + new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001) checkDoubleEvaluation( - new Randn(Literal.create(null, NullType)), 1.1164209726833079 +- 0.001) + new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001) } test("SPARK-9127 codegen with long seed") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b9fccb3f4a4df..695f2b26763af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1635,6 +1635,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } +<<<<<<< 5213bd60f4be0795e23362f555dcdcf1a1d060cd private def verifyNullabilityInFilterExec( df: DataFrame, expr: String, @@ -1729,8 +1730,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.filter($"array1" === $"array2").count() == 1) } - test("SPARK-17854: rand allows null as input seed") { + test("SPARK-17854: rand/randn allows null and long as input seed") { checkAnswer(testData.selectExpr("rand(NULL)"), testData.selectExpr("rand(0)")) + checkAnswer(testData.selectExpr("rand(0L)"), testData.selectExpr("rand(0)")) checkAnswer(testData.selectExpr("randn(NULL)"), testData.selectExpr("randn(0)")) + checkAnswer(testData.selectExpr("randn(0L)"), testData.selectExpr("randn(0)")) } } From 245346032a721c2c8b46e54e83c3620c7196b46b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 12 Oct 2016 23:31:11 +0900 Subject: [PATCH 03/12] Override constructor --- .../sql/catalyst/expressions/randomExpressions.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 0d802cd22ee98..235755408cc87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -69,6 +69,9 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm """) // scalastyle:on line.size.limit case class Rand(child: Expression) extends RDG { + + def this() = this(Literal(Utils.random.nextLong())) + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -84,7 +87,6 @@ case class Rand(child: Expression) extends RDG { object Rand { def apply(seed: Long): Rand = Rand(Literal(seed)) - def apply(): Rand = Rand(Literal(Utils.random.nextLong())) } /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ @@ -100,6 +102,9 @@ object Rand { """) // scalastyle:on line.size.limit case class Randn(child: Expression) extends RDG { + + def this() = this(Literal(Utils.random.nextLong())) + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -115,5 +120,4 @@ case class Randn(child: Expression) extends RDG { object Randn { def apply(seed: Long): Randn = Randn(Literal(seed)) - def apply(): Randn = Randn(Literal(Utils.random.nextLong())) } From fec5f42c00b38a176f6a4e149135dfb017d1ea9b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 16 Oct 2016 15:57:32 +0900 Subject: [PATCH 04/12] Add test cases for constant folding and improve documentation --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 5 +++++ .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 235755408cc87..2daa081b76870 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -59,6 +59,7 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ // scalastyle:off line.size.limit @ExpressionDescription( +<<<<<<< 245346032a721c2c8b46e54e83c3620c7196b46b usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) uniformly distributed values in [0, 1).", extended = """ Examples: @@ -67,6 +68,10 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm > SELECT _FUNC_(0); 0.8446490682263027 """) +======= + usage = "_FUNC_(a) - Returns a random column with i.i.d. uniformly distributed values in [0, 1].", + extended = "> SELECT _FUNC_();\n 0.9629742951434543\n> SELECT _FUNC_(0);\n 0.8446490682263027\n> SELECT _FUNC_(NULL);\n 0.8446490682263027") +>>>>>>> Add test cases for constant folding and improve documentation // scalastyle:on line.size.limit case class Rand(child: Expression) extends RDG { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 695f2b26763af..46d4a21218f25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1735,5 +1735,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(testData.selectExpr("rand(0L)"), testData.selectExpr("rand(0)")) checkAnswer(testData.selectExpr("randn(NULL)"), testData.selectExpr("randn(0)")) checkAnswer(testData.selectExpr("randn(0L)"), testData.selectExpr("randn(0)")) + checkAnswer(testData.selectExpr("rand(cast(NULL AS INT))"), testData.selectExpr("rand(0)")) + checkAnswer(testData.selectExpr("rand(cast(3 / 7 AS INT))"), testData.selectExpr("rand(0)")) + checkAnswer( + testData.selectExpr("randn(cast(NULL AS LONG))"), testData.selectExpr("randn(0L)")) + checkAnswer( + testData.selectExpr("randn(cast(3L / 12L AS LONG))"), testData.selectExpr("randn(0L)")) } } From 7bc0a192c3887cbd8806dae9eeb5f73fe731b80a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 16 Oct 2016 16:31:10 +0900 Subject: [PATCH 05/12] Add some more cases for exceptions --- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 46d4a21218f25..b6bc33ea50a21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1741,5 +1741,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.selectExpr("randn(cast(NULL AS LONG))"), testData.selectExpr("randn(0L)")) checkAnswer( testData.selectExpr("randn(cast(3L / 12L AS LONG))"), testData.selectExpr("randn(0L)")) + + val eOne = intercept[AnalysisException] { + testData.selectExpr("rand(key)").collect() + } + assert(eOne.message.contains("Input argument to rand must be an integer/long literal.")) + + val eTwo = intercept[AnalysisException] { + testData.selectExpr("randn(key)").collect() + } + assert(eTwo.message.contains("Input argument to randn must be an integer/long literal.")) } } From 9c560945a1e59d252584e383e5df470f918e8016 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 17 Oct 2016 10:01:42 +0900 Subject: [PATCH 06/12] Improve documentation --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 2daa081b76870..235755408cc87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -59,7 +59,6 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ // scalastyle:off line.size.limit @ExpressionDescription( -<<<<<<< 245346032a721c2c8b46e54e83c3620c7196b46b usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) uniformly distributed values in [0, 1).", extended = """ Examples: @@ -68,10 +67,6 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm > SELECT _FUNC_(0); 0.8446490682263027 """) -======= - usage = "_FUNC_(a) - Returns a random column with i.i.d. uniformly distributed values in [0, 1].", - extended = "> SELECT _FUNC_();\n 0.9629742951434543\n> SELECT _FUNC_(0);\n 0.8446490682263027\n> SELECT _FUNC_(NULL);\n 0.8446490682263027") ->>>>>>> Add test cases for constant folding and improve documentation // scalastyle:on line.size.limit case class Rand(child: Expression) extends RDG { From 30179d8279603cac12638e16915fd238ce82a310 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 17 Oct 2016 10:03:16 +0900 Subject: [PATCH 07/12] Improve exception message and documentation --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 235755408cc87..0855f682c612d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -46,7 +46,7 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm case Literal(s, IntegerType) => s.asInstanceOf[Int] case Literal(s, LongType) => s.asInstanceOf[Long] case _ => throw new AnalysisException( - s"Input argument to $prettyName must be an integer/long literal.") + s"Input argument to $prettyName must be an integer/long/NULL literal.") } override def nullable: Boolean = false From 3283d3a06398601a0960cf3cb28835fc6b861b1e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 17 Oct 2016 10:04:21 +0900 Subject: [PATCH 08/12] Fix the test too --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b6bc33ea50a21..3bfbb5c2b861f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1745,11 +1745,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val eOne = intercept[AnalysisException] { testData.selectExpr("rand(key)").collect() } - assert(eOne.message.contains("Input argument to rand must be an integer/long literal.")) + assert(eOne.message.contains("Input argument to rand must be an integer/long/NULL literal.")) val eTwo = intercept[AnalysisException] { testData.selectExpr("randn(key)").collect() } - assert(eTwo.message.contains("Input argument to randn must be an integer/long literal.")) + assert(eTwo.message.contains("Input argument to randn must be an integer/long/NULL literal.")) } } From a523302f7b5f70611f0ae15435e19ada2d88ec8d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 4 Nov 2016 14:30:49 +0900 Subject: [PATCH 09/12] Add examples for null as an argument --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 6 +++++- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 0855f682c612d..f12fea67614fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -46,7 +46,7 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm case Literal(s, IntegerType) => s.asInstanceOf[Int] case Literal(s, LongType) => s.asInstanceOf[Long] case _ => throw new AnalysisException( - s"Input argument to $prettyName must be an integer/long/NULL literal.") + s"Input argument to $prettyName must be an integer, long or null literal.") } override def nullable: Boolean = false @@ -66,6 +66,8 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm 0.9629742951434543 > SELECT _FUNC_(0); 0.8446490682263027 + > SELECT _FUNC_(null); + 0.8446490682263027 """) // scalastyle:on line.size.limit case class Rand(child: Expression) extends RDG { @@ -99,6 +101,8 @@ object Rand { -0.3254147983080288 > SELECT _FUNC_(0); 1.1164209726833079 + > SELECT _FUNC_(null); + 1.1164209726833079 """) // scalastyle:on line.size.limit case class Randn(child: Expression) extends RDG { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3bfbb5c2b861f..0402d76ce08a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1635,7 +1635,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } -<<<<<<< 5213bd60f4be0795e23362f555dcdcf1a1d060cd private def verifyNullabilityInFilterExec( df: DataFrame, expr: String, From b432355241f6a7a3a457c9e0671c803b7757e169 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 4 Nov 2016 14:36:14 +0900 Subject: [PATCH 10/12] Fix the tests accordingly --- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0402d76ce08a8..8d68e7db47e41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1744,11 +1744,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val eOne = intercept[AnalysisException] { testData.selectExpr("rand(key)").collect() } - assert(eOne.message.contains("Input argument to rand must be an integer/long/NULL literal.")) + assert( + eOne.message.contains("Input argument to rand must be an integer, long or null literal.")) val eTwo = intercept[AnalysisException] { testData.selectExpr("randn(key)").collect() } - assert(eTwo.message.contains("Input argument to randn must be an integer/long/NULL literal.")) + assert( + eTwo.message.contains("Input argument to randn must be an integer, long or null literal.")) } } From 160ea54c8b041904b0e98be0b56bda6529f91d93 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 4 Nov 2016 16:07:01 +0900 Subject: [PATCH 11/12] Move the tests into sql query test suit --- .../resources/sql-tests/inputs/random.sql | 17 ++++ .../sql-tests/results/random.sql.out | 84 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 25 ------ 3 files changed, 101 insertions(+), 25 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/random.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/random.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/random.sql b/sql/core/src/test/resources/sql-tests/inputs/random.sql new file mode 100644 index 0000000000000..a1aae7b8759dc --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/random.sql @@ -0,0 +1,17 @@ +-- rand with the seed 0 +SELECT rand(0); +SELECT rand(cast(3 / 7 AS int)); +SELECT rand(NULL); +SELECT rand(cast(NULL AS int)); + +-- rand unsupported data type +SELECT rand(1.0); + +-- randn with the seed 0 +SELECT randn(0L); +SELECT randn(cast(3 / 7 AS long)); +SELECT randn(NULL); +SELECT randn(cast(NULL AS long)); + +-- randn unsupported data type +SELECT rand('1') diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out new file mode 100644 index 0000000000000..bca67320fe7bb --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -0,0 +1,84 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +SELECT rand(0) +-- !query 0 schema +struct +-- !query 0 output +0.8446490682263027 + + +-- !query 1 +SELECT rand(cast(3 / 7 AS int)) +-- !query 1 schema +struct +-- !query 1 output +0.8446490682263027 + + +-- !query 2 +SELECT rand(NULL) +-- !query 2 schema +struct +-- !query 2 output +0.8446490682263027 + + +-- !query 3 +SELECT rand(cast(NULL AS int)) +-- !query 3 schema +struct +-- !query 3 output +0.8446490682263027 + + +-- !query 4 +SELECT rand(1.0) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'rand(1.0BD)' due to data type mismatch: argument 1 requires (int or bigint) type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7 + + +-- !query 5 +SELECT randn(0L) +-- !query 5 schema +struct +-- !query 5 output +1.1164209726833079 + + +-- !query 6 +SELECT randn(cast(3 / 7 AS long)) +-- !query 6 schema +struct +-- !query 6 output +1.1164209726833079 + + +-- !query 7 +SELECT randn(NULL) +-- !query 7 schema +struct +-- !query 7 output +1.1164209726833079 + + +-- !query 8 +SELECT randn(cast(NULL AS long)) +-- !query 8 schema +struct +-- !query 8 output +1.1164209726833079 + + +-- !query 9 +SELECT rand('1') +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'rand('1')' due to data type mismatch: argument 1 requires (int or bigint) type, however, ''1'' is of string type.; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8d68e7db47e41..f5bc8785d5a2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1728,29 +1728,4 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) assert(df.filter($"array1" === $"array2").count() == 1) } - - test("SPARK-17854: rand/randn allows null and long as input seed") { - checkAnswer(testData.selectExpr("rand(NULL)"), testData.selectExpr("rand(0)")) - checkAnswer(testData.selectExpr("rand(0L)"), testData.selectExpr("rand(0)")) - checkAnswer(testData.selectExpr("randn(NULL)"), testData.selectExpr("randn(0)")) - checkAnswer(testData.selectExpr("randn(0L)"), testData.selectExpr("randn(0)")) - checkAnswer(testData.selectExpr("rand(cast(NULL AS INT))"), testData.selectExpr("rand(0)")) - checkAnswer(testData.selectExpr("rand(cast(3 / 7 AS INT))"), testData.selectExpr("rand(0)")) - checkAnswer( - testData.selectExpr("randn(cast(NULL AS LONG))"), testData.selectExpr("randn(0L)")) - checkAnswer( - testData.selectExpr("randn(cast(3L / 12L AS LONG))"), testData.selectExpr("randn(0L)")) - - val eOne = intercept[AnalysisException] { - testData.selectExpr("rand(key)").collect() - } - assert( - eOne.message.contains("Input argument to rand must be an integer, long or null literal.")) - - val eTwo = intercept[AnalysisException] { - testData.selectExpr("randn(key)").collect() - } - assert( - eTwo.message.contains("Input argument to randn must be an integer, long or null literal.")) - } } From 9b9a49f80c216e00c526a9e07e9f764c1409bd10 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 5 Nov 2016 14:46:23 +0900 Subject: [PATCH 12/12] Specify the date type for literals --- .../sql/catalyst/expressions/randomExpressions.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f12fea67614fb..1d7a3c7356075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -72,7 +72,7 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm // scalastyle:on line.size.limit case class Rand(child: Expression) extends RDG { - def this() = this(Literal(Utils.random.nextLong())) + def this() = this(Literal(Utils.random.nextLong(), LongType)) override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() @@ -88,7 +88,7 @@ case class Rand(child: Expression) extends RDG { } object Rand { - def apply(seed: Long): Rand = Rand(Literal(seed)) + def apply(seed: Long): Rand = Rand(Literal(seed, LongType)) } /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ @@ -107,7 +107,7 @@ object Rand { // scalastyle:on line.size.limit case class Randn(child: Expression) extends RDG { - def this() = this(Literal(Utils.random.nextLong())) + def this() = this(Literal(Utils.random.nextLong(), LongType)) override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() @@ -123,5 +123,5 @@ case class Randn(child: Expression) extends RDG { } object Randn { - def apply(seed: Long): Randn = Randn(Literal(seed)) + def apply(seed: Long): Randn = Randn(Literal(seed, LongType)) }