|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.sql.execution |
| 19 | + |
| 20 | +import scala.reflect.ClassTag |
| 21 | + |
| 22 | +import org.apache.spark.sql.TPCDSQuerySuite |
| 23 | +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} |
| 24 | +import org.apache.spark.sql.catalyst.plans.QueryPlan |
| 25 | +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window} |
| 26 | +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} |
| 27 | +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} |
| 28 | +import org.apache.spark.sql.execution.datasources.LogicalRelation |
| 29 | +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} |
| 30 | +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} |
| 31 | +import org.apache.spark.sql.execution.joins._ |
| 32 | +import org.apache.spark.sql.execution.window.WindowExec |
| 33 | + |
| 34 | +class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { |
| 35 | + |
| 36 | + override protected def checkGeneratedCode(plan: SparkPlan): Unit = { |
| 37 | + super.checkGeneratedCode(plan) |
| 38 | + checkLogicalPlanTag(plan) |
| 39 | + } |
| 40 | + |
| 41 | + private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = { |
| 42 | + // TODO: aggregate node without aggregate expressions can also be a final aggregate, but |
| 43 | + // currently the aggregate node doesn't have a final/partial flag. |
| 44 | + aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final) |
| 45 | + } |
| 46 | + |
| 47 | + // A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes. |
| 48 | + private def isScanPlanTree(plan: SparkPlan): Boolean = plan match { |
| 49 | + case p: ProjectExec => isScanPlanTree(p.child) |
| 50 | + case f: FilterExec => isScanPlanTree(f.child) |
| 51 | + case _: LeafExecNode => true |
| 52 | + case _ => false |
| 53 | + } |
| 54 | + |
| 55 | + private def checkLogicalPlanTag(plan: SparkPlan): Unit = { |
| 56 | + plan match { |
| 57 | + case _: HashJoin | _: BroadcastNestedLoopJoinExec | _: CartesianProductExec |
| 58 | + | _: ShuffledHashJoinExec | _: SortMergeJoinExec => |
| 59 | + assertLogicalPlanType[Join](plan) |
| 60 | + |
| 61 | + // There is no corresponding logical plan for the physical partial aggregate. |
| 62 | + case agg: HashAggregateExec if isFinalAgg(agg.aggregateExpressions) => |
| 63 | + assertLogicalPlanType[Aggregate](plan) |
| 64 | + case agg: ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) => |
| 65 | + assertLogicalPlanType[Aggregate](plan) |
| 66 | + case agg: SortAggregateExec if isFinalAgg(agg.aggregateExpressions) => |
| 67 | + assertLogicalPlanType[Aggregate](plan) |
| 68 | + |
| 69 | + case _: WindowExec => |
| 70 | + assertLogicalPlanType[Window](plan) |
| 71 | + |
| 72 | + case _: UnionExec => |
| 73 | + assertLogicalPlanType[Union](plan) |
| 74 | + |
| 75 | + case _: SampleExec => |
| 76 | + assertLogicalPlanType[Sample](plan) |
| 77 | + |
| 78 | + case _: GenerateExec => |
| 79 | + assertLogicalPlanType[Generate](plan) |
| 80 | + |
| 81 | + // The exchange related nodes are created after the planning, they don't have corresponding |
| 82 | + // logical plan. |
| 83 | + case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec => |
| 84 | + assert(!plan.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME)) |
| 85 | + |
| 86 | + case _ if isScanPlanTree(plan) => |
| 87 | + // The strategies for planning scan can remove or add FilterExec/ProjectExec nodes, |
| 88 | + // so it's not simple to check. Instead, we only check that the origin LogicalPlan |
| 89 | + // contains the corresponding leaf node of the SparkPlan. |
| 90 | + // a strategy might remove the filter if it's totally pushed down, e.g.: |
| 91 | + // logical = Project(Filter(Scan A)) |
| 92 | + // physical = ProjectExec(ScanExec A) |
| 93 | + // we only check that leaf modes match between logical and physical plan. |
| 94 | + val logicalLeaves = getLogicalPlan(plan).collectLeaves() |
| 95 | + val physicalLeaves = plan.collectLeaves() |
| 96 | + assert(logicalLeaves.length == 1) |
| 97 | + assert(physicalLeaves.length == 1) |
| 98 | + physicalLeaves.head match { |
| 99 | + case _: RangeExec => logicalLeaves.head.isInstanceOf[Range] |
| 100 | + case _: DataSourceScanExec => logicalLeaves.head.isInstanceOf[LogicalRelation] |
| 101 | + case _: InMemoryTableScanExec => logicalLeaves.head.isInstanceOf[InMemoryRelation] |
| 102 | + case _: LocalTableScanExec => logicalLeaves.head.isInstanceOf[LocalRelation] |
| 103 | + case _: ExternalRDDScanExec[_] => logicalLeaves.head.isInstanceOf[ExternalRDD[_]] |
| 104 | + case _: BatchScanExec => logicalLeaves.head.isInstanceOf[DataSourceV2Relation] |
| 105 | + case _ => |
| 106 | + } |
| 107 | + // Do not need to check the children recursively. |
| 108 | + return |
| 109 | + |
| 110 | + case _ => |
| 111 | + } |
| 112 | + |
| 113 | + plan.children.foreach(checkLogicalPlanTag) |
| 114 | + plan.subqueries.foreach(checkLogicalPlanTag) |
| 115 | + } |
| 116 | + |
| 117 | + private def getLogicalPlan(node: SparkPlan): LogicalPlan = { |
| 118 | + assert(node.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME), |
| 119 | + node.getClass.getSimpleName + " does not have a logical plan link") |
| 120 | + node.tags(QueryPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan] |
| 121 | + } |
| 122 | + |
| 123 | + private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = { |
| 124 | + val logicalPlan = getLogicalPlan(node) |
| 125 | + val expectedCls = implicitly[ClassTag[T]].runtimeClass |
| 126 | + assert(expectedCls == logicalPlan.getClass) |
| 127 | + } |
| 128 | +} |
0 commit comments