Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Analyzer(
EliminateUnions),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveStar ::
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
Expand Down Expand Up @@ -350,28 +351,83 @@ class Analyzer(
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
* Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output.
*/
object ResolveReferences extends Rule[LogicalPlan] {
object ResolveStar extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p

// If the projection list contains Stars, expand it.
case p: Project if containsStar(p.projectList) =>
val expanded = p.projectList.flatMap {
case s: Star => s.expand(p.child, resolver)
case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil
case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
Alias(child = expandStarExpression(a.child, p.child), a.name)(
isGenerated = a.isGenerated) :: Nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will lose qualifier here, how about a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, a good catch!

case o => o :: Nil
}
Project(projectList = expanded, p.child)
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
val expanded = a.aggregateExpressions.flatMap {
case s: Star => s.expand(a.child, resolver)
case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil
case o => o :: Nil
}.map(_.asInstanceOf[NamedExpression])
a.copy(aggregateExpressions = expanded)
// If the script transformation input contains Stars, expand it.
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)
case g: Generate if containsStar(g.generator.children) =>
failAnalysis("Cannot explode *, explode can only be applied on a specific column.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized the error message is not clear enough, Generate is not always "explode"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a test for this error message?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I moved this from another rule. I will check the coverage of test cases. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a test case: https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala#L181-L182

How about changing the message to Invalid usage of '*' in explode/json_tuple/UDTF? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explode/json_tuple/UDTF LGTM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let me change it now.

}

/**
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)

/**
* Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree
* rooted at each expression.
* Expands the matching attribute.*'s in `child`'s output.
*/
def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = {
exprs.flatMap {
case s: Star => s.expand(child, resolver)
case e =>
e.transformDown {
case f1: UnresolvedFunction if containsStar(f1.children) =>
f1.copy(children = f1.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
} :: Nil
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
expr.transformUp {
case f1: UnresolvedFunction if containsStar(f1.children) =>
f1.copy(children = f1.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateStruct if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateArray if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
// count(*) has been replaced by count(1)
case o if containsStar(o.children) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have a method:

private def mayContainsStar(expr: Expression): Boolean = expr.isInstnaceOf[UnresolvedFunction] || expr.isInstnaceOf[CreateStruct]...

then we can simplify this to:

expr.transformUp {
  case e if mayContainsStar(e) =>
    e.copy(children = ...)
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great idea! : )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried it, but copy is unable to use here. When the type is Expression (abstract type), we are unable to use the copy function to change the children. In addition, withNewChildren requires the same number of children. Do you have any idea how to fix it? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see, I don't have a better idea, let's just keep it this way.

failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
}
}
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
*/
object ResolveReferences extends Rule[LogicalPlan] {
/**
* Generate a new logical plan for the right child with different expression IDs
* for all conflicting attributes.
Expand Down Expand Up @@ -432,48 +488,6 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p

// If the projection list contains Stars, expand it.
case p @ Project(projectList, child) if containsStar(projectList) =>
Project(
projectList.flatMap {
case s: Star => s.expand(child, resolver)
case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
Alias(child = f.copy(children = newChildren), name)(
isGenerated = a.isGenerated) :: Nil
case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
}
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
}
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)

case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)

// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
val expanded = expandStarExpressions(a.aggregateExpressions, a.child)
.map(_.asInstanceOf[NamedExpression])
a.copy(aggregateExpressions = expanded)

// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
Expand Down Expand Up @@ -561,12 +575,6 @@ class Analyzer(
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
}

/**
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}

private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = {
Expand Down Expand Up @@ -833,8 +841,6 @@ class Analyzer(
*/
object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case g: Generate if ResolveReferences.containsStar(g.generator.children) =>
failAnalysis("Cannot explode *, explode can only be applied on a specific column.")
case p: Generate if !p.child.resolved || !p.generator.resolved => p
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
val f = udf((a: String) => a)
val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
df.select(struct($"a").as("s")).select(f($"s.a")).collect()
df.select(struct($"*").as("s")).select(f($"s.a")).collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed?

}

test("UDF on named_struct") {
Expand All @@ -42,6 +43,7 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
val f = udf((a: String) => a)
val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect()
df.select(array($"*").as("s")).select(f(expr("s[0]"))).collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed?

}

test("SPARK-12477 accessing null element in array field") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage)
}

test("verify star in functions fail with a good error") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why put this test case in DatasetSuite?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from the example in the original JIRA. Let me move it. Thanks!

val ds = Seq(("a", 1, "c"), ("b", 2, "d")).map(a => (a._1, a._3))
val e = intercept[AnalysisException] {
ds.toDF().groupBy($"_1").agg(sum($"*") as "sumOccurances")
}
assert(e.getMessage.contains("Invalid usage of '*' in expression 'sum'"), e.getMessage)
}

test("runtime nullability check") {
val schema = StructType(Seq(
StructField("f", StructType(Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1834,6 +1834,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
""".stripMargin).select($"r.*"),
Row(3, 2) :: Nil)

assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should write a new test case to test * in CreateStruct and CreateArray, not just put in existing ones.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do. Thanks!


// With GROUP BY
checkAnswer(sql(
"""
Expand Down