Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom

Expand All @@ -32,10 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom
*
* Since this expression is stateful, it cannot be a case object.
*/
abstract class RDG extends LeafExpression with Nondeterministic {

protected def seed: Long

abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic {
/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
Expand All @@ -46,12 +42,18 @@ abstract class RDG extends LeafExpression with Nondeterministic {
rng = new XORShiftRandom(seed + partitionIndex)
}

@transient protected lazy val seed: Long = child match {
case Literal(s, IntegerType) => s.asInstanceOf[Int]
case Literal(s, LongType) => s.asInstanceOf[Long]
case _ => throw new AnalysisException(
s"Input argument to $prettyName must be an integer, long or null literal.")
}

override def nullable: Boolean = false

override def dataType: DataType = DoubleType

// NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed.
override def sql: String = s"$prettyName($seed)"
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
}

/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
Expand All @@ -64,17 +66,15 @@ abstract class RDG extends LeafExpression with Nondeterministic {
0.9629742951434543
> SELECT _FUNC_(0);
0.8446490682263027
> SELECT _FUNC_(null);
0.8446490682263027
""")
// scalastyle:on line.size.limit
case class Rand(seed: Long) extends RDG {
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
case class Rand(child: Expression) extends RDG {

def this() = this(Utils.random.nextLong())
def this() = this(Literal(Utils.random.nextLong()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same here?


def this(seed: Expression) = this(seed match {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
})
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rngTerm = ctx.freshName("rng")
Expand All @@ -87,6 +87,10 @@ case class Rand(seed: Long) extends RDG {
}
}

object Rand {
def apply(seed: Long): Rand = Rand(Literal(seed))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same here?

}

/** Generate a random column with i.i.d. values drawn from the standard normal distribution. */
// scalastyle:off line.size.limit
@ExpressionDescription(
Expand All @@ -97,17 +101,15 @@ case class Rand(seed: Long) extends RDG {
-0.3254147983080288
> SELECT _FUNC_(0);
1.1164209726833079
> SELECT _FUNC_(null);
1.1164209726833079
""")
// scalastyle:on line.size.limit
case class Randn(seed: Long) extends RDG {
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
case class Randn(child: Expression) extends RDG {

def this() = this(Utils.random.nextLong())
def this() = this(Literal(Utils.random.nextLong()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not specifying the data type here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The complier seems complaining if we specify the return type in def this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, do you mean the type for literal for example as below?

Literal(Utils.random.nextLong(), LongType)

If you think it is beneficial because it at least does not do the type dispatch once, will fix here. Also, I can sweep the usages in functions.scala in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I think we should explicitly specify the type, if possible. This is my personal preference.

Not sure whether it worths a new PR to change all of them in functions.scala.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense. I will fix them here first.


def this(seed: Expression) = this(seed match {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to randn must be an integer literal.")
})
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rngTerm = ctx.freshName("rng")
Expand All @@ -119,3 +121,7 @@ case class Randn(seed: Long) extends RDG {
final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
}
}

object Randn {
def apply(seed: Long): Randn = Randn(Literal(seed))
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ package org.apache.spark.sql.catalyst.expressions
import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

test("random") {
checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)

checkDoubleEvaluation(
new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
checkDoubleEvaluation(
new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)
}

test("SPARK-9127 codegen with long seed") {
Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/random.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
-- rand with the seed 0
SELECT rand(0);
SELECT rand(cast(3 / 7 AS int));
SELECT rand(NULL);
SELECT rand(cast(NULL AS int));

-- rand unsupported data type
SELECT rand(1.0);

-- randn with the seed 0
SELECT randn(0L);
SELECT randn(cast(3 / 7 AS long));
SELECT randn(NULL);
SELECT randn(cast(NULL AS long));

-- randn unsupported data type
SELECT rand('1')
84 changes: 84 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/random.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 10


-- !query 0
SELECT rand(0)
-- !query 0 schema
struct<rand(0):double>
-- !query 0 output
0.8446490682263027


-- !query 1
SELECT rand(cast(3 / 7 AS int))
-- !query 1 schema
struct<rand(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS INT)):double>
-- !query 1 output
0.8446490682263027


-- !query 2
SELECT rand(NULL)
-- !query 2 schema
struct<rand(CAST(NULL AS INT)):double>
-- !query 2 output
0.8446490682263027


-- !query 3
SELECT rand(cast(NULL AS int))
-- !query 3 schema
struct<rand(CAST(NULL AS INT)):double>
-- !query 3 output
0.8446490682263027


-- !query 4
SELECT rand(1.0)
-- !query 4 schema
struct<>
-- !query 4 output
org.apache.spark.sql.AnalysisException
cannot resolve 'rand(1.0BD)' due to data type mismatch: argument 1 requires (int or bigint) type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7


-- !query 5
SELECT randn(0L)
-- !query 5 schema
struct<randn(0):double>
-- !query 5 output
1.1164209726833079


-- !query 6
SELECT randn(cast(3 / 7 AS long))
-- !query 6 schema
struct<randn(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS BIGINT)):double>
-- !query 6 output
1.1164209726833079


-- !query 7
SELECT randn(NULL)
-- !query 7 schema
struct<randn(CAST(NULL AS INT)):double>
-- !query 7 output
1.1164209726833079


-- !query 8
SELECT randn(cast(NULL AS long))
-- !query 8 schema
struct<randn(CAST(NULL AS BIGINT)):double>
-- !query 8 output
1.1164209726833079


-- !query 9
SELECT rand('1')
-- !query 9 schema
struct<>
-- !query 9 output
org.apache.spark.sql.AnalysisException
cannot resolve 'rand('1')' due to data type mismatch: argument 1 requires (int or bigint) type, however, ''1'' is of string type.; line 1 pos 7