Skip to content

Commit 13d882a

Browse files
committed
Limit operation within whole stage codegen should not consume all the inputs
1 parent 71c24aa commit 13d882a

9 files changed

Lines changed: 170 additions & 108 deletions

File tree

sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,6 @@ public void append(InternalRow row) {
7373
currentRows.add(row);
7474
}
7575

76-
/**
77-
* Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]].
78-
*
79-
* If it returns true, the caller should exit the loop that [[InputAdapter]] generates.
80-
* This interface is mainly used to limit the number of input rows.
81-
*/
82-
public boolean stopEarly() {
83-
return false;
84-
}
85-
8676
/**
8777
* Returns whether `processNext()` should stop processing next row from `input` or not.
8878
*

sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
136136
|if ($batch == null) {
137137
| $nextBatchFuncName();
138138
|}
139-
|while ($batch != null) {
139+
|while ($batch != null$keepProducingDataCond) {
140140
| int $numRows = $batch.numRows();
141141
| int $localEnd = $numRows - $idx;
142142
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
@@ -166,7 +166,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
166166
}
167167
val inputRow = if (needsUnsafeRowConversion) null else row
168168
s"""
169-
|while ($input.hasNext()) {
169+
|while ($input.hasNext()$keepProducingDataCond) {
170170
| InternalRow $row = (InternalRow) $input.next();
171171
| $numOutputRows.add(1);
172172
| ${consume(ctx, outputVars, inputRow).trim}

sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ case class SortExec(
132132
// a stop check before sorting.
133133
override def needStopCheck: Boolean = false
134134

135+
// Sort operator always consumes all the input rows before outputting any result, so its upstream
136+
// operators can keep producing data, even if there is a limit after Sort.
137+
override def conditionsOfKeepProducingData: Seq[String] = Nil
138+
135139
override protected def doProduce(ctx: CodegenContext): String = {
136140
val needToSort =
137141
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
@@ -172,7 +176,7 @@ case class SortExec(
172176
| $needToSort = false;
173177
| }
174178
|
175-
| while ($sortedIterator.hasNext()) {
179+
| while ($sortedIterator.hasNext()$keepProducingDataCond) {
176180
| UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
177181
| ${consume(ctx, null, outputRow)}
178182
| if (shouldStop()) return;

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,16 @@ trait CodegenSupport extends SparkPlan {
345345
* don't require shouldStop() in the loop of producing rows.
346346
*/
347347
def needStopCheck: Boolean = parent.needStopCheck
348+
349+
def conditionsOfKeepProducingData: Seq[String] = parent.conditionsOfKeepProducingData
350+
351+
final protected def keepProducingDataCond: String = {
352+
if (parent.conditionsOfKeepProducingData.isEmpty) {
353+
""
354+
} else {
355+
parent.conditionsOfKeepProducingData.mkString(" && ", " && ", "")
356+
}
357+
}
348358
}
349359

350360

@@ -381,7 +391,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
381391
forceInline = true)
382392
val row = ctx.freshName("row")
383393
s"""
384-
| while ($input.hasNext() && !stopEarly()) {
394+
| while ($input.hasNext()$keepProducingDataCond) {
385395
| InternalRow $row = (InternalRow) $input.next();
386396
| ${consume(ctx, null, row).trim}
387397
| if (shouldStop()) return;
@@ -677,6 +687,8 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
677687

678688
override def needStopCheck: Boolean = true
679689

690+
override def conditionsOfKeepProducingData: Seq[String] = Nil
691+
680692
override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer])
681693
}
682694

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ case class HashAggregateExec(
159159
// don't need a stop check before aggregating.
160160
override def needStopCheck: Boolean = false
161161

162+
// Aggregate operator always consumes all the input rows before outputting any result, so its
163+
// upstream operators can keep producing data, even if there is a limit after Aggregate.
164+
override def conditionsOfKeepProducingData: Seq[String] = Nil
165+
162166
protected override def doProduce(ctx: CodegenContext): String = {
163167
if (groupingExpressions.isEmpty) {
164168
doProduceWithoutKeys(ctx)
@@ -705,13 +709,16 @@ case class HashAggregateExec(
705709

706710
def outputFromRegularHashMap: String = {
707711
s"""
708-
|while ($iterTerm.next()) {
712+
|while ($iterTerm.next()$keepProducingDataCond) {
709713
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
710714
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
711715
| $outputFunc($keyTerm, $bufferTerm);
712-
|
713716
| if (shouldStop()) return;
714717
|}
718+
|$iterTerm.close();
719+
|if ($sorterTerm == null) {
720+
| $hashMapTerm.free();
721+
|}
715722
""".stripMargin
716723
}
717724

@@ -728,11 +735,6 @@ case class HashAggregateExec(
728735
// output the result
729736
$outputFromFastHashMap
730737
$outputFromRegularHashMap
731-
732-
$iterTerm.close();
733-
if ($sorterTerm == null) {
734-
$hashMapTerm.free();
735-
}
736738
"""
737739
}
738740

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
378378
val numOutput = metricTerm(ctx, "numOutputRows")
379379

380380
val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
381-
val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number")
381+
val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex")
382382

383383
val value = ctx.freshName("value")
384384
val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
@@ -397,7 +397,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
397397
// within a batch, while the code in the outer loop is setting batch parameters and updating
398398
// the metrics.
399399

400-
// Once number == batchEnd, it's time to progress to the next batch.
400+
// Once nextIndex == batchEnd, it's time to progress to the next batch.
401401
val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
402402

403403
// How many values should still be generated by this range operator.
@@ -421,13 +421,13 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
421421
|
422422
| $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
423423
| if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
424-
| $number = Long.MAX_VALUE;
424+
| $nextIndex = Long.MAX_VALUE;
425425
| } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
426-
| $number = Long.MIN_VALUE;
426+
| $nextIndex = Long.MIN_VALUE;
427427
| } else {
428-
| $number = st.longValue();
428+
| $nextIndex = st.longValue();
429429
| }
430-
| $batchEnd = $number;
430+
| $batchEnd = $nextIndex;
431431
|
432432
| $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
433433
| .multiply(step).add(start);
@@ -440,7 +440,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
440440
| }
441441
|
442442
| $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
443-
| $BigInt.valueOf($number));
443+
| $BigInt.valueOf($nextIndex));
444444
| $numElementsTodo = startToEnd.divide(step).longValue();
445445
| if ($numElementsTodo < 0) {
446446
| $numElementsTodo = 0;
@@ -452,46 +452,68 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
452452

453453
val localIdx = ctx.freshName("localIdx")
454454
val localEnd = ctx.freshName("localEnd")
455-
val range = ctx.freshName("range")
456455
val shouldStop = if (parent.needStopCheck) {
457-
s"if (shouldStop()) { $number = $value + ${step}L; return; }"
456+
s"if (shouldStop()) { $nextIndex = $value + ${step}L; return; }"
458457
} else {
459458
"// shouldStop check is eliminated"
460459
}
460+
461+
// An overview of the Range processing.
462+
//
463+
// For each partition, the Range task needs to produce records from partition start(inclusive)
464+
// to end(exclusive). For better performance, we separate the partition range into batches, and
465+
// use 2 loops to produce data. The outer while loop is used to iterate batches, and the inner
466+
// for loop is used to iterate records inside a batch.
467+
//
468+
// `nextIndex` tracks the index of the next record that is going to be consumed, initialized
469+
// with partition start. `batchEnd` tracks the end index of the current batch, initialized
470+
// with `nextIndex`. In the outer loop, we first check if `nextIndex == batchEnd`. If it's true,
471+
// it means the current batch is fully consumed, and we will update `batchEnd` to process the
472+
// next batch. If `batchEnd` reaches partition end, exit the outer loop. finally we enter the
473+
// inner loop. Note that, when we enter inner loop, `nextIndex` must be different from
474+
// `batchEnd`, otherwise the outer loop should already exits.
475+
//
476+
// The inner loop iterates from 0 to `localEnd`, which is calculated by
477+
// `(batchEnd - nextIndex) / step`. Since `batchEnd` is increased by `nextBatchTodo * step` in
478+
// the outer loop, and initialized with `nextIndex`, so `batchEnd - nextIndex` is always
479+
// divisible by `step`. The `nextIndex` is increased by `step` during each iteration, and ends
480+
// up being equal to `batchEnd` when the inner loop finishes.
481+
//
482+
// The inner loop can be interrupted, if the query has produced at least one result row, so that
483+
// we don't buffer too many result rows and waste memory. It's ok to interrupt the inner loop,
484+
// because `nextIndex` will be updated before interrupting.
485+
461486
s"""
462487
| // initialize Range
463488
| if (!$initTerm) {
464489
| $initTerm = true;
465490
| $initRangeFuncName(partitionIndex);
466491
| }
467492
|
468-
| while (true) {
469-
| long $range = $batchEnd - $number;
470-
| if ($range != 0L) {
471-
| int $localEnd = (int)($range / ${step}L);
472-
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
473-
| long $value = ((long)$localIdx * ${step}L) + $number;
474-
| ${consume(ctx, Seq(ev))}
475-
| $shouldStop
493+
| while (true$keepProducingDataCond) {
494+
| if ($nextIndex == $batchEnd) {
495+
| long $nextBatchTodo;
496+
| if ($numElementsTodo > ${batchSize}L) {
497+
| $nextBatchTodo = ${batchSize}L;
498+
| $numElementsTodo -= ${batchSize}L;
499+
| } else {
500+
| $nextBatchTodo = $numElementsTodo;
501+
| $numElementsTodo = 0;
502+
| if ($nextBatchTodo == 0) break;
476503
| }
477-
| $number = $batchEnd;
504+
| $numOutput.add($nextBatchTodo);
505+
| $inputMetrics.incRecordsRead($nextBatchTodo);
506+
| $batchEnd += $nextBatchTodo * ${step}L;
478507
| }
479508
|
480-
| $taskContext.killTaskIfInterrupted();
481-
|
482-
| long $nextBatchTodo;
483-
| if ($numElementsTodo > ${batchSize}L) {
484-
| $nextBatchTodo = ${batchSize}L;
485-
| $numElementsTodo -= ${batchSize}L;
486-
| } else {
487-
| $nextBatchTodo = $numElementsTodo;
488-
| $numElementsTodo = 0;
489-
| if ($nextBatchTodo == 0) break;
509+
| int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L);
510+
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
511+
| long $value = ((long)$localIdx * ${step}L) + $nextIndex;
512+
| ${consume(ctx, Seq(ev))}
513+
| $shouldStop
490514
| }
491-
| $numOutput.add($nextBatchTodo);
492-
| $inputMetrics.incRecordsRead($nextBatchTodo);
493-
|
494-
| $batchEnd += $nextBatchTodo * ${step}L;
515+
| $nextIndex = $batchEnd;
516+
| $taskContext.killTaskIfInterrupted();
495517
| }
496518
""".stripMargin
497519
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ case class SortMergeJoinExec(
623623
}
624624

625625
s"""
626-
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
626+
|while (findNextInnerJoinRows($leftInput, $rightInput)$keepProducingDataCond) {
627627
| ${leftVarDecl.mkString("\n")}
628628
| ${beforeLoop.trim}
629629
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
4646
}
4747
}
4848

49+
object BaseLimitExec {
50+
private val curId = new java.util.concurrent.atomic.AtomicInteger()
51+
52+
def newLimitCountTerm(): String = {
53+
val id = curId.getAndIncrement()
54+
s"_limit_counter_$id"
55+
}
56+
}
57+
4958
/**
5059
* Helper trait which defines methods that are shared by both
5160
* [[LocalLimitExec]] and [[GlobalLimitExec]].
@@ -66,27 +75,22 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
6675
// to the parent operator.
6776
override def usedInputs: AttributeSet = AttributeSet.empty
6877

78+
private lazy val countTerm = BaseLimitExec.newLimitCountTerm()
79+
80+
override lazy val conditionsOfKeepProducingData: Seq[String] = {
81+
s"$countTerm < $limit" +: super.conditionsOfKeepProducingData
82+
}
83+
6984
protected override def doProduce(ctx: CodegenContext): String = {
7085
child.asInstanceOf[CodegenSupport].produce(ctx, this)
7186
}
7287

7388
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
74-
val stopEarly =
75-
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false
76-
77-
ctx.addNewFunction("stopEarly", s"""
78-
@Override
79-
protected boolean stopEarly() {
80-
return $stopEarly;
81-
}
82-
""", inlineToOuterClass = true)
83-
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0
89+
ctx.addMutableState(CodeGenerator.JAVA_INT, countTerm, forceInline = true, useFreshName = false)
8490
s"""
8591
| if ($countTerm < $limit) {
8692
| $countTerm += 1;
8793
| ${consume(ctx, input)}
88-
| } else {
89-
| $stopEarly = true;
9094
| }
9195
""".stripMargin
9296
}

0 commit comments

Comments
 (0)