Skip to content

Commit 6603d9f

Browse files
Davies Liudavies
authored andcommitted
[SPARK-13919] [SQL] fix column pruning through filter
## What changes were proposed in this pull request? This PR fix the conflict between ColumnPruning and PushPredicatesThroughProject, because ColumnPruning will try to insert a Project before Filter, but PushPredicatesThroughProject will move the Filter before Project.This is fixed by remove the Project before Filter, if the Project only do column pruning. The RuleExecutor will fail the test if reached max iterations. Closes #11745 ## How was this patch tested? Existing tests. This is a test case still failing, disabled for now, will be fixed by https://issues.apache.org/jira/browse/SPARK-14137 Author: Davies Liu <davies@databricks.com> Closes #11828 from davies/fail_rule.
1 parent 55a6057 commit 6603d9f

7 files changed

Lines changed: 124 additions & 105 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 76 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ class Analyzer(
8080
EliminateUnions),
8181
Batch("Resolution", fixedPoint,
8282
ResolveRelations ::
83-
ResolveStar ::
8483
ResolveReferences ::
8584
ResolveGroupingAnalytics ::
8685
ResolvePivot ::
@@ -374,91 +373,6 @@ class Analyzer(
374373
}
375374
}
376375

377-
/**
378-
* Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output.
379-
*/
380-
object ResolveStar extends Rule[LogicalPlan] {
381-
382-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
383-
case p: LogicalPlan if !p.childrenResolved => p
384-
// If the projection list contains Stars, expand it.
385-
case p: Project if containsStar(p.projectList) =>
386-
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
387-
// If the aggregate function argument contains Stars, expand it.
388-
case a: Aggregate if containsStar(a.aggregateExpressions) =>
389-
if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
390-
failAnalysis(
391-
"Group by position: star is not allowed to use in the select list " +
392-
"when using ordinals in group by")
393-
} else {
394-
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
395-
}
396-
// If the script transformation input contains Stars, expand it.
397-
case t: ScriptTransformation if containsStar(t.input) =>
398-
t.copy(
399-
input = t.input.flatMap {
400-
case s: Star => s.expand(t.child, resolver)
401-
case o => o :: Nil
402-
}
403-
)
404-
case g: Generate if containsStar(g.generator.children) =>
405-
failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
406-
}
407-
408-
/**
409-
* Build a project list for Project/Aggregate and expand the star if possible
410-
*/
411-
private def buildExpandedProjectList(
412-
exprs: Seq[NamedExpression],
413-
child: LogicalPlan): Seq[NamedExpression] = {
414-
exprs.flatMap {
415-
// Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
416-
case s: Star => s.expand(child, resolver)
417-
// Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
418-
case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
419-
case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
420-
case o => o :: Nil
421-
}.map(_.asInstanceOf[NamedExpression])
422-
}
423-
424-
/**
425-
* Returns true if `exprs` contains a [[Star]].
426-
*/
427-
def containsStar(exprs: Seq[Expression]): Boolean =
428-
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
429-
430-
/**
431-
* Expands the matching attribute.*'s in `child`'s output.
432-
*/
433-
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
434-
expr.transformUp {
435-
case f1: UnresolvedFunction if containsStar(f1.children) =>
436-
f1.copy(children = f1.children.flatMap {
437-
case s: Star => s.expand(child, resolver)
438-
case o => o :: Nil
439-
})
440-
case c: CreateStruct if containsStar(c.children) =>
441-
c.copy(children = c.children.flatMap {
442-
case s: Star => s.expand(child, resolver)
443-
case o => o :: Nil
444-
})
445-
case c: CreateArray if containsStar(c.children) =>
446-
c.copy(children = c.children.flatMap {
447-
case s: Star => s.expand(child, resolver)
448-
case o => o :: Nil
449-
})
450-
case p: Murmur3Hash if containsStar(p.children) =>
451-
p.copy(children = p.children.flatMap {
452-
case s: Star => s.expand(child, resolver)
453-
case o => o :: Nil
454-
})
455-
// count(*) has been replaced by count(1)
456-
case o if containsStar(o.children) =>
457-
failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
458-
}
459-
}
460-
}
461-
462376
/**
463377
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
464378
* a logical plan node's children.
@@ -525,6 +439,29 @@ class Analyzer(
525439
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
526440
case p: LogicalPlan if !p.childrenResolved => p
527441

442+
// If the projection list contains Stars, expand it.
443+
case p: Project if containsStar(p.projectList) =>
444+
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
445+
// If the aggregate function argument contains Stars, expand it.
446+
case a: Aggregate if containsStar(a.aggregateExpressions) =>
447+
if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
448+
failAnalysis(
449+
"Group by position: star is not allowed to use in the select list " +
450+
"when using ordinals in group by")
451+
} else {
452+
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
453+
}
454+
// If the script transformation input contains Stars, expand it.
455+
case t: ScriptTransformation if containsStar(t.input) =>
456+
t.copy(
457+
input = t.input.flatMap {
458+
case s: Star => s.expand(t.child, resolver)
459+
case o => o :: Nil
460+
}
461+
)
462+
case g: Generate if containsStar(g.generator.children) =>
463+
failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
464+
528465
// To resolve duplicate expression IDs for Join and Intersect
529466
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
530467
j.copy(right = dedupRight(left, right))
@@ -619,6 +556,59 @@ class Analyzer(
619556
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
620557
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
621558
}
559+
560+
/**
561+
* Build a project list for Project/Aggregate and expand the star if possible
562+
*/
563+
private def buildExpandedProjectList(
564+
exprs: Seq[NamedExpression],
565+
child: LogicalPlan): Seq[NamedExpression] = {
566+
exprs.flatMap {
567+
// Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
568+
case s: Star => s.expand(child, resolver)
569+
// Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
570+
case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
571+
case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
572+
case o => o :: Nil
573+
}.map(_.asInstanceOf[NamedExpression])
574+
}
575+
576+
/**
577+
* Returns true if `exprs` contains a [[Star]].
578+
*/
579+
def containsStar(exprs: Seq[Expression]): Boolean =
580+
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
581+
582+
/**
583+
* Expands the matching attribute.*'s in `child`'s output.
584+
*/
585+
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
586+
expr.transformUp {
587+
case f1: UnresolvedFunction if containsStar(f1.children) =>
588+
f1.copy(children = f1.children.flatMap {
589+
case s: Star => s.expand(child, resolver)
590+
case o => o :: Nil
591+
})
592+
case c: CreateStruct if containsStar(c.children) =>
593+
c.copy(children = c.children.flatMap {
594+
case s: Star => s.expand(child, resolver)
595+
case o => o :: Nil
596+
})
597+
case c: CreateArray if containsStar(c.children) =>
598+
c.copy(children = c.children.flatMap {
599+
case s: Star => s.expand(child, resolver)
600+
case o => o :: Nil
601+
})
602+
case p: Murmur3Hash if containsStar(p.children) =>
603+
p.copy(children = p.children.flatMap {
604+
case s: Star => s.expand(child, resolver)
605+
case o => o :: Nil
606+
})
607+
// count(*) has been replaced by count(1)
608+
case o if containsStar(o.children) =>
609+
failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
610+
}
611+
}
622612
}
623613

624614
protected[sql] def resolveExpression(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -306,21 +306,21 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
306306
}
307307

308308
/**
309-
* Attempts to eliminate the reading of unneeded columns from the query plan using the following
310-
* transformations:
309+
* Attempts to eliminate the reading of unneeded columns from the query plan.
311310
*
312-
* - Inserting Projections beneath the following operators:
313-
* - Aggregate
314-
* - Generate
315-
* - Project <- Join
316-
* - LeftSemiJoin
311+
* Since adding Project before Filter conflicts with PushPredicatesThroughProject, this rule will
312+
* remove the Project p2 in the following pattern:
313+
*
314+
* p1 @ Project(_, Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(p2.inputSet)
315+
*
316+
* p2 is usually inserted by this rule and useless, p1 could prune the columns anyway.
317317
*/
318318
object ColumnPruning extends Rule[LogicalPlan] {
319319
private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
320320
output1.size == output2.size &&
321321
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
322322

323-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
323+
def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform {
324324
// Prunes the unused columns from project list of Project/Aggregate/Expand
325325
case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
326326
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
@@ -399,7 +399,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
399399
} else {
400400
p
401401
}
402-
}
402+
})
403403

404404
/** Applies a projection only when the child is producing unnecessary attributes */
405405
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
@@ -408,6 +408,16 @@ object ColumnPruning extends Rule[LogicalPlan] {
408408
} else {
409409
c
410410
}
411+
412+
/**
413+
* The Project before Filter is not necessary but conflict with PushPredicatesThroughProject,
414+
* so remove it.
415+
*/
416+
private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform {
417+
case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child)))
418+
if p2.outputSet.subsetOf(child.outputSet) =>
419+
p1.copy(child = f.copy(child = child))
420+
}
411421
}
412422

413423
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ import scala.collection.JavaConverters._
2222
import com.google.common.util.concurrent.AtomicLongMap
2323

2424
import org.apache.spark.internal.Logging
25+
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2526
import org.apache.spark.sql.catalyst.trees.TreeNode
2627
import org.apache.spark.sql.catalyst.util.sideBySide
28+
import org.apache.spark.util.Utils
2729

2830
object RuleExecutor {
2931
protected val timeMap = AtomicLongMap.create[String]()
@@ -98,7 +100,12 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
98100
if (iteration > batch.strategy.maxIterations) {
99101
// Only log if this is a rule that is supposed to run more than once.
100102
if (iteration != 2) {
101-
logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}")
103+
val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}"
104+
if (Utils.isTesting) {
105+
throw new TreeNodeException(curPlan, message, null)
106+
} else {
107+
logWarning(message)
108+
}
102109
}
103110
continue = false
104111
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class AnalysisSuite extends AnalysisTest {
2929
import org.apache.spark.sql.catalyst.analysis.TestRelations._
3030

3131
test("union project *") {
32-
val plan = (1 to 100)
32+
val plan = (1 to 120)
3333
.map(_ => testRelation)
3434
.fold[LogicalPlan](testRelation) { (a, b) =>
3535
a.select(UnresolvedStar(None)).select('a).union(b.select(UnresolvedStar(None)))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class ColumnPruningSuite extends PlanTest {
3434

3535
object Optimize extends RuleExecutor[LogicalPlan] {
3636
val batches = Batch("Column pruning", FixedPoint(100),
37+
PushPredicateThroughProject,
3738
ColumnPruning,
3839
CollapseProject) :: Nil
3940
}
@@ -133,12 +134,16 @@ class ColumnPruningSuite extends PlanTest {
133134

134135
test("Column pruning on Filter") {
135136
val input = LocalRelation('a.int, 'b.string, 'c.double)
137+
val plan1 = Filter('a > 1, input).analyze
138+
comparePlans(Optimize.execute(plan1), plan1)
136139
val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze
137-
val expected =
138-
Project('a :: Nil,
139-
Filter('c > Literal(0.0),
140-
Project(Seq('a, 'c), input))).analyze
141-
comparePlans(Optimize.execute(query), expected)
140+
comparePlans(Optimize.execute(query), query)
141+
val plan2 = Filter('b > 1, Project(Seq('a, 'b), input)).analyze
142+
val expected2 = Project(Seq('a, 'b), Filter('b > 1, input)).analyze
143+
comparePlans(Optimize.execute(plan2), expected2)
144+
val plan3 = Project(Seq('a), Filter('b > 1, Project(Seq('a, 'b), input))).analyze
145+
val expected3 = Project(Seq('a), Filter('b > 1, input)).analyze
146+
comparePlans(Optimize.execute(plan3), expected3)
142147
}
143148

144149
test("Column pruning on except/intersect/distinct") {
@@ -297,7 +302,7 @@ class ColumnPruningSuite extends PlanTest {
297302
SortOrder('b, Ascending) :: Nil,
298303
UnspecifiedFrame)).as('window) :: Nil,
299304
'a :: Nil, 'b.asc :: Nil)
300-
.select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze
305+
.where('window > 1).select('a, 'c).analyze
301306

302307
val optimized = Optimize.execute(originalQuery.analyze)
303308

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.sql.catalyst.trees
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2122
import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal}
23+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2224
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
2325

2426
class RuleExecutorSuite extends SparkFunSuite {
@@ -49,6 +51,9 @@ class RuleExecutorSuite extends SparkFunSuite {
4951
val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil
5052
}
5153

52-
assert(ToFixedPoint.execute(Literal(100)) === Literal(90))
54+
val message = intercept[TreeNodeException[LogicalPlan]] {
55+
ToFixedPoint.execute(Literal(100))
56+
}.getMessage
57+
assert(message.contains("Max iterations (10) reached for batch fixedPoint"))
5358
}
5459
}

sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
341341
"udf_round_3",
342342
"view_cast",
343343

344+
// enable this after fixing SPARK-14137
345+
"union20",
346+
344347
// These tests check the VIEW table definition, but Spark handles CREATE VIEW itself and
345348
// generates different View Expanded Text.
346349
"alter_view_as_select",
@@ -1043,7 +1046,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
10431046
"union18",
10441047
"union19",
10451048
"union2",
1046-
"union20",
10471049
"union22",
10481050
"union23",
10491051
"union24",

0 commit comments

Comments
 (0)