Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,21 @@ class SimpleTestOptimizer extends Optimizer(
new SimpleCatalystConf(caseSensitiveAnalysis = true)),
new SimpleCatalystConf(caseSensitiveAnalysis = true))

/**
* Pushes projects down beneath Sample to enable column pruning with sampling.
*/
object PushProjectThroughSample extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Push down projection into sample
case proj @ Project(projectList, Sample(lb, up, replace, seed, child)) =>
if (!replace || !projectList.exists(_.find(!_.deterministic).nonEmpty)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second condition looks complicated. Just projectList.forall(_.deterministic)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, thanks! I'll fix this.

Sample(lb, up, replace, seed, Project(projectList, child))()
} else {
proj
}
}
}

/**
* Removes the Project only conducting Alias of its child node.
* It is created mainly for removing extra Project added in EliminateSerialization rule,
Expand Down
8 changes: 4 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1534,15 +1534,15 @@ class Dataset[T] private[sql](
* Returns a new Dataset by sampling a fraction of rows.
*
* @param withReplacement Sample with replacement or not.
* @param fraction Fraction of rows to generate.
* @param fraction Fraction of rows to generate and the range is 0.0 <= `fraction` <= 1.0.
* @param seed Seed for sampling.
*
* @group typedrel
* @since 1.6.0
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
require(fraction >= 0,
s"Fraction must be nonnegative, but got ${fraction}")
require(fraction >= 0 && fraction <= 1.0,
s"Fraction range must be 0.0 <= `fraction` <= 1.0, but got ${fraction}")
Copy link
Member

@HyukjinKwon HyukjinKwon Aug 23, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @maropu, I just wonder if this fix is still needed though just to be consistent whether withRelacement is true or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon oh, you're right and my bad... thanks! Since this original pr is far from this bug, I'll make a new jira ticket and a pr soon later.


withTypedPlan {
Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
Expand All @@ -1553,7 +1553,7 @@ class Dataset[T] private[sql](
* Returns a new Dataset by sampling a fraction of rows, using a random seed.
*
* @param withReplacement Sample with replacement or not.
* @param fraction Fraction of rows to generate.
* @param fraction Fraction of rows to generate and the range is 0.0 <= `fraction` <= 1.0.
*
* @group typedrel
* @since 1.6.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ case class SampleExec(
if (withReplacement) {
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")
ctx.copyResult = true
ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
s"$initSampler();")

Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1578,4 +1578,26 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = spark.createDataFrame(rdd, StructType(schemas), false)
assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100)
}

test("cannot push projects down beneath sample when amplifying data") {
val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
val amplifyDf = df.sample(true, 0.99)
def checkQuery(c: Column): Unit = {
val d = amplifyDf.withColumn("c", c).select($"c").collect
assert(d.size == d.distinct.size)
}
checkQuery(monotonically_increasing_id)
checkQuery(rand)
}

test("sampling fraction must be greate than 0.0 and less than 1.0") {
def checkException(v: Double): Unit = {
val e = intercept[AnalysisException] {
Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b").sample(true, 2.0)
}
assert(e.getMessage.startsWith("A valid range is 0.0 < `fraction` <= 1.0"))
}
checkException(1.1)
checkException(-1.2)
}
}