Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Analyzer(
EliminateUnions),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveStar ::
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
Expand Down Expand Up @@ -350,28 +351,83 @@ class Analyzer(
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
* Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output.
*/
object ResolveReferences extends Rule[LogicalPlan] {
object ResolveStar extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p

// If the projection list contains Stars, expand it.
case p: Project if containsStar(p.projectList) =>
val expanded = p.projectList.flatMap {
case s: Star => s.expand(p.child, resolver)
case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil
case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
Alias(child = expandStarExpression(a.child, p.child), a.name)(
isGenerated = a.isGenerated) :: Nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will lose qualifier here, how about a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, a good catch!

case o => o :: Nil
}
Project(projectList = expanded, p.child)
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
val expanded = a.aggregateExpressions.flatMap {
case s: Star => s.expand(a.child, resolver)
case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil
case o => o :: Nil
}.map(_.asInstanceOf[NamedExpression])
a.copy(aggregateExpressions = expanded)
// If the script transformation input contains Stars, expand it.
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)
case g: Generate if containsStar(g.generator.children) =>
failAnalysis("Cannot explode *, explode can only be applied on a specific column.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized the error message is not clear enough, Generate is not always "explode"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a test for this error message?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I moved this from another rule. I will check the coverage of test cases. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a test case: https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala#L181-L182

How about changing the message to Invalid usage of '*' in explode/json_tuple/UDTF? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explode/json_tuple/UDTF LGTM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let me change it now.

}

/**
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)

/**
* Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree
* rooted at each expression.
* Expands the matching attribute.*'s in `child`'s output.
*/
def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = {
exprs.flatMap {
case s: Star => s.expand(child, resolver)
case e =>
e.transformDown {
case f1: UnresolvedFunction if containsStar(f1.children) =>
f1.copy(children = f1.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
} :: Nil
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
expr.transformUp {
case f1: UnresolvedFunction if containsStar(f1.children) =>
f1.copy(children = f1.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateStruct if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateArray if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
// count(*) has been replaced by count(1)
case o if containsStar(o.children) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have a method:

private def mayContainsStar(expr: Expression): Boolean = expr.isInstnaceOf[UnresolvedFunction] || expr.isInstnaceOf[CreateStruct]...

then we can simplify this to:

expr.transformUp {
  case e if mayContainsStar(e) =>
    e.copy(children = ...)
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great idea! : )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried it, but copy is unable to use here. When the type is Expression (abstract type), we are unable to use the copy function to change the children. In addition, withNewChildren requires the same number of children. Do you have any idea how to fix it? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see, I don't have a better idea, let's just keep it this way.

failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
}
}
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
*/
object ResolveReferences extends Rule[LogicalPlan] {
/**
* Generate a new logical plan for the right child with different expression IDs
* for all conflicting attributes.
Expand Down Expand Up @@ -432,48 +488,6 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p

// If the projection list contains Stars, expand it.
case p @ Project(projectList, child) if containsStar(projectList) =>
Project(
projectList.flatMap {
case s: Star => s.expand(child, resolver)
case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
Alias(child = f.copy(children = newChildren), name)(
isGenerated = a.isGenerated) :: Nil
case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
}
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
}
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)

case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)

// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
val expanded = expandStarExpressions(a.aggregateExpressions, a.child)
.map(_.asInstanceOf[NamedExpression])
a.copy(aggregateExpressions = expanded)

// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
Expand Down Expand Up @@ -561,12 +575,6 @@ class Analyzer(
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
}

/**
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}

private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = {
Expand Down Expand Up @@ -833,8 +841,6 @@ class Analyzer(
*/
object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case g: Generate if ResolveReferences.containsStar(g.generator.children) =>
failAnalysis("Cannot explode *, explode can only be applied on a specific column.")
case p: Generate if !p.child.resolved || !p.generator.resolved => p
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ class AnalysisErrorSuite extends AnalysisTest {
.orderBy('havingCondition.asc),
"cannot resolve" :: "havingCondition" :: Nil)

errorTest(
"unresolved star expansion in max",
testRelation2.groupBy('a)(sum(UnresolvedStar(None))),
"Invalid usage of '*'" :: "in expression 'sum'" :: Nil)

errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
Expand Down
15 changes: 15 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,21 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}

test("Star Expansion - CreateStruct and CreateArray") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we put these tests in SQLQuerySuite? It looks like they are mostly testing DF APIs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, let me move them to DataFrameSuite. Thanks!

val structDf = testData2.select("a", "b").as("record")
// CreateStruct and CreateArray in aggregateExpressions
assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1)))
assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1)))

// CreateStruct and CreateArray in project list (unresolved alias)
assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1)))
assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1))

// CreateStruct and CreateArray in project list (alias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we add another 2 cases: Generate and ScriptTransform?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let me do it.

assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1)))
assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1))
}

test("Common subexpression elimination") {
// TODO: support subexpression elimination in whole stage codegen
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
Expand Down