Skip to content

Commit d57ecc1

Browse files
committed
add a logical plan link in the physical plan
1 parent fd9acf2 commit d57ecc1

6 files changed

Lines changed: 180 additions & 6 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans
1919

2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions._
22-
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode}
22+
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTagName}
2323
import org.apache.spark.sql.internal.SQLConf
2424
import org.apache.spark.sql.types.{DataType, StructType}
2525

@@ -271,6 +271,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
271271
}
272272

273273
object QueryPlan extends PredicateHelper {
274+
val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan")
275+
274276
/**
275277
* Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference`
276278
* with its referenced ordinal from input attributes. It's similar to `BindReferences` but we

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,11 @@ case class OneRowRelation() extends LeafNode {
10781078
override def computeStats(): Statistics = Statistics(sizeInBytes = 1)
10791079

10801080
/** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */
1081-
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = OneRowRelation()
1081+
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
1082+
val newCopy = OneRowRelation()
1083+
newCopy.tags ++= this.tags
1084+
newCopy
1085+
}
10821086
}
10831087

10841088
/** A logical plan for `dropDuplicates`. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees
1919

2020
import java.util.UUID
2121

22-
import scala.collection.Map
22+
import scala.collection.{mutable, Map}
2323
import scala.reflect.ClassTag
2424

2525
import org.apache.commons.lang3.ClassUtils
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.errors._
3535
import org.apache.spark.sql.catalyst.expressions._
3636
import org.apache.spark.sql.catalyst.plans.JoinType
3737
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
38-
import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, StringConcat}
38+
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
3939
import org.apache.spark.sql.catalyst.util.truncatedString
4040
import org.apache.spark.sql.internal.SQLConf
4141
import org.apache.spark.sql.types._
@@ -74,13 +74,24 @@ object CurrentOrigin {
7474
}
7575
}
7676

77+
// The name of the tree node tag. This is preferred over using string directly, as we can easily
78+
// find all the defined tags.
79+
case class TreeNodeTagName(name: String)
80+
7781
// scalastyle:off
7882
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
7983
// scalastyle:on
8084
self: BaseType =>
8185

8286
val origin: Origin = CurrentOrigin.get
8387

88+
/**
89+
* A mutable map for holding auxiliary information of this tree node. It will be carried over
90+
* when this node is copied via `makeCopy`. If a user copies the tree node via other ways like the
91+
* `copy` method, it's his responsibility to carry over the tags.
92+
*/
93+
val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty
94+
8495
/**
8596
* Returns a Seq of the children of this node.
8697
* Children should not change. Immutability required for containsChild optimization
@@ -262,6 +273,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
262273
if (this fastEquals afterRule) {
263274
mapChildren(_.transformDown(rule))
264275
} else {
276+
// If the transform function replaces this node with a new one of the same type, carry over
277+
// the tags.
278+
if (afterRule.getClass == this.getClass) {
279+
afterRule.tags ++= this.tags
280+
}
281+
265282
afterRule.mapChildren(_.transformDown(rule))
266283
}
267284
}
@@ -280,9 +297,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
280297
rule.applyOrElse(this, identity[BaseType])
281298
}
282299
} else {
283-
CurrentOrigin.withOrigin(origin) {
300+
val newNode = CurrentOrigin.withOrigin(origin) {
284301
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
285302
}
303+
// If the transform function replaces this node with a new one of the same type, carry over
304+
// the tags.
305+
if (newNode.getClass == this.getClass) {
306+
newNode.tags ++= this.tags
307+
}
308+
newNode
286309
}
287310
}
288311

@@ -402,7 +425,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
402425

403426
try {
404427
CurrentOrigin.withOrigin(origin) {
405-
defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
428+
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
429+
res.tags ++= this.tags
430+
res
406431
}
407432
} catch {
408433
case e: java.lang.IllegalArgumentException =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,4 +620,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
620620
assert(planString.startsWith("Truncated plan of"))
621621
}
622622
}
623+
624+
test("tags will be carried over after copy") {
625+
626+
}
623627
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode {
6363
abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
6464
self: SparkPlanner =>
6565

66+
override def plan(plan: LogicalPlan): Iterator[SparkPlan] = {
67+
super.plan(plan).map { p =>
68+
val logicalPlan = plan match {
69+
case ReturnAnswer(rootPlan) => rootPlan
70+
case _ => plan
71+
}
72+
p.tags += QueryPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan
73+
p
74+
}
75+
}
76+
6677
/**
6778
* Plans special cases of limit operators.
6879
*/
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import scala.reflect.ClassTag
21+
22+
import org.apache.spark.sql.TPCDSQuerySuite
23+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
24+
import org.apache.spark.sql.catalyst.plans.QueryPlan
25+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window}
26+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
27+
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
28+
import org.apache.spark.sql.execution.datasources.LogicalRelation
29+
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation}
30+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
31+
import org.apache.spark.sql.execution.joins._
32+
import org.apache.spark.sql.execution.window.WindowExec
33+
34+
class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
35+
36+
override protected def checkGeneratedCode(plan: SparkPlan): Unit = {
37+
super.checkGeneratedCode(plan)
38+
checkLogicalPlanTag(plan)
39+
}
40+
41+
private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = {
42+
// TODO: aggregate node without aggregate expressions can also be a final aggregate, but
43+
// currently the aggregate node doesn't have a final/partial flag.
44+
aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final)
45+
}
46+
47+
// A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes.
48+
private def isScanPlanTree(plan: SparkPlan): Boolean = plan match {
49+
case p: ProjectExec => isScanPlanTree(p.child)
50+
case f: FilterExec => isScanPlanTree(f.child)
51+
case _: LeafExecNode => true
52+
case _ => false
53+
}
54+
55+
private def checkLogicalPlanTag(plan: SparkPlan): Unit = {
56+
plan match {
57+
case _: HashJoin | _: BroadcastNestedLoopJoinExec | _: CartesianProductExec
58+
| _: ShuffledHashJoinExec | _: SortMergeJoinExec =>
59+
assertLogicalPlanType[Join](plan)
60+
61+
// There is no corresponding logical plan for the physical partial aggregate.
62+
case agg: HashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
63+
assertLogicalPlanType[Aggregate](plan)
64+
case agg: ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
65+
assertLogicalPlanType[Aggregate](plan)
66+
case agg: SortAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
67+
assertLogicalPlanType[Aggregate](plan)
68+
69+
case _: WindowExec =>
70+
assertLogicalPlanType[Window](plan)
71+
72+
case _: UnionExec =>
73+
assertLogicalPlanType[Union](plan)
74+
75+
case _: SampleExec =>
76+
assertLogicalPlanType[Sample](plan)
77+
78+
case _: GenerateExec =>
79+
assertLogicalPlanType[Generate](plan)
80+
81+
// The exchange related nodes are created after the planning, they don't have corresponding
82+
// logical plan.
83+
case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec =>
84+
assert(!plan.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME))
85+
86+
case _ if isScanPlanTree(plan) =>
87+
// The strategies for planning scan can remove or add FilterExec/ProjectExec nodes,
88+
// so it's not simple to check. Instead, we only check that the origin LogicalPlan
89+
// contains the corresponding leaf node of the SparkPlan.
90+
// a strategy might remove the filter if it's totally pushed down, e.g.:
91+
// logical = Project(Filter(Scan A))
92+
// physical = ProjectExec(ScanExec A)
93+
// we only check that leaf modes match between logical and physical plan.
94+
val logicalLeaves = getLogicalPlan(plan).collectLeaves()
95+
val physicalLeaves = plan.collectLeaves()
96+
assert(logicalLeaves.length == 1)
97+
assert(physicalLeaves.length == 1)
98+
physicalLeaves.head match {
99+
case _: RangeExec => logicalLeaves.head.isInstanceOf[Range]
100+
case _: DataSourceScanExec => logicalLeaves.head.isInstanceOf[LogicalRelation]
101+
case _: InMemoryTableScanExec => logicalLeaves.head.isInstanceOf[InMemoryRelation]
102+
case _: LocalTableScanExec => logicalLeaves.head.isInstanceOf[LocalRelation]
103+
case _: ExternalRDDScanExec[_] => logicalLeaves.head.isInstanceOf[ExternalRDD[_]]
104+
case _: BatchScanExec => logicalLeaves.head.isInstanceOf[DataSourceV2Relation]
105+
case _ =>
106+
}
107+
// Do not need to check the children recursively.
108+
return
109+
110+
case _ =>
111+
}
112+
113+
plan.children.foreach(checkLogicalPlanTag)
114+
plan.subqueries.foreach(checkLogicalPlanTag)
115+
}
116+
117+
private def getLogicalPlan(node: SparkPlan): LogicalPlan = {
118+
assert(node.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME),
119+
node.getClass.getSimpleName + " does not have a logical plan link")
120+
node.tags(QueryPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan]
121+
}
122+
123+
private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = {
124+
val logicalPlan = getLogicalPlan(node)
125+
val expectedCls = implicitly[ClassTag[T]].runtimeClass
126+
assert(expectedCls == logicalPlan.getClass)
127+
}
128+
}

0 commit comments

Comments
 (0)