Skip to content

Commit 3f6a2bb

Browse files
JoshRosenhvanhovell
authored andcommitted
[SPARK-17515] CollectLimit.execute() should perform per-partition limits
## What changes were proposed in this pull request? CollectLimit.execute() incorrectly omits per-partition limits, leading to performance regressions in case this case is hit (which should not happen in normal operation, but can occur in some cases (see #15068 for one example). ## How was this patch tested? Regression test in SQLQuerySuite that asserts the number of records scanned from the input RDD. Author: Josh Rosen <[email protected]> Closes #15070 from JoshRosen/SPARK-17515.
1 parent 46f5c20 commit 3f6a2bb

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
3939
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
4040
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
4141
protected override def doExecute(): RDD[InternalRow] = {
42+
val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit))
4243
val shuffled = new ShuffledRowRDD(
4344
ShuffleExchange.prepareShuffleDependency(
44-
child.execute(), child.output, SinglePartition, serializer))
45+
locallyLimited, child.output, SinglePartition, serializer))
4546
shuffled.mapPartitionsInternal(_.take(limit))
4647
}
4748
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2661,4 +2661,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
26612661
data.selectExpr("`part.col1`", "`col.1`"))
26622662
}
26632663
}
2664+
2665+
test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") {
2666+
val numRecordsRead = spark.sparkContext.longAccumulator
2667+
spark.range(1, 100, 1, numPartitions = 10).map { x =>
2668+
numRecordsRead.add(1)
2669+
x
2670+
}.limit(1).queryExecution.toRdd.count()
2671+
assert(numRecordsRead.value === 10)
2672+
}
26642673
}

0 commit comments

Comments
 (0)