Skip to content

Commit d817fc7

Browse files
committed
split plan and expression pruning
1 parent ed374fe commit d817fc7

10 files changed

Lines changed: 44 additions & 40 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans._
3939
import org.apache.spark.sql.catalyst.plans.logical._
4040
import org.apache.spark.sql.catalyst.rules._
4141
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
42-
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
42+
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, TreeNodeRef}
4343
import org.apache.spark.sql.catalyst.trees.TreePattern._
4444
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
4545
import org.apache.spark.sql.connector.catalog._
@@ -2297,8 +2297,8 @@ class Analyzer(override val catalogManager: CatalogManager)
22972297
* outer plan to get evaluated.
22982298
*/
22992299
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
2300-
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY,
2301-
EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
2300+
plan.transformAllExpressionsWithPruning(AlwaysProcess.fn,
2301+
_.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
23022302
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
23032303
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
23042304
case e @ Exists(sub, _, exprId) if !sub.resolved =>
@@ -3128,7 +3128,7 @@ class Analyzer(override val catalogManager: CatalogManager)
31283128
*/
31293129
object ResolveWindowFrame extends Rule[LogicalPlan] {
31303130
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning(
3131-
_.containsPattern(WINDOW_EXPRESSION), ruleId) {
3131+
AlwaysProcess.fn, _.containsPattern(WINDOW_EXPRESSION), ruleId) {
31323132
case WindowExpression(wf: FrameLessOffsetWindowFunction,
31333133
WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) if wf.frame != f =>
31343134
throw QueryCompilationErrors.cannotSpecifyWindowFrameError(wf.prettyName)
@@ -3154,7 +3154,7 @@ class Analyzer(override val catalogManager: CatalogManager)
31543154
*/
31553155
object ResolveWindowOrder extends Rule[LogicalPlan] {
31563156
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning(
3157-
_.containsPattern(WINDOW_EXPRESSION), ruleId) {
3157+
AlwaysProcess.fn, _.containsPattern(WINDOW_EXPRESSION), ruleId) {
31583158
case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty =>
31593159
throw QueryCompilationErrors.windowFunctionWithWindowFrameNotOrderedError(wf)
31603160
case WindowExpression(rank: RankLike, spec) if spec.resolved =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
2727
import org.apache.spark.sql.catalyst.plans._
2828
import org.apache.spark.sql.catalyst.plans.logical._
2929
import org.apache.spark.sql.catalyst.rules._
30+
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
3031
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
3132
import org.apache.spark.sql.connector.catalog.CatalogManager
3233
import org.apache.spark.sql.internal.SQLConf
@@ -283,7 +284,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
283284
}
284285
}
285286
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
286-
_.containsPattern(PLAN_EXPRESSION), ruleId) {
287+
AlwaysProcess.fn, _.containsPattern(PLAN_EXPRESSION), ruleId) {
287288
case s: SubqueryExpression =>
288289
val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s))
289290
// At this point we have an optimized subquery plan that we are going to attach

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,8 @@ object PullOutGroupingExpressions extends Rule[LogicalPlan] {
5757
case o => o
5858
}
5959
if (complexGroupingExpressionMap.nonEmpty) {
60-
val newAggregateExpressions = a.aggregateExpressions.map(_.transformWithPruning(tpb => {
61-
val e = tpb.asInstanceOf[Expression]
62-
!(AggregateExpression.isAggregate(e) || e.foldable)
63-
}) {
60+
val newAggregateExpressions = a.aggregateExpressions.map(_.transformWithPruning(
61+
e => !(AggregateExpression.isAggregate(e) || e.foldable)) {
6462
case e if complexGroupingExpressionMap.contains(e.canonicalized) =>
6563
complexGroupingExpressionMap.get(e.canonicalized).map(_.toAttribute).getOrElse(e)
6664
}.asInstanceOf[NamedExpression])

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ object LikeSimplification extends Rule[LogicalPlan] {
724724
}
725725

726726
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
727-
_.containsPattern(LIKE_FAMLIY), ruleId) {
727+
AlwaysProcess.fn, _.containsPattern(LIKE_FAMLIY), ruleId) {
728728
case l @ Like(input, Literal(pattern, StringType), escapeChar) =>
729729
if (pattern == null) {
730730
// If pattern is null, return null value directly, since "col like null" == null.
@@ -933,7 +933,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
933933
*/
934934
object SimplifyCasts extends Rule[LogicalPlan] {
935935
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
936-
_.containsPattern(CAST), ruleId) {
936+
AlwaysProcess.fn, _.containsPattern(CAST), ruleId) {
937937
case Cast(e, dataType, _) if e.dataType == dataType => e
938938
case c @ Cast(e, dataType, _) => (e.dataType, dataType) match {
939939
case (ArrayType(from, false), ArrayType(to, true)) if from == to => e
@@ -950,7 +950,7 @@ object SimplifyCasts extends Rule[LogicalPlan] {
950950
*/
951951
object RemoveDispensableExpressions extends Rule[LogicalPlan] {
952952
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
953-
_.containsPattern(UNARY_POSITIVE), ruleId) {
953+
AlwaysProcess.fn, _.containsPattern(UNARY_POSITIVE), ruleId) {
954954
case UnaryPositive(child) => child
955955
}
956956
}
@@ -1006,7 +1006,7 @@ object CombineConcats extends Rule[LogicalPlan] {
10061006
}
10071007

10081008
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
1009-
_.containsPattern(CONCAT), ruleId) {
1009+
AlwaysProcess.fn, _.containsPattern(CONCAT), ruleId) {
10101010
case concat: Concat if hasNestedConcats(concat) =>
10111011
flattenConcats(concat)
10121012
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.rules.RuleId
2626
import org.apache.spark.sql.catalyst.rules.UnknownRuleId
2727
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag}
28-
import org.apache.spark.sql.catalyst.trees.TreePatternBits
2928
import org.apache.spark.sql.internal.SQLConf
3029
import org.apache.spark.sql.types.{DataType, StructType}
3130
import org.apache.spark.util.collection.BitSet
@@ -116,7 +115,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
116115
* subtree. Do not pass it if the rule is not purely functional and reads a
117116
* varying initial state for different invocations.
118117
*/
119-
def transformExpressionsWithPruning(cond: TreePatternBits => Boolean,
118+
def transformExpressionsWithPruning(cond: Expression => Boolean,
120119
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression])
121120
: this.type = {
122121
transformExpressionsDownWithPruning(cond, ruleId)(rule)
@@ -145,7 +144,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
145144
* subtree. Do not pass it if the rule is not purely functional and reads a
146145
* varying initial state for different invocations.
147146
*/
148-
def transformExpressionsDownWithPruning(cond: TreePatternBits => Boolean,
147+
def transformExpressionsDownWithPruning(cond: Expression => Boolean,
149148
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression])
150149
: this.type = {
151150
mapExpressions(_.transformDownWithPruning(cond, ruleId)(rule))
@@ -174,7 +173,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
174173
* subtree. Do not pass it if the rule is not purely functional and reads a
175174
* varying initial state for different invocations.
176175
*/
177-
def transformExpressionsUpWithPruning(cond: TreePatternBits => Boolean,
176+
def transformExpressionsUpWithPruning(cond: Expression => Boolean,
178177
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression])
179178
: this.type = {
180179
mapExpressions(_.transformUpWithPruning(cond, ruleId)(rule))
@@ -220,19 +219,20 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
220219
* and all its children. Note that this method skips expressions inside subqueries.
221220
*/
222221
def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = {
223-
transformAllExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
222+
transformAllExpressionsWithPruning(AlwaysProcess.fn, AlwaysProcess.fn, UnknownRuleId)(rule)
224223
}
225224

226225
/**
227226
* Returns the result of running [[transformExpressionsWithPruning]] on this node
228227
* and all its children. Note that this method skips expressions inside subqueries.
229228
*/
230-
def transformAllExpressionsWithPruning(cond: TreePatternBits => Boolean,
229+
def transformAllExpressionsWithPruning(cond: PlanType => Boolean,
230+
exprCond: Expression => Boolean,
231231
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression])
232232
: this.type = {
233233
transformWithPruning(cond, ruleId) {
234234
case q: QueryPlan[_] =>
235-
q.transformExpressionsWithPruning(cond, ruleId)(rule).asInstanceOf[PlanType]
235+
q.transformExpressionsWithPruning(exprCond, ruleId)(rule).asInstanceOf[PlanType]
236236
}.asInstanceOf[this.type]
237237
}
238238

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expre
2121
import org.apache.spark.sql.catalyst.plans.QueryPlan
2222
import org.apache.spark.sql.catalyst.rules.RuleId
2323
import org.apache.spark.sql.catalyst.rules.UnknownRuleId
24-
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreePatternBits}
24+
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin}
2525
import org.apache.spark.util.Utils
2626

2727

@@ -92,7 +92,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
9292
* subtree. Do not pass it if the rule is not purely functional and reads a
9393
* varying initial state for different invocations.
9494
*/
95-
def resolveOperatorsWithPruning(cond: TreePatternBits => Boolean,
95+
def resolveOperatorsWithPruning(cond: LogicalPlan => Boolean,
9696
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[LogicalPlan, LogicalPlan])
9797
: LogicalPlan = {
9898
resolveOperatorsDownWithPruning(cond, ruleId)(rule)
@@ -126,7 +126,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
126126
* subtree. Do not pass it if the rule is not purely functional and reads a
127127
* varying initial state for different invocations.
128128
*/
129-
def resolveOperatorsUpWithPruning(cond: TreePatternBits => Boolean,
129+
def resolveOperatorsUpWithPruning(cond: LogicalPlan => Boolean,
130130
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[LogicalPlan, LogicalPlan])
131131
: LogicalPlan = {
132132
if (!analyzed && cond.apply(self) && !isRuleIneffective(ruleId)) {
@@ -160,7 +160,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
160160
}
161161

162162
/** Similar to [[resolveOperatorsUpWithPruning]], but does it top-down. */
163-
def resolveOperatorsDownWithPruning(cond: TreePatternBits => Boolean,
163+
def resolveOperatorsDownWithPruning(cond: LogicalPlan => Boolean,
164164
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[LogicalPlan, LogicalPlan])
165165
: LogicalPlan = {
166166
if (!analyzed && cond.apply(self) && !isRuleIneffective(ruleId)) {
@@ -221,7 +221,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
221221
* been analyzed.
222222
*/
223223
def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = {
224-
resolveExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(r)
224+
resolveExpressionsWithPruning(AlwaysProcess.fn, AlwaysProcess.fn, UnknownRuleId)(r)
225225
}
226226

227227
/**
@@ -238,10 +238,10 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
238238
* subtree. Do not pass it if the rule is not purely functional and reads a
239239
* varying initial state for different invocations.
240240
*/
241-
def resolveExpressionsWithPruning(cond: TreePatternBits => Boolean,
241+
def resolveExpressionsWithPruning(cond: LogicalPlan => Boolean, exprCond: Expression => Boolean,
242242
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): LogicalPlan = {
243243
resolveOperatorsWithPruning(cond, ruleId) {
244-
case p => p.transformExpressionsWithPruning(cond, ruleId)(rule)
244+
case p => p.transformExpressionsWithPruning(exprCond, ruleId)(rule)
245245
}
246246
}
247247

@@ -259,7 +259,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
259259
* the scope of a [[resolveOperatorsDown()]] call.
260260
* @see [[org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning()]].
261261
*/
262-
override def transformDownWithPruning(cond: TreePatternBits => Boolean,
262+
override def transformDownWithPruning(cond: LogicalPlan => Boolean,
263263
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[LogicalPlan, LogicalPlan])
264264
: LogicalPlan = {
265265
assertNotAnalysisRule()
@@ -271,7 +271,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
271271
*
272272
* @see [[org.apache.spark.sql.catalyst.trees.TreeNode.transformUpWithPruning()]]
273273
*/
274-
override def transformUpWithPruning(cond: TreePatternBits => Boolean,
274+
override def transformUpWithPruning(cond: LogicalPlan => Boolean,
275275
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[LogicalPlan, LogicalPlan])
276276
: LogicalPlan = {
277277
assertNotAnalysisRule()
@@ -283,11 +283,12 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
283283
* @see [[QueryPlan.transformAllExpressionsWithPruning()]]
284284
*/
285285
override def transformAllExpressionsWithPruning(
286-
cond: TreePatternBits => Boolean,
286+
cond: LogicalPlan => Boolean,
287+
exprCond: Expression => Boolean,
287288
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression])
288289
: this.type = {
289290
assertNotAnalysisRule()
290-
super.transformAllExpressionsWithPruning(cond, ruleId)(rule)
291+
super.transformAllExpressionsWithPruning(cond, exprCond, ruleId)(rule)
291292
}
292293

293294
override def clone(): LogicalPlan = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
440440
* subtree. Do not pass it if the rule is not purely functional and reads a
441441
* varying initial state for different invocations.
442442
*/
443-
def transformWithPruning(cond: TreePatternBits => Boolean,
443+
def transformWithPruning(cond: BaseType => Boolean,
444444
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
445445
: BaseType = {
446446
transformDownWithPruning(cond, ruleId)(rule)
@@ -470,7 +470,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
470470
* subtree. Do not pass it if the rule is not purely functional and reads a
471471
* varying initial state for different invocations.
472472
*/
473-
def transformDownWithPruning(cond: TreePatternBits => Boolean,
473+
def transformDownWithPruning(cond: BaseType => Boolean,
474474
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
475475
: BaseType = {
476476
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
@@ -522,7 +522,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
522522
* subtree. Do not pass it if the rule is not purely functional and reads a
523523
* varying initial state for different invocations.
524524
*/
525-
def transformUpWithPruning(cond: TreePatternBits => Boolean,
525+
def transformUpWithPruning(cond: BaseType => Boolean,
526526
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
527527
: BaseType = {
528528
if (!cond.apply(this) || isRuleIneffective(ruleId)) {

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.adaptive
2020
import org.apache.spark.sql.catalyst.expressions
2121
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, DynamicPruningExpression, ListQuery, Literal}
2222
import org.apache.spark.sql.catalyst.rules.Rule
23-
import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY,
24-
SCALAR_SUBQUERY}
23+
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
24+
import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, SCALAR_SUBQUERY}
2525
import org.apache.spark.sql.execution
2626
import org.apache.spark.sql.execution.{BaseSubqueryExec, InSubqueryExec, SparkPlan}
2727

@@ -30,6 +30,7 @@ case class PlanAdaptiveSubqueries(
3030

3131
def apply(plan: SparkPlan): SparkPlan = {
3232
plan.transformAllExpressionsWithPruning(
33+
AlwaysProcess.fn,
3334
_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
3435
case expressions.ScalarSubquery(_, _, exprId) =>
3536
execution.ScalarSubquery(subqueryMap(exprId.id), exprId)

sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
2424
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
2525
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
2626
import org.apache.spark.sql.catalyst.rules.Rule
27+
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
2728
import org.apache.spark.sql.catalyst.trees.TreePattern.DYNAMIC_PRUNING_SUBQUERY
2829
import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan, SubqueryBroadcastExec}
2930
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
@@ -50,7 +51,8 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession)
5051
return plan
5152
}
5253

53-
plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) {
54+
plan.transformAllExpressionsWithPruning(AlwaysProcess.fn,
55+
_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) {
5456
case DynamicPruningSubquery(
5557
value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) =>
5658
val sparkPlan = QueryExecution.createSparkPlan(

sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{expressions, InternalRow}
2626
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, ExprId, InSet, ListQuery, Literal, PlanExpression}
2727
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2828
import org.apache.spark.sql.catalyst.rules.Rule
29-
import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike}
29+
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, LeafLike, UnaryLike}
3030
import org.apache.spark.sql.catalyst.trees.TreePattern.{IN_SUBQUERY, SCALAR_SUBQUERY}
3131
import org.apache.spark.sql.internal.SQLConf
3232
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}
@@ -177,7 +177,8 @@ case class InSubqueryExec(
177177
*/
178178
case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
179179
def apply(plan: SparkPlan): SparkPlan = {
180-
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY)) {
180+
plan.transformAllExpressionsWithPruning(AlwaysProcess.fn,
181+
_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY)) {
181182
case subquery: expressions.ScalarSubquery =>
182183
val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, subquery.plan)
183184
ScalarSubquery(

0 commit comments

Comments
 (0)