-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17854][SQL] rand/randn allows null/long as input seed #15432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
5213bd6
182548b
2453460
fec5f42
7bc0a19
9c56094
30179d8
3283d3a
a523302
b432355
160ea54
9b9a49f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,10 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} | ||
| import org.apache.spark.sql.types.{DataType, DoubleType} | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.util.Utils | ||
| import org.apache.spark.util.random.XORShiftRandom | ||
|
|
||
|
|
@@ -32,10 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom | |
| * | ||
| * Since this expression is stateful, it cannot be a case object. | ||
| */ | ||
| abstract class RDG extends LeafExpression with Nondeterministic { | ||
|
|
||
| protected def seed: Long | ||
|
|
||
| abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic { | ||
| /** | ||
| * Record ID within each partition. By being transient, the Random Number Generator is | ||
| * reset every time we serialize and deserialize and initialize it. | ||
|
|
@@ -46,12 +42,18 @@ abstract class RDG extends LeafExpression with Nondeterministic { | |
| rng = new XORShiftRandom(seed + partitionIndex) | ||
| } | ||
|
|
||
| @transient protected lazy val seed: Long = child match { | ||
| case Literal(s, IntegerType) => s.asInstanceOf[Int] | ||
| case Literal(s, LongType) => s.asInstanceOf[Long] | ||
| case _ => throw new AnalysisException( | ||
| s"Input argument to $prettyName must be an integer, long or null literal.") | ||
| } | ||
|
|
||
| override def nullable: Boolean = false | ||
|
|
||
| override def dataType: DataType = DoubleType | ||
|
|
||
| // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. | ||
| override def sql: String = s"$prettyName($seed)" | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType)) | ||
| } | ||
|
|
||
| /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ | ||
|
|
@@ -64,17 +66,15 @@ abstract class RDG extends LeafExpression with Nondeterministic { | |
| 0.9629742951434543 | ||
| > SELECT _FUNC_(0); | ||
| 0.8446490682263027 | ||
| > SELECT _FUNC_(null); | ||
| 0.8446490682263027 | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class Rand(seed: Long) extends RDG { | ||
| override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() | ||
| case class Rand(child: Expression) extends RDG { | ||
|
|
||
| def this() = this(Utils.random.nextLong()) | ||
| def this() = this(Literal(Utils.random.nextLong())) | ||
|
|
||
| def this(seed: Expression) = this(seed match { | ||
| case IntegerLiteral(s) => s | ||
| case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") | ||
| }) | ||
| override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val rngTerm = ctx.freshName("rng") | ||
|
|
@@ -87,6 +87,10 @@ case class Rand(seed: Long) extends RDG { | |
| } | ||
| } | ||
|
|
||
| object Rand { | ||
| def apply(seed: Long): Rand = Rand(Literal(seed)) | ||
|
||
| } | ||
|
|
||
| /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ | ||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
|
|
@@ -97,17 +101,15 @@ case class Rand(seed: Long) extends RDG { | |
| -0.3254147983080288 | ||
| > SELECT _FUNC_(0); | ||
| 1.1164209726833079 | ||
| > SELECT _FUNC_(null); | ||
| 1.1164209726833079 | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class Randn(seed: Long) extends RDG { | ||
| override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() | ||
| case class Randn(child: Expression) extends RDG { | ||
|
|
||
| def this() = this(Utils.random.nextLong()) | ||
| def this() = this(Literal(Utils.random.nextLong())) | ||
|
||
|
|
||
| def this(seed: Expression) = this(seed match { | ||
| case IntegerLiteral(s) => s | ||
| case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") | ||
| }) | ||
| override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val rngTerm = ctx.freshName("rng") | ||
|
|
@@ -119,3 +121,7 @@ case class Randn(seed: Long) extends RDG { | |
| final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") | ||
| } | ||
| } | ||
|
|
||
| object Randn { | ||
| def apply(seed: Long): Randn = Randn(Literal(seed)) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| -- rand with the seed 0 | ||
| SELECT rand(0); | ||
| SELECT rand(cast(3 / 7 AS int)); | ||
| SELECT rand(NULL); | ||
| SELECT rand(cast(NULL AS int)); | ||
|
|
||
| -- rand unsupported data type | ||
| SELECT rand(1.0); | ||
|
|
||
| -- randn with the seed 0 | ||
| SELECT randn(0L); | ||
| SELECT randn(cast(3 / 7 AS long)); | ||
| SELECT randn(NULL); | ||
| SELECT randn(cast(NULL AS long)); | ||
|
|
||
| -- randn unsupported data type | ||
| SELECT rand('1') |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| -- Automatically generated by SQLQueryTestSuite | ||
| -- Number of queries: 10 | ||
|
|
||
|
|
||
| -- !query 0 | ||
| SELECT rand(0) | ||
| -- !query 0 schema | ||
| struct<rand(0):double> | ||
| -- !query 0 output | ||
| 0.8446490682263027 | ||
|
|
||
|
|
||
| -- !query 1 | ||
| SELECT rand(cast(3 / 7 AS int)) | ||
| -- !query 1 schema | ||
| struct<rand(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS INT)):double> | ||
| -- !query 1 output | ||
| 0.8446490682263027 | ||
|
|
||
|
|
||
| -- !query 2 | ||
| SELECT rand(NULL) | ||
| -- !query 2 schema | ||
| struct<rand(CAST(NULL AS INT)):double> | ||
| -- !query 2 output | ||
| 0.8446490682263027 | ||
|
|
||
|
|
||
| -- !query 3 | ||
| SELECT rand(cast(NULL AS int)) | ||
| -- !query 3 schema | ||
| struct<rand(CAST(NULL AS INT)):double> | ||
| -- !query 3 output | ||
| 0.8446490682263027 | ||
|
|
||
|
|
||
| -- !query 4 | ||
| SELECT rand(1.0) | ||
| -- !query 4 schema | ||
| struct<> | ||
| -- !query 4 output | ||
| org.apache.spark.sql.AnalysisException | ||
| cannot resolve 'rand(1.0BD)' due to data type mismatch: argument 1 requires (int or bigint) type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7 | ||
|
|
||
|
|
||
| -- !query 5 | ||
| SELECT randn(0L) | ||
| -- !query 5 schema | ||
| struct<randn(0):double> | ||
| -- !query 5 output | ||
| 1.1164209726833079 | ||
|
|
||
|
|
||
| -- !query 6 | ||
| SELECT randn(cast(3 / 7 AS long)) | ||
| -- !query 6 schema | ||
| struct<randn(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS BIGINT)):double> | ||
| -- !query 6 output | ||
| 1.1164209726833079 | ||
|
|
||
|
|
||
| -- !query 7 | ||
| SELECT randn(NULL) | ||
| -- !query 7 schema | ||
| struct<randn(CAST(NULL AS INT)):double> | ||
| -- !query 7 output | ||
| 1.1164209726833079 | ||
|
|
||
|
|
||
| -- !query 8 | ||
| SELECT randn(cast(NULL AS long)) | ||
| -- !query 8 schema | ||
| struct<randn(CAST(NULL AS BIGINT)):double> | ||
| -- !query 8 output | ||
| 1.1164209726833079 | ||
|
|
||
|
|
||
| -- !query 9 | ||
| SELECT rand('1') | ||
| -- !query 9 schema | ||
| struct<> | ||
| -- !query 9 output | ||
| org.apache.spark.sql.AnalysisException | ||
| cannot resolve 'rand('1')' due to data type mismatch: argument 1 requires (int or bigint) type, however, ''1'' is of string type.; line 1 pos 7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same here?