diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f7aa6da0a5bdc..8a4fca750d7e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -147,6 +147,21 @@ class SimpleTestOptimizer extends Optimizer( new SimpleCatalystConf(caseSensitiveAnalysis = true)), new SimpleCatalystConf(caseSensitiveAnalysis = true)) +/** + * Pushes projects down beneath Sample to enable column pruning with sampling. + */ +object PushProjectThroughSample extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Push down projection into sample + case proj @ Project(projectList, Sample(lb, up, replace, seed, child)) => + if (!replace || !projectList.exists(_.find(!_.deterministic).nonEmpty)) { + Sample(lb, up, replace, seed, Project(projectList, child))() + } else { + proj + } + } +} + /** * Removes the Project only conducting Alias of its child node. * It is created mainly for removing extra Project added in EliminateSerialization rule, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6da99ce0dd683..434d9496a1fbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1534,15 +1534,15 @@ class Dataset[T] private[sql]( * Returns a new Dataset by sampling a fraction of rows. * * @param withReplacement Sample with replacement or not. - * @param fraction Fraction of rows to generate. + * @param fraction Fraction of rows to generate and the range is 0.0 <= `fraction` <= 1.0. * @param seed Seed for sampling. * * @group typedrel * @since 1.6.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { - require(fraction >= 0, - s"Fraction must be nonnegative, but got ${fraction}") + require(fraction >= 0 && fraction <= 1.0, + s"Fraction range must be 0.0 <= `fraction` <= 1.0, but got ${fraction}") withTypedPlan { Sample(0.0, fraction, withReplacement, seed, logicalPlan)() @@ -1553,7 +1553,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset by sampling a fraction of rows, using a random seed. * * @param withReplacement Sample with replacement or not. - * @param fraction Fraction of rows to generate. + * @param fraction Fraction of rows to generate and the range is 0.0 <= `fraction` <= 1.0. * * @group typedrel * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index ad8a71689895b..17708a195f6c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -266,6 +266,7 @@ case class SampleExec( if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") + ctx.copyResult = true ctx.addMutableState(s"$samplerClass", sampler, s"$initSampler();") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 499f3180379c2..983fbd68d6812 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1578,4 +1578,26 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(rdd, StructType(schemas), false) assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } + + test("cannot push projects down beneath sample when amplifying data") { + val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b") + val amplifyDf = df.sample(true, 0.99) + def checkQuery(c: Column): Unit = { + val d = amplifyDf.withColumn("c", c).select($"c").collect + assert(d.size == d.distinct.size) + } + checkQuery(monotonically_increasing_id) + checkQuery(rand) + } + + test("sampling fraction must be greate than 0.0 and less than 1.0") { + def checkException(v: Double): Unit = { + val e = intercept[AnalysisException] { + Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b").sample(true, 2.0) + } + assert(e.getMessage.startsWith("A valid range is 0.0 < `fraction` <= 1.0")) + } + checkException(1.1) + checkException(-1.2) + } }