Skip to content

Commit 2bba472

Browse files
cloud-fanpengbo
authored andcommitted
[SPARK-27747][SQL] add a logical plan link in the physical plan
It's pretty useful if we can convert a physical plan back to a logical plan, e.g., in apache#24389 This PR introduces a new feature to `TreeNode`, which allows `TreeNode` to carry some extra information via a mutable map, and keep the information when it's copied. The planner leverages this feature to put the logical plan into the physical plan. a test suite that runs all TPCDS queries and checks that some common physical plans contain the corresponding logical plans. Closes apache#24626 from cloud-fan/link. Lead-authored-by: Wenchen Fan <wenchen@databricks.com> Co-authored-by: Peng Bo <bo.peng1019@gmail.com> Signed-off-by: gatorsmile <gatorsmile@gmail.com>
1 parent f67d752 commit 2bba472

6 files changed

Lines changed: 227 additions & 6 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,11 @@ case class OneRowRelation() extends LeafNode {
988988
override def computeStats(): Statistics = Statistics(sizeInBytes = 1)
989989

990990
/** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */
991-
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = OneRowRelation()
991+
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
992+
val newCopy = OneRowRelation()
993+
newCopy.tags ++= this.tags
994+
newCopy
995+
}
992996
}
993997

994998
/** A logical plan for `dropDuplicates`. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees
1919

2020
import java.util.UUID
2121

22-
import scala.collection.Map
22+
import scala.collection.{mutable, Map}
2323
import scala.reflect.ClassTag
2424

2525
import org.apache.commons.lang3.ClassUtils
@@ -74,13 +74,23 @@ object CurrentOrigin {
7474
}
7575
}
7676

77+
// The name of the tree node tag. This is preferred over using string directly, as we can easily
78+
// find all the defined tags.
79+
case class TreeNodeTagName(name: String)
80+
7781
// scalastyle:off
7882
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
7983
// scalastyle:on
8084
self: BaseType =>
8185

8286
val origin: Origin = CurrentOrigin.get
8387

88+
/**
89+
* A mutable map for holding auxiliary information of this tree node. It will be carried over
90+
* when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`.
91+
*/
92+
val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty
93+
8494
/**
8595
* Returns a Seq of the children of this node.
8696
* Children should not change. Immutability required for containsChild optimization
@@ -264,6 +274,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
264274
if (this fastEquals afterRule) {
265275
mapChildren(_.transformDown(rule))
266276
} else {
277+
// If the transform function replaces this node with a new one, carry over the tags.
278+
afterRule.tags ++= this.tags
267279
afterRule.mapChildren(_.transformDown(rule))
268280
}
269281
}
@@ -277,7 +289,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
277289
*/
278290
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
279291
val afterRuleOnChildren = mapChildren(_.transformUp(rule))
280-
if (this fastEquals afterRuleOnChildren) {
292+
val newNode = if (this fastEquals afterRuleOnChildren) {
281293
CurrentOrigin.withOrigin(origin) {
282294
rule.applyOrElse(this, identity[BaseType])
283295
}
@@ -286,6 +298,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
286298
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
287299
}
288300
}
301+
// If the transform function replaces this node with a new one, carry over the tags.
302+
newNode.tags ++= this.tags
303+
newNode
289304
}
290305

291306
/**
@@ -404,7 +419,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
404419

405420
try {
406421
CurrentOrigin.withOrigin(origin) {
407-
defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
422+
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
423+
res.tags ++= this.tags
424+
res
408425
}
409426
} catch {
410427
case e: java.lang.IllegalArgumentException =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,4 +595,55 @@ class TreeNodeSuite extends SparkFunSuite {
595595
val expected = Coalesce(Stream(Literal(1), Literal(3)))
596596
assert(result === expected)
597597
}
598+
599+
test("tags will be carried over after copy & transform") {
600+
withClue("makeCopy") {
601+
val node = Dummy(None)
602+
node.tags += TreeNodeTagName("test") -> "a"
603+
val copied = node.makeCopy(Array(Some(Literal(1))))
604+
assert(copied.tags(TreeNodeTagName("test")) == "a")
605+
}
606+
607+
def checkTransform(
608+
sameTypeTransform: Expression => Expression,
609+
differentTypeTransform: Expression => Expression): Unit = {
610+
val child = Dummy(None)
611+
child.tags += TreeNodeTagName("test") -> "child"
612+
val node = Dummy(Some(child))
613+
node.tags += TreeNodeTagName("test") -> "parent"
614+
615+
val transformed = sameTypeTransform(node)
616+
// Both the child and parent keep the tags
617+
assert(transformed.tags(TreeNodeTagName("test")) == "parent")
618+
assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child")
619+
620+
val transformed2 = differentTypeTransform(node)
621+
// Both the child and parent keep the tags, even if we transform the node to a new one of
622+
// different type.
623+
assert(transformed2.tags(TreeNodeTagName("test")) == "parent")
624+
assert(transformed2.children.head.tags.contains(TreeNodeTagName("test")))
625+
}
626+
627+
withClue("transformDown") {
628+
checkTransform(
629+
sameTypeTransform = _ transformDown {
630+
case Dummy(None) => Dummy(Some(Literal(1)))
631+
},
632+
differentTypeTransform = _ transformDown {
633+
case Dummy(None) => Literal(1)
634+
635+
})
636+
}
637+
638+
withClue("transformUp") {
639+
checkTransform(
640+
sameTypeTransform = _ transformUp {
641+
case Dummy(None) => Dummy(Some(Literal(1)))
642+
},
643+
differentTypeTransform = _ transformUp {
644+
case Dummy(None) => Literal(1)
645+
646+
})
647+
}
648+
}
598649
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
2020
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
2121

2222
import scala.collection.mutable.ArrayBuffer
23-
import scala.concurrent.ExecutionContext
2423

2524
import org.codehaus.commons.compiler.CompileException
2625
import org.codehaus.janino.InternalCompilerException
@@ -35,9 +34,15 @@ import org.apache.spark.sql.catalyst.expressions._
3534
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
3635
import org.apache.spark.sql.catalyst.plans.QueryPlan
3736
import org.apache.spark.sql.catalyst.plans.physical._
37+
import org.apache.spark.sql.catalyst.trees.TreeNodeTagName
3838
import org.apache.spark.sql.execution.metric.SQLMetric
3939
import org.apache.spark.sql.types.DataType
40-
import org.apache.spark.util.ThreadUtils
40+
41+
object SparkPlan {
42+
// a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag
43+
// when converting a logical plan to a physical plan.
44+
val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan")
45+
}
4146

4247
/**
4348
* The base class for physical operators.

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode {
6363
abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
6464
self: SparkPlanner =>
6565

66+
override def plan(plan: LogicalPlan): Iterator[SparkPlan] = {
67+
super.plan(plan).map { p =>
68+
val logicalPlan = plan match {
69+
case ReturnAnswer(rootPlan) => rootPlan
70+
case _ => plan
71+
}
72+
p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan
73+
p
74+
}
75+
}
76+
6677
/**
6778
* Plans special cases of limit operators.
6879
*/
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import scala.reflect.ClassTag
21+
22+
import org.apache.spark.sql.TPCDSQuerySuite
23+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
24+
import org.apache.spark.sql.catalyst.plans.QueryPlan
25+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window}
26+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
27+
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
28+
import org.apache.spark.sql.execution.datasources.LogicalRelation
29+
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation}
30+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
31+
import org.apache.spark.sql.execution.joins._
32+
import org.apache.spark.sql.execution.window.WindowExec
33+
34+
class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
35+
36+
override protected def checkGeneratedCode(plan: SparkPlan): Unit = {
37+
super.checkGeneratedCode(plan)
38+
checkLogicalPlanTag(plan)
39+
}
40+
41+
private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = {
42+
// TODO: aggregate node without aggregate expressions can also be a final aggregate, but
43+
// currently the aggregate node doesn't have a final/partial flag.
44+
aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final)
45+
}
46+
47+
// A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes.
48+
private def isScanPlanTree(plan: SparkPlan): Boolean = plan match {
49+
case p: ProjectExec => isScanPlanTree(p.child)
50+
case f: FilterExec => isScanPlanTree(f.child)
51+
case _: LeafExecNode => true
52+
case _ => false
53+
}
54+
55+
private def checkLogicalPlanTag(plan: SparkPlan): Unit = {
56+
plan match {
57+
case _: HashJoin | _: BroadcastNestedLoopJoinExec | _: CartesianProductExec
58+
| _: ShuffledHashJoinExec | _: SortMergeJoinExec =>
59+
assertLogicalPlanType[Join](plan)
60+
61+
// There is no corresponding logical plan for the physical partial aggregate.
62+
case agg: HashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
63+
assertLogicalPlanType[Aggregate](plan)
64+
case agg: ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
65+
assertLogicalPlanType[Aggregate](plan)
66+
case agg: SortAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
67+
assertLogicalPlanType[Aggregate](plan)
68+
69+
case _: WindowExec =>
70+
assertLogicalPlanType[Window](plan)
71+
72+
case _: UnionExec =>
73+
assertLogicalPlanType[Union](plan)
74+
75+
case _: SampleExec =>
76+
assertLogicalPlanType[Sample](plan)
77+
78+
case _: GenerateExec =>
79+
assertLogicalPlanType[Generate](plan)
80+
81+
// The exchange related nodes are created after the planning, they don't have corresponding
82+
// logical plan.
83+
case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec =>
84+
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
85+
86+
// The subquery exec nodes are just wrappers of the actual nodes, they don't have
87+
// corresponding logical plan.
88+
case _: SubqueryExec | _: ReusedSubqueryExec =>
89+
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
90+
91+
case _ if isScanPlanTree(plan) =>
92+
// The strategies for planning scan can remove or add FilterExec/ProjectExec nodes,
93+
// so it's not simple to check. Instead, we only check that the origin LogicalPlan
94+
// contains the corresponding leaf node of the SparkPlan.
95+
// a strategy might remove the filter if it's totally pushed down, e.g.:
96+
// logical = Project(Filter(Scan A))
97+
// physical = ProjectExec(ScanExec A)
98+
// we only check that leaf modes match between logical and physical plan.
99+
val logicalLeaves = getLogicalPlan(plan).collectLeaves()
100+
val physicalLeaves = plan.collectLeaves()
101+
assert(logicalLeaves.length == 1)
102+
assert(physicalLeaves.length == 1)
103+
physicalLeaves.head match {
104+
case _: RangeExec => logicalLeaves.head.isInstanceOf[Range]
105+
case _: DataSourceScanExec => logicalLeaves.head.isInstanceOf[LogicalRelation]
106+
case _: InMemoryTableScanExec => logicalLeaves.head.isInstanceOf[InMemoryRelation]
107+
case _: LocalTableScanExec => logicalLeaves.head.isInstanceOf[LocalRelation]
108+
case _: ExternalRDDScanExec[_] => logicalLeaves.head.isInstanceOf[ExternalRDD[_]]
109+
case _: BatchScanExec => logicalLeaves.head.isInstanceOf[DataSourceV2Relation]
110+
case _ =>
111+
}
112+
// Do not need to check the children recursively.
113+
return
114+
115+
case _ =>
116+
}
117+
118+
plan.children.foreach(checkLogicalPlanTag)
119+
plan.subqueries.foreach(checkLogicalPlanTag)
120+
}
121+
122+
private def getLogicalPlan(node: SparkPlan): LogicalPlan = {
123+
assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME),
124+
node.getClass.getSimpleName + " does not have a logical plan link")
125+
node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan]
126+
}
127+
128+
private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = {
129+
val logicalPlan = getLogicalPlan(node)
130+
val expectedCls = implicitly[ClassTag[T]].runtimeClass
131+
assert(expectedCls == logicalPlan.getClass)
132+
}
133+
}

0 commit comments

Comments
 (0)