-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-13320] [SQL] Support Star in CreateStruct/CreateArray and Error Handling when DataFrame/DataSet Functions using Star #11208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
3b2b448
ac71f39
8d809bc
2c72edf
6b2d609
e47f141
4b65af3
e060dea
99f5312
ba3fe7c
0fce075
50abeec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,6 +72,7 @@ class Analyzer( | |
| EliminateUnions), | ||
| Batch("Resolution", fixedPoint, | ||
| ResolveRelations :: | ||
| ResolveStar :: | ||
| ResolveReferences :: | ||
| ResolveGroupingAnalytics :: | ||
| ResolvePivot :: | ||
|
|
@@ -350,28 +351,83 @@ class Analyzer( | |
| } | ||
|
|
||
| /** | ||
| * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from | ||
| * a logical plan node's children. | ||
| * Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output. | ||
| */ | ||
| object ResolveReferences extends Rule[LogicalPlan] { | ||
| object ResolveStar extends Rule[LogicalPlan] { | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { | ||
| case p: LogicalPlan if !p.childrenResolved => p | ||
|
|
||
| // If the projection list contains Stars, expand it. | ||
| case p: Project if containsStar(p.projectList) => | ||
| val expanded = p.projectList.flatMap { | ||
| case s: Star => s.expand(p.child, resolver) | ||
| case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => | ||
| UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil | ||
| case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => | ||
| Alias(child = expandStarExpression(a.child, p.child), a.name)( | ||
| isGenerated = a.isGenerated) :: Nil | ||
| case o => o :: Nil | ||
| } | ||
| Project(projectList = expanded, p.child) | ||
| // If the aggregate function argument contains Stars, expand it. | ||
| case a: Aggregate if containsStar(a.aggregateExpressions) => | ||
| val expanded = a.aggregateExpressions.flatMap { | ||
| case s: Star => s.expand(a.child, resolver) | ||
| case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil | ||
| case o => o :: Nil | ||
| }.map(_.asInstanceOf[NamedExpression]) | ||
| a.copy(aggregateExpressions = expanded) | ||
| // If the script transformation input contains Stars, expand it. | ||
| case t: ScriptTransformation if containsStar(t.input) => | ||
| t.copy( | ||
| input = t.input.flatMap { | ||
| case s: Star => s.expand(t.child, resolver) | ||
| case o => o :: Nil | ||
| } | ||
| ) | ||
| case g: Generate if containsStar(g.generator.children) => | ||
| failAnalysis("Cannot explode *, explode can only be applied on a specific column.") | ||
|
||
| } | ||
|
|
||
| /** | ||
| * Returns true if `exprs` contains a [[Star]]. | ||
| */ | ||
| def containsStar(exprs: Seq[Expression]): Boolean = | ||
| exprs.exists(_.collect { case _: Star => true }.nonEmpty) | ||
|
|
||
| /** | ||
| * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree | ||
| * rooted at each expression. | ||
| * Expands the matching attribute.*'s in `child`'s output. | ||
| */ | ||
| def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = { | ||
| exprs.flatMap { | ||
| case s: Star => s.expand(child, resolver) | ||
| case e => | ||
| e.transformDown { | ||
| case f1: UnresolvedFunction if containsStar(f1.children) => | ||
| f1.copy(children = f1.children.flatMap { | ||
| case s: Star => s.expand(child, resolver) | ||
| case o => o :: Nil | ||
| }) | ||
| } :: Nil | ||
| def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { | ||
| expr.transformUp { | ||
| case f1: UnresolvedFunction if containsStar(f1.children) => | ||
| f1.copy(children = f1.children.flatMap { | ||
| case s: Star => s.expand(child, resolver) | ||
| case o => o :: Nil | ||
| }) | ||
| case c: CreateStruct if containsStar(c.children) => | ||
| c.copy(children = c.children.flatMap { | ||
| case s: Star => s.expand(child, resolver) | ||
| case o => o :: Nil | ||
| }) | ||
| case c: CreateArray if containsStar(c.children) => | ||
| c.copy(children = c.children.flatMap { | ||
| case s: Star => s.expand(child, resolver) | ||
| case o => o :: Nil | ||
| }) | ||
| // count(*) has been replaced by count(1) | ||
| case o if containsStar(o.children) => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can have a method: then we can simplify this to:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is a great idea! : )
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tried it, but
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh i see, I don't have a better idea, let's just keep it this way. |
||
| failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from | ||
| * a logical plan node's children. | ||
| */ | ||
| object ResolveReferences extends Rule[LogicalPlan] { | ||
| /** | ||
| * Generate a new logical plan for the right child with different expression IDs | ||
| * for all conflicting attributes. | ||
|
|
@@ -432,48 +488,6 @@ class Analyzer( | |
| def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { | ||
| case p: LogicalPlan if !p.childrenResolved => p | ||
|
|
||
| // If the projection list contains Stars, expand it. | ||
| case p @ Project(projectList, child) if containsStar(projectList) => | ||
| Project( | ||
| projectList.flatMap { | ||
| case s: Star => s.expand(child, resolver) | ||
| case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) => | ||
| val newChildren = expandStarExpressions(args, child) | ||
| UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil | ||
| case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => | ||
| val newChildren = expandStarExpressions(args, child) | ||
| Alias(child = f.copy(children = newChildren), name)( | ||
| isGenerated = a.isGenerated) :: Nil | ||
| case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) => | ||
| val expandedArgs = args.flatMap { | ||
| case s: Star => s.expand(child, resolver) | ||
| case o => o :: Nil | ||
| } | ||
| UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil | ||
| case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) => | ||
| val expandedArgs = args.flatMap { | ||
| case s: Star => s.expand(child, resolver) | ||
| case o => o :: Nil | ||
| } | ||
| UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil | ||
| case o => o :: Nil | ||
| }, | ||
| child) | ||
|
|
||
| case t: ScriptTransformation if containsStar(t.input) => | ||
| t.copy( | ||
| input = t.input.flatMap { | ||
| case s: Star => s.expand(t.child, resolver) | ||
| case o => o :: Nil | ||
| } | ||
| ) | ||
|
|
||
| // If the aggregate function argument contains Stars, expand it. | ||
| case a: Aggregate if containsStar(a.aggregateExpressions) => | ||
| val expanded = expandStarExpressions(a.aggregateExpressions, a.child) | ||
| .map(_.asInstanceOf[NamedExpression]) | ||
| a.copy(aggregateExpressions = expanded) | ||
|
|
||
| // To resolve duplicate expression IDs for Join and Intersect | ||
| case j @ Join(left, right, _, _) if !j.duplicateResolved => | ||
| j.copy(right = dedupRight(left, right)) | ||
|
|
@@ -561,12 +575,6 @@ class Analyzer( | |
| def findAliases(projectList: Seq[NamedExpression]): AttributeSet = { | ||
| AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) | ||
| } | ||
|
|
||
| /** | ||
| * Returns true if `exprs` contains a [[Star]]. | ||
| */ | ||
| def containsStar(exprs: Seq[Expression]): Boolean = | ||
| exprs.exists(_.collect { case _: Star => true }.nonEmpty) | ||
| } | ||
|
|
||
| private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = { | ||
|
|
@@ -833,8 +841,6 @@ class Analyzer( | |
| */ | ||
| object ResolveGenerate extends Rule[LogicalPlan] { | ||
| def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { | ||
| case g: Generate if ResolveReferences.containsStar(g.generator.children) => | ||
| failAnalysis("Cannot explode *, explode can only be applied on a specific column.") | ||
| case p: Generate if !p.child.resolved || !p.generator.resolved => p | ||
| case g: Generate if !g.resolved => | ||
| g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1938,6 +1938,21 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | |
| } | ||
| } | ||
|
|
||
| test("Star Expansion - CreateStruct and CreateArray") { | ||
|
||
| val structDf = testData2.select("a", "b").as("record") | ||
| // CreateStruct and CreateArray in aggregateExpressions | ||
| assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) | ||
| assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1))) | ||
|
|
||
| // CreateStruct and CreateArray in project list (unresolved alias) | ||
| assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1))) | ||
| assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1)) | ||
|
|
||
| // CreateStruct and CreateArray in project list (alias) | ||
|
||
| assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1))) | ||
| assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1)) | ||
| } | ||
|
|
||
| test("Common subexpression elimination") { | ||
| // TODO: support subexpression elimination in whole stage codegen | ||
| withSQLConf("spark.sql.codegen.wholeStage" -> "false") { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will lose qualifier here, how about
a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, a good catch!