-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17854][SQL] rand/randn allows null/long as input seed #15432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
5213bd6
182548b
2453460
fec5f42
7bc0a19
9c56094
30179d8
3283d3a
a523302
b432355
160ea54
9b9a49f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,10 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} | ||
| import org.apache.spark.sql.types.{DataType, DoubleType} | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.util.Utils | ||
| import org.apache.spark.util.random.XORShiftRandom | ||
|
|
||
|
|
@@ -32,10 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom | |
| * | ||
| * Since this expression is stateful, it cannot be a case object. | ||
| */ | ||
| abstract class RDG extends LeafExpression with Nondeterministic { | ||
|
|
||
| protected def seed: Long | ||
|
|
||
| abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic { | ||
| /** | ||
| * Record ID within each partition. By being transient, the Random Number Generator is | ||
| * reset every time we serialize and deserialize and initialize it. | ||
|
|
@@ -46,12 +42,18 @@ abstract class RDG extends LeafExpression with Nondeterministic { | |
| rng = new XORShiftRandom(seed + partitionIndex) | ||
| } | ||
|
|
||
| @transient protected lazy val seed: Long = child match { | ||
| case Literal(s, IntegerType) => s.asInstanceOf[Int] | ||
| case Literal(s, LongType) => s.asInstanceOf[Long] | ||
| case _ => throw new AnalysisException( | ||
| s"Input argument to $prettyName must be an integer, long or null literal.") | ||
| } | ||
|
|
||
| override def nullable: Boolean = false | ||
|
|
||
| override def dataType: DataType = DoubleType | ||
|
|
||
| // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. | ||
| override def sql: String = s"$prettyName($seed)" | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType)) | ||
| } | ||
|
|
||
| /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ | ||
|
|
@@ -64,17 +66,15 @@ abstract class RDG extends LeafExpression with Nondeterministic { | |
| 0.9629742951434543 | ||
| > SELECT _FUNC_(0); | ||
| 0.8446490682263027 | ||
| > SELECT _FUNC_(null); | ||
| 0.8446490682263027 | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class Rand(seed: Long) extends RDG { | ||
| override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() | ||
| case class Rand(child: Expression) extends RDG { | ||
|
|
||
| def this() = this(Utils.random.nextLong()) | ||
| def this() = this(Literal(Utils.random.nextLong())) | ||
|
|
||
| def this(seed: Expression) = this(seed match { | ||
| case IntegerLiteral(s) => s | ||
| case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") | ||
| }) | ||
| override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val rngTerm = ctx.freshName("rng") | ||
|
|
@@ -87,6 +87,10 @@ case class Rand(seed: Long) extends RDG { | |
| } | ||
| } | ||
|
|
||
| object Rand { | ||
| def apply(seed: Long): Rand = Rand(Literal(seed)) | ||
|
||
| } | ||
|
|
||
| /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ | ||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
|
|
@@ -97,17 +101,15 @@ case class Rand(seed: Long) extends RDG { | |
| -0.3254147983080288 | ||
| > SELECT _FUNC_(0); | ||
| 1.1164209726833079 | ||
| > SELECT _FUNC_(null); | ||
| 1.1164209726833079 | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class Randn(seed: Long) extends RDG { | ||
| override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() | ||
| case class Randn(child: Expression) extends RDG { | ||
|
|
||
| def this() = this(Utils.random.nextLong()) | ||
| def this() = this(Literal(Utils.random.nextLong())) | ||
|
||
|
|
||
| def this(seed: Expression) = this(seed match { | ||
| case IntegerLiteral(s) => s | ||
| case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") | ||
| }) | ||
| override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val rngTerm = ctx.freshName("rng") | ||
|
|
@@ -119,3 +121,7 @@ case class Randn(seed: Long) extends RDG { | |
| final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") | ||
| } | ||
| } | ||
|
|
||
| object Randn { | ||
| def apply(seed: Long): Randn = Randn(Literal(seed)) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1728,4 +1728,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { | |
| val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) | ||
| assert(df.filter($"array1" === $"array2").count() == 1) | ||
| } | ||
|
|
||
| test("SPARK-17854: rand/randn allows null and long as input seed") { | ||
|
||
| checkAnswer(testData.selectExpr("rand(NULL)"), testData.selectExpr("rand(0)")) | ||
| checkAnswer(testData.selectExpr("rand(0L)"), testData.selectExpr("rand(0)")) | ||
| checkAnswer(testData.selectExpr("randn(NULL)"), testData.selectExpr("randn(0)")) | ||
| checkAnswer(testData.selectExpr("randn(0L)"), testData.selectExpr("randn(0)")) | ||
| checkAnswer(testData.selectExpr("rand(cast(NULL AS INT))"), testData.selectExpr("rand(0)")) | ||
| checkAnswer(testData.selectExpr("rand(cast(3 / 7 AS INT))"), testData.selectExpr("rand(0)")) | ||
| checkAnswer( | ||
| testData.selectExpr("randn(cast(NULL AS LONG))"), testData.selectExpr("randn(0L)")) | ||
| checkAnswer( | ||
| testData.selectExpr("randn(cast(3L / 12L AS LONG))"), testData.selectExpr("randn(0L)")) | ||
|
|
||
| val eOne = intercept[AnalysisException] { | ||
| testData.selectExpr("rand(key)").collect() | ||
| } | ||
| assert( | ||
| eOne.message.contains("Input argument to rand must be an integer, long or null literal.")) | ||
|
|
||
| val eTwo = intercept[AnalysisException] { | ||
| testData.selectExpr("randn(key)").collect() | ||
| } | ||
| assert( | ||
| eTwo.message.contains("Input argument to randn must be an integer, long or null literal.")) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same here?