Skip to content

Commit c31a66f

Browse files
committed
refactor: use auxiliary idMap instead of OP_ID_TAG
1 parent 96365c8 commit c31a66f

File tree

2 files changed

+41
-46
lines changed

2 files changed

+41
-46
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.plans
1919

20+
import java.util.IdentityHashMap
21+
2022
import scala.collection.mutable
2123

2224
import org.apache.spark.sql.AnalysisException
@@ -443,7 +445,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
443445
override def verboseString(maxFields: Int): String = simpleString(maxFields)
444446

445447
override def simpleStringWithNodeId(): String = {
446-
val operatorId = getTagValue(QueryPlan.OP_ID_TAG).map(id => s"$id").getOrElse("unknown")
448+
val operatorId = Option(QueryPlan.localIdMap.get().get(this)).map(id => s"$id")
449+
.getOrElse("unknown")
447450
s"$nodeName ($operatorId)".trim
448451
}
449452

@@ -463,7 +466,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
463466
}
464467

465468
protected def formattedNodeName: String = {
466-
val opId = getTagValue(QueryPlan.OP_ID_TAG).map(id => s"$id").getOrElse("unknown")
469+
val opId = Option(QueryPlan.localIdMap.get().get(this)).map(id => s"$id")
470+
.getOrElse("unknown")
467471
val codegenId =
468472
getTagValue(QueryPlan.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("")
469473
s"($opId) $nodeName$codegenId"
@@ -675,9 +679,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
675679
}
676680

677681
object QueryPlan extends PredicateHelper {
678-
val OP_ID_TAG = TreeNodeTag[Int]("operatorId")
679682
val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId")
680683

684+
val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = ThreadLocal.withInitial(() =>
685+
new IdentityHashMap[QueryPlan[_], Int]())
686+
681687
/**
682688
* Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference`
683689
* with its referenced ordinal from input attributes. It's similar to `BindReferences` but we

sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.util.Collections.newSetFromMap
2120
import java.util.IdentityHashMap
22-
import java.util.Set
2321

2422
import scala.collection.mutable.{ArrayBuffer, BitSet}
2523

@@ -30,6 +28,8 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS
3028
import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}
3129

3230
object ExplainUtils extends AdaptiveSparkPlanHelper {
31+
def localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = QueryPlan.localIdMap
32+
3333
/**
3434
* Given a input physical plan, performs the following tasks.
3535
* 1. Computes the whole stage codegen id for current operator and records it in the
@@ -80,34 +80,36 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
8080
* instances but cached plan is an exception. The `InMemoryRelation#innerChildren` use a shared
8181
* plan instance across multi-queries. Add lock for this method to avoid tag race condition.
8282
*/
83-
def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = synchronized {
83+
def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = {
84+
val prevIdMap = localIdMap.get()
8485
try {
85-
// Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow
86-
// intentional overwriting of IDs generated in previous AQE iteration
87-
val operators = newSetFromMap[QueryPlan[_]](new IdentityHashMap())
86+
// Initialize a reference-unique id map to store generated ids, which also avoid accidental
87+
// overwrites and to allow intentional overwriting of IDs generated in previous AQE iteration
88+
val idMap = new IdentityHashMap[QueryPlan[_], Int]()
89+
localIdMap.set(idMap)
8890
// Initialize an array of ReusedExchanges to help find Adaptively Optimized Out
8991
// Exchanges as part of SPARK-42753
9092
val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec]
9193

9294
var currentOperatorID = 0
93-
currentOperatorID = generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges,
95+
currentOperatorID = generateOperatorIDs(plan, currentOperatorID, idMap, reusedExchanges,
9496
true)
9597

9698
val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)]
9799
getSubqueries(plan, subqueries)
98100

99101
currentOperatorID = subqueries.foldLeft(currentOperatorID) {
100-
(curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges,
102+
(curId, plan) => generateOperatorIDs(plan._3.child, curId, idMap, reusedExchanges,
101103
true)
102104
}
103105

104106
// SPARK-42753: Process subtree for a ReusedExchange with unknown child
105107
val optimizedOutExchanges = ArrayBuffer.empty[Exchange]
106108
reusedExchanges.foreach{ reused =>
107109
val child = reused.child
108-
if (!operators.contains(child)) {
110+
if (!idMap.containsKey(child)) {
109111
optimizedOutExchanges.append(child)
110-
currentOperatorID = generateOperatorIDs(child, currentOperatorID, operators,
112+
currentOperatorID = generateOperatorIDs(child, currentOperatorID, idMap,
111113
reusedExchanges, false)
112114
}
113115
}
@@ -144,7 +146,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
144146
append("\n")
145147
}
146148
} finally {
147-
removeTags(plan)
149+
localIdMap.set(prevIdMap)
148150
}
149151
}
150152

@@ -159,13 +161,15 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
159161
* @param plan Input query plan to process
160162
* @param startOperatorID The start value of operation id. The subsequent operations will be
161163
* assigned higher value.
162-
* @param visited A unique set of operators visited by generateOperatorIds. The set is scoped
163-
* at the callsite function processPlan. It serves two purpose: Firstly, it is
164-
* used to avoid accidentally overwriting existing IDs that were generated in
165-
* the same processPlan call. Secondly, it is used to allow for intentional ID
166-
* overwriting as part of SPARK-42753 where an Adaptively Optimized Out Exchange
167-
* and its subtree may contain IDs that were generated in a previous AQE
168-
* iteration's processPlan call which would result in incorrect IDs.
164+
* @param idMap A reference-unique map store operators visited by generateOperatorIds and its
165+
* id. This Map is scoped at the callsite function processPlan. It serves three
166+
* purpose:
167+
* Firstly, it stores the QueryPlan - generated ID mapping. Secondly, it is used to
168+
* avoid accidentally overwriting existing IDs that were generated in the same
169+
* processPlan call. Thirdly, it is used to allow for intentional ID overwriting as
170+
* part of SPARK-42753 where an Adaptively Optimized Out Exchange and its subtree
171+
* may contain IDs that were generated in a previous AQE iteration's processPlan
172+
* call which would result in incorrect IDs.
169173
* @param reusedExchanges A unique set of ReusedExchange nodes visited which will be used to
170174
* idenitfy adaptively optimized out exchanges in SPARK-42753.
171175
* @param addReusedExchanges Whether to add ReusedExchange nodes to reusedExchanges set. We set it
@@ -177,7 +181,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
177181
private def generateOperatorIDs(
178182
plan: QueryPlan[_],
179183
startOperatorID: Int,
180-
visited: Set[QueryPlan[_]],
184+
idMap: java.util.Map[QueryPlan[_], Int],
181185
reusedExchanges: ArrayBuffer[ReusedExchangeExec],
182186
addReusedExchanges: Boolean): Int = {
183187
var currentOperationID = startOperatorID
@@ -186,36 +190,35 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
186190
return currentOperationID
187191
}
188192

189-
def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) {
193+
def setOpId(plan: QueryPlan[_]): Unit = idMap.computeIfAbsent(plan, plan => {
190194
plan match {
191195
case r: ReusedExchangeExec if addReusedExchanges =>
192196
reusedExchanges.append(r)
193197
case _ =>
194198
}
195-
visited.add(plan)
196199
currentOperationID += 1
197-
plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID)
198-
}
200+
currentOperationID
201+
})
199202

200203
plan.foreachUp {
201204
case _: WholeStageCodegenExec =>
202205
case _: InputAdapter =>
203206
case p: AdaptiveSparkPlanExec =>
204-
currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, visited,
207+
currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, idMap,
205208
reusedExchanges, addReusedExchanges)
206209
if (!p.executedPlan.fastEquals(p.initialPlan)) {
207-
currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, visited,
210+
currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, idMap,
208211
reusedExchanges, addReusedExchanges)
209212
}
210213
setOpId(p)
211214
case p: QueryStageExec =>
212-
currentOperationID = generateOperatorIDs(p.plan, currentOperationID, visited,
215+
currentOperationID = generateOperatorIDs(p.plan, currentOperationID, idMap,
213216
reusedExchanges, addReusedExchanges)
214217
setOpId(p)
215218
case other: QueryPlan[_] =>
216219
setOpId(other)
217220
currentOperationID = other.innerChildren.foldLeft(currentOperationID) {
218-
(curId, plan) => generateOperatorIDs(plan, curId, visited, reusedExchanges,
221+
(curId, plan) => generateOperatorIDs(plan, curId, idMap, reusedExchanges,
219222
addReusedExchanges)
220223
}
221224
}
@@ -241,7 +244,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
241244
}
242245

243246
def collectOperatorWithID(plan: QueryPlan[_]): Unit = {
244-
plan.getTagValue(QueryPlan.OP_ID_TAG).foreach { id =>
247+
Option(ExplainUtils.localIdMap.get().get(plan)).foreach { id =>
245248
if (collectedOperators.add(id)) operators += plan
246249
}
247250
}
@@ -334,20 +337,6 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
334337
* `operationId` tag value.
335338
*/
336339
def getOpId(plan: QueryPlan[_]): String = {
337-
plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown")
338-
}
339-
340-
def removeTags(plan: QueryPlan[_]): Unit = {
341-
def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = {
342-
p.unsetTagValue(QueryPlan.OP_ID_TAG)
343-
p.unsetTagValue(QueryPlan.CODEGEN_ID_TAG)
344-
children.foreach(removeTags)
345-
}
346-
347-
plan foreach {
348-
case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, p.initialPlan))
349-
case p: QueryStageExec => remove(p, Seq(p.plan))
350-
case plan: QueryPlan[_] => remove(plan, plan.innerChildren)
351-
}
340+
Option(ExplainUtils.localIdMap.get().get(plan)).map(v => s"$v").getOrElse("unknown")
352341
}
353342
}

0 commit comments

Comments
 (0)