Skip to content

Commit 7c91b15

Browse files
andygrovetgravescs
authored andcommitted
[SPARK-32332][SQL][3.0] Support columnar exchanges
### What changes were proposed in this pull request? Backports SPARK-32332 to 3.0 branch. ### Why are the changes needed? Plugins cannot replace exchanges with columnar versions when AQE is enabled without this patch. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tests included. Closes #29310 from andygrove/backport-SPARK-32332. Authored-by: Andy Grove <[email protected]> Signed-off-by: Thomas Graves <[email protected]>
1 parent 2a38090 commit 7c91b15

10 files changed

Lines changed: 272 additions & 68 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,12 @@ case class AdaptiveSparkPlanExec(
100100
// The following two rules need to make use of 'CustomShuffleReaderExec.partitionSpecs'
101101
// added by `CoalesceShufflePartitions`. So they must be executed after it.
102102
OptimizeSkewedJoin(conf),
103-
OptimizeLocalShuffleReader(conf),
103+
OptimizeLocalShuffleReader(conf)
104+
)
105+
106+
// A list of physical optimizer rules to be applied right after a new stage is created. The input
107+
// plan to these rules has exchange as its root node.
108+
@transient private val postStageCreationRules = Seq(
104109
ApplyColumnarRulesAndInsertTransitions(conf, context.session.sessionState.columnarRules),
105110
CollapseCodegenStages(conf)
106111
)
@@ -227,7 +232,8 @@ case class AdaptiveSparkPlanExec(
227232
}
228233

229234
// Run the final plan when there's no more unfinished stages.
230-
currentPhysicalPlan = applyPhysicalRules(result.newPlan, queryStageOptimizerRules)
235+
currentPhysicalPlan = applyPhysicalRules(
236+
result.newPlan, queryStageOptimizerRules ++ postStageCreationRules)
231237
isFinalPlan = true
232238
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
233239
currentPhysicalPlan
@@ -375,10 +381,22 @@ case class AdaptiveSparkPlanExec(
375381
private def newQueryStage(e: Exchange): QueryStageExec = {
376382
val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules)
377383
val queryStage = e match {
378-
case s: ShuffleExchangeExec =>
379-
ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan))
380-
case b: BroadcastExchangeExec =>
381-
BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan))
384+
case s: ShuffleExchangeLike =>
385+
val newShuffle = applyPhysicalRules(
386+
s.withNewChildren(Seq(optimizedPlan)), postStageCreationRules)
387+
if (!newShuffle.isInstanceOf[ShuffleExchangeLike]) {
388+
throw new IllegalStateException(
389+
"Custom columnar rules cannot transform shuffle node to something else.")
390+
}
391+
ShuffleQueryStageExec(currentStageId, newShuffle)
392+
case b: BroadcastExchangeLike =>
393+
val newBroadcast = applyPhysicalRules(
394+
b.withNewChildren(Seq(optimizedPlan)), postStageCreationRules)
395+
if (!newBroadcast.isInstanceOf[BroadcastExchangeLike]) {
396+
throw new IllegalStateException(
397+
"Custom columnar rules cannot transform broadcast node to something else.")
398+
}
399+
BroadcastQueryStageExec(currentStageId, newBroadcast)
382400
}
383401
currentStageId += 1
384402
setLogicalLinkForNewQueryStage(queryStage, e)

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
2323
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
2424
import org.apache.spark.sql.execution._
25-
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
25+
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
26+
import org.apache.spark.sql.vectorized.ColumnarBatch
2627

2728

2829
/**
@@ -38,6 +39,8 @@ case class CustomShuffleReaderExec private(
3839
partitionSpecs: Seq[ShufflePartitionSpec],
3940
description: String) extends UnaryExecNode {
4041

42+
override def supportsColumnar: Boolean = child.supportsColumnar
43+
4144
override def output: Seq[Attribute] = child.output
4245
override lazy val outputPartitioning: Partitioning = {
4346
// If it is a local shuffle reader with one mapper per task, then the output partitioning is
@@ -47,9 +50,9 @@ case class CustomShuffleReaderExec private(
4750
partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size ==
4851
partitionSpecs.length) {
4952
child match {
50-
case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) =>
53+
case ShuffleQueryStageExec(_, s: ShuffleExchangeLike) =>
5154
s.child.outputPartitioning
52-
case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) =>
55+
case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike)) =>
5356
s.child.outputPartitioning match {
5457
case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
5558
case other => other
@@ -64,18 +67,24 @@ case class CustomShuffleReaderExec private(
6467

6568
override def stringArgs: Iterator[Any] = Iterator(description)
6669

67-
private var cachedShuffleRDD: RDD[InternalRow] = null
70+
private def shuffleStage = child match {
71+
case stage: ShuffleQueryStageExec => Some(stage)
72+
case _ => None
73+
}
6874

69-
override protected def doExecute(): RDD[InternalRow] = {
70-
if (cachedShuffleRDD == null) {
71-
cachedShuffleRDD = child match {
72-
case stage: ShuffleQueryStageExec =>
73-
new ShuffledRowRDD(
74-
stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs.toArray)
75-
case _ =>
76-
throw new IllegalStateException("operating on canonicalization plan")
77-
}
75+
private lazy val shuffleRDD: RDD[_] = {
76+
shuffleStage.map { stage =>
77+
stage.shuffle.getShuffleRDD(partitionSpecs.toArray)
78+
}.getOrElse {
79+
throw new IllegalStateException("operating on canonicalized plan")
7880
}
79-
cachedShuffleRDD
81+
}
82+
83+
override protected def doExecute(): RDD[InternalRow] = {
84+
shuffleRDD.asInstanceOf[RDD[InternalRow]]
85+
}
86+
87+
override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
88+
shuffleRDD.asInstanceOf[RDD[ColumnarBatch]]
8089
}
8190
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
7878
private def getPartitionSpecs(
7979
shuffleStage: ShuffleQueryStageExec,
8080
advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = {
81-
val shuffleDep = shuffleStage.shuffle.shuffleDependency
82-
val numReducers = shuffleDep.partitioner.numPartitions
81+
val numMappers = shuffleStage.shuffle.numMappers
82+
val numReducers = shuffleStage.shuffle.numPartitions
8383
val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
84-
val numMappers = shuffleDep.rdd.getNumPartitions
8584
val splitPoints = if (numMappers == 0) {
8685
Seq.empty
8786
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable
2121

2222
import org.apache.commons.io.FileUtils
2323

24-
import org.apache.spark.{MapOutputTrackerMaster, SparkEnv}
24+
import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
2525
import org.apache.spark.sql.catalyst.plans._
2626
import org.apache.spark.sql.catalyst.rules.Rule
2727
import org.apache.spark.sql.execution._
@@ -197,7 +197,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
197197
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
198198
val reducerId = leftPartSpec.startReducerIndex
199199
val skewSpecs = createSkewPartitionSpecs(
200-
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
200+
left.mapStats.shuffleId, reducerId, leftTargetSize)
201201
if (skewSpecs.isDefined) {
202202
logDebug(s"Left side partition $partitionIndex is skewed, split it into " +
203203
s"${skewSpecs.get.length} parts.")
@@ -212,7 +212,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
212212
val rightParts = if (isRightSkew && !isRightCoalesced) {
213213
val reducerId = rightPartSpec.startReducerIndex
214214
val skewSpecs = createSkewPartitionSpecs(
215-
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
215+
right.mapStats.shuffleId, reducerId, rightTargetSize)
216216
if (skewSpecs.isDefined) {
217217
logDebug(s"Right side partition $partitionIndex is skewed, split it into " +
218218
s"${skewSpecs.get.length} parts.")
@@ -287,15 +287,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
287287
private object ShuffleStage {
288288
def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match {
289289
case s: ShuffleQueryStageExec if s.mapStats.isDefined =>
290-
val sizes = s.mapStats.get.bytesByPartitionId
290+
val mapStats = s.mapStats.get
291+
val sizes = mapStats.bytesByPartitionId
291292
val partitions = sizes.zipWithIndex.map {
292293
case (size, i) => CoalescedPartitionSpec(i, i + 1) -> size
293294
}
294-
Some(ShuffleStageInfo(s, partitions))
295+
Some(ShuffleStageInfo(s, mapStats, partitions))
295296

296297
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs, _)
297298
if s.mapStats.isDefined && partitionSpecs.nonEmpty =>
298-
val sizes = s.mapStats.get.bytesByPartitionId
299+
val mapStats = s.mapStats.get
300+
val sizes = mapStats.bytesByPartitionId
299301
val partitions = partitionSpecs.map {
300302
case spec @ CoalescedPartitionSpec(start, end) =>
301303
var sum = 0L
@@ -308,14 +310,15 @@ private object ShuffleStage {
308310
case other => throw new IllegalArgumentException(
309311
s"Expect CoalescedPartitionSpec but got $other")
310312
}
311-
Some(ShuffleStageInfo(s, partitions))
313+
Some(ShuffleStageInfo(s, mapStats, partitions))
312314

313315
case _ => None
314316
}
315317
}
316318

317319
private case class ShuffleStageInfo(
318320
shuffleStage: ShuffleQueryStageExec,
321+
mapStats: MapOutputStatistics,
319322
partitionsWithSizes: Seq[(CoalescedPartitionSpec, Long)])
320323

321324
private class SkewDesc {

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
3232
import org.apache.spark.sql.execution._
3333
import org.apache.spark.sql.execution.exchange._
3434
import org.apache.spark.sql.internal.SQLConf
35+
import org.apache.spark.sql.vectorized.ColumnarBatch
3536
import org.apache.spark.util.ThreadUtils
3637

3738
/**
@@ -80,6 +81,11 @@ abstract class QueryStageExec extends LeafExecNode {
8081

8182
def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec
8283

84+
/**
85+
* Returns the runtime statistics after stage materialization.
86+
*/
87+
def getRuntimeStatistics: Statistics
88+
8389
/**
8490
* Compute the statistics of the query stage if executed, otherwise None.
8591
*/
@@ -107,6 +113,8 @@ abstract class QueryStageExec extends LeafExecNode {
107113

108114
protected override def doPrepare(): Unit = plan.prepare()
109115
protected override def doExecute(): RDD[InternalRow] = plan.execute()
116+
override def supportsColumnar: Boolean = plan.supportsColumnar
117+
protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar()
110118
override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast()
111119
override def doCanonicalize(): SparkPlan = plan.canonicalized
112120

@@ -135,15 +143,15 @@ abstract class QueryStageExec extends LeafExecNode {
135143
}
136144

137145
/**
138-
* A shuffle query stage whose child is a [[ShuffleExchangeExec]] or [[ReusedExchangeExec]].
146+
* A shuffle query stage whose child is a [[ShuffleExchangeLike]] or [[ReusedExchangeExec]].
139147
*/
140148
case class ShuffleQueryStageExec(
141149
override val id: Int,
142150
override val plan: SparkPlan) extends QueryStageExec {
143151

144152
@transient val shuffle = plan match {
145-
case s: ShuffleExchangeExec => s
146-
case ReusedExchangeExec(_, s: ShuffleExchangeExec) => s
153+
case s: ShuffleExchangeLike => s
154+
case ReusedExchangeExec(_, s: ShuffleExchangeLike) => s
147155
case _ =>
148156
throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString)
149157
}
@@ -176,18 +184,20 @@ case class ShuffleQueryStageExec(
176184
val stats = resultOption.get.asInstanceOf[MapOutputStatistics]
177185
Option(stats)
178186
}
187+
188+
override def getRuntimeStatistics: Statistics = shuffle.runtimeStatistics
179189
}
180190

181191
/**
182-
* A broadcast query stage whose child is a [[BroadcastExchangeExec]] or [[ReusedExchangeExec]].
192+
* A broadcast query stage whose child is a [[BroadcastExchangeLike]] or [[ReusedExchangeExec]].
183193
*/
184194
case class BroadcastQueryStageExec(
185195
override val id: Int,
186196
override val plan: SparkPlan) extends QueryStageExec {
187197

188198
@transient val broadcast = plan match {
189-
case b: BroadcastExchangeExec => b
190-
case ReusedExchangeExec(_, b: BroadcastExchangeExec) => b
199+
case b: BroadcastExchangeLike => b
200+
case ReusedExchangeExec(_, b: BroadcastExchangeLike) => b
191201
case _ =>
192202
throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString)
193203
}
@@ -224,6 +234,8 @@ case class BroadcastQueryStageExec(
224234
broadcast.relationFuture.cancel(true)
225235
}
226236
}
237+
238+
override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics
227239
}
228240

229241
object BroadcastQueryStageExec {

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.adaptive
1919

2020
import org.apache.spark.sql.execution.SparkPlan
21-
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
21+
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
2222

2323
/**
2424
* A simple implementation of [[Cost]], which takes a number of [[Long]] as the cost value.
@@ -35,13 +35,13 @@ case class SimpleCost(value: Long) extends Cost {
3535

3636
/**
3737
* A simple implementation of [[CostEvaluator]], which counts the number of
38-
* [[ShuffleExchangeExec]] nodes in the plan.
38+
* [[ShuffleExchangeLike]] nodes in the plan.
3939
*/
4040
object SimpleCostEvaluator extends CostEvaluator {
4141

4242
override def evaluateCost(plan: SparkPlan): Cost = {
4343
val cost = plan.collect {
44-
case s: ShuffleExchangeExec => s
44+
case s: ShuffleExchangeLike => s
4545
}.size
4646
SimpleCost(cost)
4747
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.launcher.SparkLauncher
2929
import org.apache.spark.rdd.RDD
3030
import org.apache.spark.sql.catalyst.InternalRow
3131
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
32+
import org.apache.spark.sql.catalyst.plans.logical.Statistics
3233
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
3334
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
3435
import org.apache.spark.sql.execution.joins.HashedRelation
@@ -37,16 +38,43 @@ import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
3738
import org.apache.spark.unsafe.map.BytesToBytesMap
3839
import org.apache.spark.util.{SparkFatalException, ThreadUtils}
3940

41+
/**
42+
* Common trait for all broadcast exchange implementations to facilitate pattern matching.
43+
*/
44+
trait BroadcastExchangeLike extends Exchange {
45+
46+
/**
47+
* The broadcast job group ID
48+
*/
49+
def runId: UUID = UUID.randomUUID
50+
51+
/**
52+
* The asynchronous job that prepares the broadcast relation.
53+
*/
54+
def relationFuture: Future[broadcast.Broadcast[Any]]
55+
56+
/**
57+
* For registering callbacks on `relationFuture`.
58+
* Note that calling this method may not start the execution of broadcast job.
59+
*/
60+
def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]
61+
62+
/**
63+
* Returns the runtime statistics after broadcast materialization.
64+
*/
65+
def runtimeStatistics: Statistics
66+
}
67+
4068
/**
4169
* A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
4270
* a transformed SparkPlan.
4371
*/
4472
case class BroadcastExchangeExec(
4573
mode: BroadcastMode,
46-
child: SparkPlan) extends Exchange {
74+
child: SparkPlan) extends BroadcastExchangeLike {
4775
import BroadcastExchangeExec._
4876

49-
private[sql] val runId: UUID = UUID.randomUUID
77+
override val runId: UUID = UUID.randomUUID
5078

5179
override lazy val metrics = Map(
5280
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
@@ -60,6 +88,11 @@ case class BroadcastExchangeExec(
6088
BroadcastExchangeExec(mode.canonicalized, child.canonicalized)
6189
}
6290

91+
override def runtimeStatistics: Statistics = {
92+
val dataSize = metrics("dataSize").value
93+
Statistics(dataSize)
94+
}
95+
6396
@transient
6497
private lazy val promise = Promise[broadcast.Broadcast[Any]]()
6598

@@ -68,13 +101,14 @@ case class BroadcastExchangeExec(
68101
* Note that calling this field will not start the execution of broadcast job.
69102
*/
70103
@transient
71-
lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = promise.future
104+
override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] =
105+
promise.future
72106

73107
@transient
74108
private val timeout: Long = SQLConf.get.broadcastTimeout
75109

76110
@transient
77-
private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
111+
override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
78112
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
79113
sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
80114
try {

0 commit comments

Comments
 (0)