Skip to content

Commit cd7132e

Browse files
author
Davies Liu
committed
fix ResolveStar
1 parent 6ac25f7 commit cd7132e

2 files changed

Lines changed: 71 additions & 81 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 70 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ class Analyzer(
8080
EliminateUnions),
8181
Batch("Resolution", fixedPoint,
8282
ResolveRelations ::
83-
ResolveStar ::
8483
ResolveReferences ::
8584
ResolveGroupingAnalytics ::
8685
ResolvePivot ::
@@ -373,85 +372,6 @@ class Analyzer(
373372
}
374373
}
375374

376-
/**
377-
* Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output.
378-
*/
379-
object ResolveStar extends Rule[LogicalPlan] {
380-
381-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
382-
case p: LogicalPlan if !p.childrenResolved => p
383-
// If the projection list contains Stars, expand it.
384-
case p: Project if containsStar(p.projectList) =>
385-
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
386-
// If the aggregate function argument contains Stars, expand it.
387-
case a: Aggregate if containsStar(a.aggregateExpressions) =>
388-
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
389-
// If the script transformation input contains Stars, expand it.
390-
case t: ScriptTransformation if containsStar(t.input) =>
391-
t.copy(
392-
input = t.input.flatMap {
393-
case s: Star => s.expand(t.child, resolver)
394-
case o => o :: Nil
395-
}
396-
)
397-
case g: Generate if containsStar(g.generator.children) =>
398-
failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
399-
}
400-
401-
/**
402-
* Build a project list for Project/Aggregate and expand the star if possible
403-
*/
404-
private def buildExpandedProjectList(
405-
exprs: Seq[NamedExpression],
406-
child: LogicalPlan): Seq[NamedExpression] = {
407-
exprs.flatMap {
408-
// Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
409-
case s: Star => s.expand(child, resolver)
410-
// Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
411-
case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
412-
case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
413-
case o => o :: Nil
414-
}.map(_.asInstanceOf[NamedExpression])
415-
}
416-
417-
/**
418-
* Returns true if `exprs` contains a [[Star]].
419-
*/
420-
def containsStar(exprs: Seq[Expression]): Boolean =
421-
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
422-
423-
/**
424-
* Expands the matching attribute.*'s in `child`'s output.
425-
*/
426-
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
427-
expr.transformUp {
428-
case f1: UnresolvedFunction if containsStar(f1.children) =>
429-
f1.copy(children = f1.children.flatMap {
430-
case s: Star => s.expand(child, resolver)
431-
case o => o :: Nil
432-
})
433-
case c: CreateStruct if containsStar(c.children) =>
434-
c.copy(children = c.children.flatMap {
435-
case s: Star => s.expand(child, resolver)
436-
case o => o :: Nil
437-
})
438-
case c: CreateArray if containsStar(c.children) =>
439-
c.copy(children = c.children.flatMap {
440-
case s: Star => s.expand(child, resolver)
441-
case o => o :: Nil
442-
})
443-
case p: Murmur3Hash if containsStar(p.children) =>
444-
p.copy(children = p.children.flatMap {
445-
case s: Star => s.expand(child, resolver)
446-
case o => o :: Nil
447-
})
448-
// count(*) has been replaced by count(1)
449-
case o if containsStar(o.children) =>
450-
failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
451-
}
452-
}
453-
}
454-
455375
/**
456376
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
457377
* a logical plan node's children.
@@ -518,6 +438,23 @@ class Analyzer(
518438
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
519439
case p: LogicalPlan if !p.childrenResolved => p
520440

441+
// If the projection list contains Stars, expand it.
442+
case p: Project if containsStar(p.projectList) =>
443+
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
444+
// If the aggregate function argument contains Stars, expand it.
445+
case a: Aggregate if containsStar(a.aggregateExpressions) =>
446+
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
447+
// If the script transformation input contains Stars, expand it.
448+
case t: ScriptTransformation if containsStar(t.input) =>
449+
t.copy(
450+
input = t.input.flatMap {
451+
case s: Star => s.expand(t.child, resolver)
452+
case o => o :: Nil
453+
}
454+
)
455+
case g: Generate if containsStar(g.generator.children) =>
456+
failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
457+
521458
// To resolve duplicate expression IDs for Join and Intersect
522459
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
523460
j.copy(right = dedupRight(left, right))
@@ -612,6 +549,59 @@ class Analyzer(
612549
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
613550
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
614551
}
552+
553+
/**
554+
* Build a project list for Project/Aggregate and expand the star if possible
555+
*/
556+
private def buildExpandedProjectList(
557+
exprs: Seq[NamedExpression],
558+
child: LogicalPlan): Seq[NamedExpression] = {
559+
exprs.flatMap {
560+
// Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
561+
case s: Star => s.expand(child, resolver)
562+
// Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
563+
case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
564+
case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
565+
case o => o :: Nil
566+
}.map(_.asInstanceOf[NamedExpression])
567+
}
568+
569+
/**
570+
* Returns true if `exprs` contains a [[Star]].
571+
*/
572+
def containsStar(exprs: Seq[Expression]): Boolean =
573+
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
574+
575+
/**
576+
* Expands the matching attribute.*'s in `child`'s output.
577+
*/
578+
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
579+
expr.transformUp {
580+
case f1: UnresolvedFunction if containsStar(f1.children) =>
581+
f1.copy(children = f1.children.flatMap {
582+
case s: Star => s.expand(child, resolver)
583+
case o => o :: Nil
584+
})
585+
case c: CreateStruct if containsStar(c.children) =>
586+
c.copy(children = c.children.flatMap {
587+
case s: Star => s.expand(child, resolver)
588+
case o => o :: Nil
589+
})
590+
case c: CreateArray if containsStar(c.children) =>
591+
c.copy(children = c.children.flatMap {
592+
case s: Star => s.expand(child, resolver)
593+
case o => o :: Nil
594+
})
595+
case p: Murmur3Hash if containsStar(p.children) =>
596+
p.copy(children = p.children.flatMap {
597+
case s: Star => s.expand(child, resolver)
598+
case o => o :: Nil
599+
})
600+
// count(*) has been replaced by count(1)
601+
case o if containsStar(o.children) =>
602+
failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
603+
}
604+
}
615605
}
616606

617607
protected[sql] def resolveExpression(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class AnalysisSuite extends AnalysisTest {
2929
import org.apache.spark.sql.catalyst.analysis.TestRelations._
3030

3131
test("union project *") {
32-
val plan = (1 to 100)
32+
val plan = (1 to 120)
3333
.map(_ => testRelation)
3434
.fold[LogicalPlan](testRelation) { (a, b) =>
3535
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))

0 commit comments

Comments
 (0)