Skip to content

Commit 175e429

Browse files
maryannxuecloud-fan
andcommitted
[SPARK-37670][SQL] Support predicate pushdown and column pruning for de-duped CTEs
### What changes were proposed in this pull request? This PR adds predicate push-down and column pruning to CTEs that are not inlined as well as fixes a few potential correctness issues: 1) Replace (previously not inlined) CTE refs with Repartition operations at the end of logical plan optimization so that WithCTE is not carried over to physical plan. As a result, we can simplify the logic of physical planning, as well as avoid a correctness issue where the logical link of a physical plan node can point to `WithCTE` and lead to unexpected behaviors in AQE, e.g., class cast exceptions in DPP. 2) Pull (not inlined) CTE defs from subqueries up to the main query level, in order to avoid creating copies of the same CTE def during predicate push-downs and other transformations. 3) Make CTE IDs more deterministic by starting from 0 for each query. ### Why are the changes needed? Improve de-duped CTEs' performance with predicate pushdown and column pruning; fixes de-duped CTEs' correctness issues. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added UTs. Closes apache#34929 from maryannxue/cte-followup. Lead-authored-by: Maryann Xue <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 4b38188 commit 175e429

20 files changed

Lines changed: 962 additions & 308 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ object CTESubstitution extends Rule[LogicalPlan] {
6969
if (cteDefs.isEmpty) {
7070
substituted
7171
} else if (substituted eq lastSubstituted.get) {
72-
WithCTE(substituted, cteDefs.toSeq)
72+
WithCTE(substituted, cteDefs.sortBy(_.id).toSeq)
7373
} else {
7474
var done = false
7575
substituted.resolveOperatorsWithPruning(_ => !done) {
7676
case p if p eq lastSubstituted.get =>
7777
done = true
78-
WithCTE(p, cteDefs.toSeq)
78+
WithCTE(p, cteDefs.sortBy(_.id).toSeq)
7979
}
8080
}
8181
}
@@ -203,6 +203,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
203203
cteDefs: mutable.ArrayBuffer[CTERelationDef]): Seq[(String, CTERelationDef)] = {
204204
val resolvedCTERelations = new mutable.ArrayBuffer[(String, CTERelationDef)](relations.size)
205205
for ((name, relation) <- relations) {
206+
val lastCTEDefCount = cteDefs.length
206207
val innerCTEResolved = if (isLegacy) {
207208
// In legacy mode, outer CTE relations take precedence. Here we don't resolve the inner
208209
// `With` nodes, later we will substitute `UnresolvedRelation`s with outer CTE relations.
@@ -211,8 +212,33 @@ object CTESubstitution extends Rule[LogicalPlan] {
211212
} else {
212213
// A CTE definition might contain an inner CTE that has a higher priority, so traverse and
213214
// substitute CTE defined in `relation` first.
215+
// NOTE: we must call `traverseAndSubstituteCTE` before `substituteCTE`, as the relations
216+
// in the inner CTE have higher priority over the relations in the outer CTE when resolving
217+
// inner CTE relations. For example:
218+
// WITH t1 AS (SELECT 1)
219+
// t2 AS (
220+
// WITH t1 AS (SELECT 2)
221+
// WITH t3 AS (SELECT * FROM t1)
222+
// )
223+
// t3 should resolve the t1 to `SELECT 2` instead of `SELECT 1`.
214224
traverseAndSubstituteCTE(relation, isCommand, cteDefs)._1
215225
}
226+
227+
if (cteDefs.length > lastCTEDefCount) {
228+
// We have added more CTE relations to the `cteDefs` from the inner CTE, and these relations
229+
// should also be substituted with `resolvedCTERelations` as inner CTE relation can refer to
230+
// outer CTE relation. For example:
231+
// WITH t1 AS (SELECT 1)
232+
// t2 AS (
233+
// WITH t3 AS (SELECT * FROM t1)
234+
// )
235+
for (i <- lastCTEDefCount until cteDefs.length) {
236+
val substituted =
237+
substituteCTE(cteDefs(i).child, isLegacy || isCommand, resolvedCTERelations.toSeq)
238+
cteDefs(i) = cteDefs(i).copy(child = substituted)
239+
}
240+
}
241+
216242
// CTE definition can reference a previous one
217243
val substituted =
218244
substituteCTE(innerCTEResolved, isLegacy || isCommand, resolvedCTERelations.toSeq)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.AnalysisException
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, PercentileCont}
25-
import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery}
25+
import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery, InlineCTE}
2626
import org.apache.spark.sql.catalyst.plans._
2727
import org.apache.spark.sql.catalyst.plans.logical._
2828
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -93,8 +93,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
9393

9494
def checkAnalysis(plan: LogicalPlan): Unit = {
9595
// We transform up and order the rules so as to catch the first possible failure instead
96-
// of the result of cascading resolution failures.
97-
plan.foreachUp {
96+
// of the result of cascading resolution failures. Inline all CTEs in the plan to help check
97+
// query plan structures in subqueries.
98+
val inlineCTE = InlineCTE(alwaysInline = true)
99+
inlineCTE(plan).foreachUp {
98100

99101
case p if p.analyzed => // Skip already analyzed sub-plans
100102

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,37 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
2828

2929
/**
3030
* Inlines CTE definitions into corresponding references if either of the conditions satisfies:
31-
* 1. The CTE definition does not contain any non-deterministic expressions. If this CTE
32-
* definition references another CTE definition that has non-deterministic expressions, it
33-
* is still OK to inline the current CTE definition.
31+
* 1. The CTE definition does not contain any non-deterministic expressions or contains attribute
32+
* references to an outer query. If this CTE definition references another CTE definition that
33+
* has non-deterministic expressions, it is still OK to inline the current CTE definition.
3434
* 2. The CTE definition is only referenced once throughout the main query and all the subqueries.
3535
*
36-
* In addition, due to the complexity of correlated subqueries, all CTE references in correlated
37-
* subqueries are inlined regardless of the conditions above.
36+
* CTE definitions that appear in subqueries and are not inlined will be pulled up to the main
37+
* query level.
38+
*
39+
* @param alwaysInline if true, inline all CTEs in the query plan.
3840
*/
39-
object InlineCTE extends Rule[LogicalPlan] {
41+
case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
42+
4043
override def apply(plan: LogicalPlan): LogicalPlan = {
4144
if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
4245
val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int)]
4346
buildCTEMap(plan, cteMap)
44-
inlineCTE(plan, cteMap, forceInline = false)
47+
val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
48+
val inlined = inlineCTE(plan, cteMap, notInlined)
49+
// CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add
50+
// WithCTE as top node here.
51+
if (notInlined.isEmpty) {
52+
inlined
53+
} else {
54+
WithCTE(inlined, notInlined.toSeq)
55+
}
4556
} else {
4657
plan
4758
}
4859
}
4960

50-
private def shouldInline(cteDef: CTERelationDef, refCount: Int): Boolean = {
61+
private def shouldInline(cteDef: CTERelationDef, refCount: Int): Boolean = alwaysInline || {
5162
// We do not need to check enclosed `CTERelationRef`s for `deterministic` or `OuterReference`,
5263
// because:
5364
// 1) It is fine to inline a CTE if it references another CTE that is non-deterministic;
@@ -93,25 +104,24 @@ object InlineCTE extends Rule[LogicalPlan] {
93104
private def inlineCTE(
94105
plan: LogicalPlan,
95106
cteMap: mutable.HashMap[Long, (CTERelationDef, Int)],
96-
forceInline: Boolean): LogicalPlan = {
97-
val (stripped, notInlined) = plan match {
107+
notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = {
108+
plan match {
98109
case WithCTE(child, cteDefs) =>
99-
val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
100110
cteDefs.foreach { cteDef =>
101111
val (cte, refCount) = cteMap(cteDef.id)
102112
if (refCount > 0) {
103-
val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, forceInline))
113+
val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined))
104114
cteMap.update(cteDef.id, (inlined, refCount))
105-
if (!forceInline && !shouldInline(inlined, refCount)) {
115+
if (!shouldInline(inlined, refCount)) {
106116
notInlined.append(inlined)
107117
}
108118
}
109119
}
110-
(inlineCTE(child, cteMap, forceInline), notInlined.toSeq)
120+
inlineCTE(child, cteMap, notInlined)
111121

112122
case ref: CTERelationRef =>
113123
val (cteDef, refCount) = cteMap(ref.cteId)
114-
val newRef = if (forceInline || shouldInline(cteDef, refCount)) {
124+
if (shouldInline(cteDef, refCount)) {
115125
if (ref.outputSet == cteDef.outputSet) {
116126
cteDef.child
117127
} else {
@@ -125,24 +135,16 @@ object InlineCTE extends Rule[LogicalPlan] {
125135
} else {
126136
ref
127137
}
128-
(newRef, Seq.empty)
129138

130139
case _ if plan.containsPattern(CTE) =>
131-
val newPlan = plan
132-
.withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, forceInline)))
140+
plan
141+
.withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined)))
133142
.transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
134143
case e: SubqueryExpression =>
135-
e.withNewPlan(inlineCTE(e.plan, cteMap, forceInline = e.isCorrelated))
144+
e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined))
136145
}
137-
(newPlan, Seq.empty)
138146

139-
case _ => (plan, Seq.empty)
140-
}
141-
142-
if (notInlined.isEmpty) {
143-
stripped
144-
} else {
145-
WithCTE(stripped, notInlined)
147+
case _ => plan
146148
}
147149
}
148150
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
128128
OptimizeUpdateFields,
129129
SimplifyExtractValueOps,
130130
OptimizeCsvJsonExprs,
131-
CombineConcats) ++
131+
CombineConcats,
132+
PushdownPredicatesAndPruneColumnsForCTEDef) ++
132133
extendedOperatorOptimizationRules
133134

134135
val operatorOptimizationBatch: Seq[Batch] = {
@@ -147,22 +148,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
147148
}
148149

149150
val batches = (
150-
// Technically some of the rules in Finish Analysis are not optimizer rules and belong more
151-
// in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime).
152-
// However, because we also use the analyzer to canonicalized queries (for view definition),
153-
// we do not eliminate subqueries or compute current time in the analyzer.
154-
Batch("Finish Analysis", Once,
155-
EliminateResolvedHint,
156-
EliminateSubqueryAliases,
157-
EliminateView,
158-
InlineCTE,
159-
ReplaceExpressions,
160-
RewriteNonCorrelatedExists,
161-
PullOutGroupingExpressions,
162-
ComputeCurrentTime,
163-
ReplaceCurrentLike(catalogManager),
164-
SpecialDatetimeValues,
165-
RewriteAsOfJoin) ::
151+
Batch("Finish Analysis", Once, FinishAnalysis) ::
166152
//////////////////////////////////////////////////////////////////////////////////////////
167153
// Optimizer rules start here
168154
//////////////////////////////////////////////////////////////////////////////////////////
@@ -172,6 +158,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
172158
// extra operators between two adjacent Union operators.
173159
// - Call CombineUnions again in Batch("Operator Optimizations"),
174160
// since the other rules might make two separate Unions operators adjacent.
161+
Batch("Inline CTE", Once,
162+
InlineCTE()) ::
175163
Batch("Union", Once,
176164
RemoveNoopOperators,
177165
CombineUnions,
@@ -208,6 +196,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
208196
RemoveLiteralFromGroupExpressions,
209197
RemoveRepetitionFromGroupExpressions) :: Nil ++
210198
operatorOptimizationBatch) :+
199+
Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo) :+
211200
// This batch rewrites plans after the operator optimization and
212201
// before any batches that depend on stats.
213202
Batch("Pre CBO Rules", Once, preCBORules: _*) :+
@@ -266,14 +255,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
266255
* (defaultBatches - (excludedRules - nonExcludableRules)).
267256
*/
268257
def nonExcludableRules: Seq[String] =
269-
EliminateDistinct.ruleName ::
270-
EliminateResolvedHint.ruleName ::
271-
EliminateSubqueryAliases.ruleName ::
272-
EliminateView.ruleName ::
273-
ReplaceExpressions.ruleName ::
274-
ComputeCurrentTime.ruleName ::
275-
SpecialDatetimeValues.ruleName ::
276-
ReplaceCurrentLike(catalogManager).ruleName ::
258+
FinishAnalysis.ruleName ::
277259
RewriteDistinctAggregates.ruleName ::
278260
ReplaceDeduplicateWithAggregate.ruleName ::
279261
ReplaceIntersectWithSemiJoin.ruleName ::
@@ -287,10 +269,38 @@ abstract class Optimizer(catalogManager: CatalogManager)
287269
RewritePredicateSubquery.ruleName ::
288270
NormalizeFloatingNumbers.ruleName ::
289271
ReplaceUpdateFieldsExpression.ruleName ::
290-
PullOutGroupingExpressions.ruleName ::
291-
RewriteAsOfJoin.ruleName ::
292272
RewriteLateralSubquery.ruleName :: Nil
293273

274+
/**
275+
* Apply finish-analysis rules for the entire plan including all subqueries.
276+
*/
277+
object FinishAnalysis extends Rule[LogicalPlan] {
278+
// Technically some of the rules in Finish Analysis are not optimizer rules and belong more
279+
// in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime).
280+
// However, because we also use the analyzer to canonicalized queries (for view definition),
281+
// we do not eliminate subqueries or compute current time in the analyzer.
282+
private val rules = Seq(
283+
EliminateResolvedHint,
284+
EliminateSubqueryAliases,
285+
EliminateView,
286+
ReplaceExpressions,
287+
RewriteNonCorrelatedExists,
288+
PullOutGroupingExpressions,
289+
ComputeCurrentTime,
290+
ReplaceCurrentLike(catalogManager),
291+
SpecialDatetimeValues,
292+
RewriteAsOfJoin)
293+
294+
override def apply(plan: LogicalPlan): LogicalPlan = {
295+
rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
296+
.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
297+
case s: SubqueryExpression =>
298+
val Subquery(newPlan, _) = apply(Subquery.fromExpression(s))
299+
s.withNewPlan(newPlan)
300+
}
301+
}
302+
}
303+
294304
/**
295305
* Optimize all the subqueries inside expression.
296306
*/

0 commit comments

Comments
 (0)