Skip to content

Commit 71c24aa

Browse files
committed
[SPARK-25602][SQL] SparkPlan.getByteArrayRdd should not consume the input when not necessary
## What changes were proposed in this pull request? In `SparkPlan.getByteArrayRdd`, we should only call `it.hasNext` when the limit is not hit, as `iter.hasNext` may produce one row and buffer it, and cause wrong metrics. ## How was this patch tested? new tests Closes #22621 from cloud-fan/range. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 927e527 commit 71c24aa

2 files changed

Lines changed: 57 additions & 2 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
250250
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
251251
val bos = new ByteArrayOutputStream()
252252
val out = new DataOutputStream(codec.compressedOutputStream(bos))
253-
while (iter.hasNext && (n < 0 || count < n)) {
253+
// `iter.hasNext` may produce one row and buffer it, we should only call it when the limit is
254+
// not hit.
255+
while ((n < 0 || count < n) && iter.hasNext) {
254256
val row = iter.next().asInstanceOf[UnsafeRow]
255257
out.writeInt(row.getSizeInBytes)
256258
row.writeToStream(out, buffer)

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.util.Random
2424
import org.apache.spark.SparkFunSuite
2525
import org.apache.spark.sql._
2626
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
27-
import org.apache.spark.sql.execution.ui.SQLAppStatusStore
27+
import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec}
2828
import org.apache.spark.sql.functions._
2929
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.test.SharedSQLContext
@@ -517,4 +517,57 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
517517
test("writing data out metrics with dynamic partition: parquet") {
518518
testMetricsDynamicPartition("parquet", "parquet", "t1")
519519
}
520+
521+
test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input when not necessary") {
522+
def checkFilterAndRangeMetrics(
523+
df: DataFrame,
524+
filterNumOutputs: Int,
525+
rangeNumOutputs: Int): Unit = {
526+
var filter: FilterExec = null
527+
var range: RangeExec = null
528+
val collectFilterAndRange: SparkPlan => Unit = {
529+
case f: FilterExec =>
530+
assert(filter == null, "the query should only have one Filter")
531+
filter = f
532+
case r: RangeExec =>
533+
assert(range == null, "the query should only have one Range")
534+
range = r
535+
case _ =>
536+
}
537+
if (SQLConf.get.wholeStageEnabled) {
538+
df.queryExecution.executedPlan.foreach {
539+
case w: WholeStageCodegenExec =>
540+
w.child.foreach(collectFilterAndRange)
541+
case _ =>
542+
}
543+
} else {
544+
df.queryExecution.executedPlan.foreach(collectFilterAndRange)
545+
}
546+
547+
assert(filter != null && range != null, "the query doesn't have Filter and Range")
548+
assert(filter.metrics("numOutputRows").value == filterNumOutputs)
549+
assert(range.metrics("numOutputRows").value == rangeNumOutputs)
550+
}
551+
552+
val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0)
553+
val df2 = df.limit(2)
554+
Seq(true, false).foreach { wholeStageEnabled =>
555+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStageEnabled.toString) {
556+
df.collect()
557+
checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs = 3000)
558+
559+
df.queryExecution.executedPlan.foreach(_.resetMetrics())
560+
// For each partition, we get 2 rows. Then the Filter should produce 2 rows per-partition,
561+
// and Range should produce 1000 rows (one batch) per-partition. Totally Filter produces
562+
// 4 rows, and Range produces 2000 rows.
563+
df.queryExecution.toRdd.mapPartitions(_.take(2)).collect()
564+
checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 2000)
565+
566+
// Top-most limit will call `CollectLimitExec.executeCollect`, which will only run the first
567+
// task, so totally the Filter produces 2 rows, and Range produces 1000 rows (one batch).
568+
df2.collect()
569+
checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs = 1000)
570+
}
571+
}
572+
}
520573
}

0 commit comments

Comments
 (0)