-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-31670][SQL] Trim unnecessary Struct field alias in Aggregate/GroupingSets #28490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
27c495b
4c0b04c
282648d
c4ff823
6d1b60e
e28b084
1ee0542
0af3166
5f0562c
7ecc8ad
53fd03a
cf818cf
cf31ab4
f846539
d63613f
ef6c87f
82f3876
d0f89af
3ebec5f
72dc305
d3ffbbd
f17dd53
891fd1b
281096a
51cea07
84e65af
9411887
e6fb91f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -510,10 +510,12 @@ class Analyzer( | |
| // collect all the found AggregateExpression, so we can check an expression is part of | ||
| // any AggregateExpression or not. | ||
| val aggsBuffer = ArrayBuffer[Expression]() | ||
|
|
||
| // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. | ||
| def isPartOfAggregation(e: Expression): Boolean = { | ||
| aggsBuffer.exists(a => a.find(_ eq e).isDefined) | ||
| } | ||
|
|
||
| replaceGroupingFunc(_, groupByExprs, gid).transformDown { | ||
| // AggregateExpression should be computed on the unmodified value of its argument | ||
| // expressions, so we should not replace any references to grouping expression | ||
|
|
@@ -1259,6 +1261,11 @@ class Analyzer( | |
| attr.withExprId(exprId) | ||
| } | ||
|
|
||
| private def dedupStructField(attr: Alias, structFieldMap: Map[String, Attribute]) = { | ||
|
||
| val exprId = structFieldMap.getOrElse(attr.child.sql, attr).exprId | ||
| Alias(attr.child, attr.name)(exprId, attr.qualifier, attr.explicitMetadata) | ||
| } | ||
|
|
||
| /** | ||
| * The outer plan may have been de-duplicated and the function below updates the | ||
| * outer references to refer to the de-duplicated attributes. | ||
|
|
@@ -1479,11 +1486,70 @@ class Analyzer( | |
| // Skip the having clause here, this will be handled in ResolveAggregateFunctions. | ||
| case h: UnresolvedHaving => h | ||
|
|
||
| case p: LogicalPlan if needResolveStructField(p) => | ||
| logTrace(s"Attempting to resolve ${p.simpleString(SQLConf.get.maxToStringFields)}") | ||
| val resolved = p.mapExpressions(resolveExpressionTopDown(_, p)) | ||
| val structFieldMap = new mutable.HashMap[String, Alias] | ||
|
||
| resolved.transformExpressions { | ||
| case a @ Alias(struct: GetStructField, _) => | ||
| if (structFieldMap.contains(struct.sql)) { | ||
| val exprId = structFieldMap.getOrElse(struct.sql, a).exprId | ||
| Alias(a.child, a.name)(exprId, a.qualifier, a.explicitMetadata) | ||
| } else { | ||
| structFieldMap.put(struct.sql, a) | ||
| a | ||
| } | ||
| case e => e | ||
| } | ||
|
|
||
| case q: LogicalPlan => | ||
| logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you write this handling in an independent patten like this?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. w/ some code cleanup;
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Done |
||
| q.mapExpressions(resolveExpressionTopDown(_, q)) | ||
| } | ||
|
|
||
| def needResolveStructField(plan: LogicalPlan): Boolean = { | ||
|
||
| plan match { | ||
| case UnresolvedHaving(havingCondition, a: Aggregate) | ||
| if containSameStructFields(a.groupingExpressions.flatMap(_.references), | ||
| a.aggregateExpressions.flatMap(_.references), | ||
| Some(havingCondition.references.toSeq)) => true | ||
| case Aggregate(groupingExpressions, aggregateExpressions, _) | ||
| if containSameStructFields(groupingExpressions.flatMap(_.references), | ||
| aggregateExpressions.flatMap(_.references)) => true | ||
| case GroupingSets(selectedGroupByExprs, groupByExprs, _, aggregations) | ||
| if containSameStructFields(groupByExprs.flatMap(_.references), | ||
| aggregations.flatMap(_.references), | ||
| Some(selectedGroupByExprs.flatMap(_.flatMap(_.references)))) => true | ||
| case _ => false | ||
| } | ||
| } | ||
|
|
||
| def containSameStructFields( | ||
|
||
| grpExprs: Seq[Attribute], | ||
|
||
| aggExprs: Seq[Attribute], | ||
| extra: Option[Seq[Attribute]] = None): Boolean = { | ||
|
|
||
| def isStructField(attr: Attribute): Boolean = { | ||
| attr.isInstanceOf[UnresolvedAttribute] && | ||
| attr.asInstanceOf[UnresolvedAttribute].nameParts.size == 2 | ||
| } | ||
|
|
||
| val grpAttrs = grpExprs.filter(isStructField) | ||
| .map(_.asInstanceOf[UnresolvedAttribute].name) | ||
| val aggAttrs = aggExprs.filter(isStructField) | ||
| .map(_.asInstanceOf[UnresolvedAttribute].name) | ||
| val havingAttrs = extra.getOrElse(Seq.empty[Attribute]).filter(isStructField) | ||
| .map(_.asInstanceOf[UnresolvedAttribute].name) | ||
|
|
||
| if (extra.isDefined) { | ||
| grpAttrs.exists(aggAttrs.contains) | ||
| } else { | ||
| grpAttrs.exists(aggAttrs.contains) || | ||
| grpAttrs.exists(havingAttrs.contains) || | ||
| aggAttrs.exists(havingAttrs.contains) | ||
| } | ||
| } | ||
|
|
||
| def resolveAssignments( | ||
| assignments: Seq[Assignment], | ||
| mergeInto: MergeIntoTable, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3496,6 +3496,88 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
| checkIfSeedExistsInExplain(df2) | ||
| } | ||
|
|
||
| test("SPARK-31670: Struct Field in groupByExpr with CUBE") { | ||
|
||
| withTable("t") { | ||
| sql( | ||
| """CREATE TABLE t( | ||
| |a STRING, | ||
| |b INT, | ||
| |c ARRAY<STRUCT<row_id:INT,json_string:STRING>>, | ||
| |d ARRAY<ARRAY<STRING>>, | ||
| |e ARRAY<MAP<STRING, INT>>) | ||
| |USING ORC""".stripMargin) | ||
|
||
|
|
||
| checkAnswer( | ||
| sql( | ||
| """ | ||
| |SELECT a, each.json_string, SUM(b) | ||
| |FROM t | ||
| |LATERAL VIEW EXPLODE(c) x AS each | ||
| |GROUP BY a, each.json_string | ||
| |WITH CUBE | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| |""".stripMargin), Nil) | ||
|
|
||
| checkAnswer( | ||
| sql( | ||
| """ | ||
| |SELECT a, get_json_object(each.json_string, '$.i'), SUM(b) | ||
| |FROM t | ||
| |LATERAL VIEW EXPLODE(c) X AS each | ||
| |GROUP BY a, get_json_object(each.json_string, '$.i') | ||
| |WITH CUBE | ||
| |""".stripMargin), Nil) | ||
|
|
||
| checkAnswer( | ||
| sql( | ||
| """ | ||
| |SELECT a, each.json_string AS json_string, SUM(b) | ||
| |FROM t | ||
| |LATERAL VIEW EXPLODE(c) x AS each | ||
|
||
| |GROUP BY a, each.json_string | ||
| |WITH CUBE | ||
| |""".stripMargin), Nil) | ||
|
|
||
| checkAnswer( | ||
| sql( | ||
| """ | ||
| |SELECT a, each.json_string as js, SUM(b) | ||
| |FROM t | ||
| |LATERAL VIEW EXPLODE(c) X AS each | ||
| |GROUP BY a, each.json_string | ||
| |WITH CUBE | ||
| |""".stripMargin), Nil) | ||
|
|
||
| checkAnswer( | ||
| sql( | ||
| """ | ||
| |SELECT a, each.json_string as js, SUM(b) | ||
| |FROM t | ||
| |LATERAL VIEW EXPLODE(c) X AS each | ||
| |GROUP BY a, each.json_string | ||
| |WITH ROLLUP | ||
| |""".stripMargin), Nil) | ||
|
|
||
| sql( | ||
| """ | ||
| |SELECT a, each.json_string, SUM(b) | ||
| |FROM t | ||
| |LATERAL VIEW EXPLODE(c) X AS each | ||
| |GROUP BY a, each.json_string | ||
| |GROUPING sets((a),(a, each.json_string)) | ||
| |""".stripMargin).explain(true) | ||
|
|
||
| checkAnswer( | ||
| sql( | ||
| """ | ||
| |SELECT a, each.json_string, SUM(b) | ||
| |FROM t | ||
| |LATERAL VIEW EXPLODE(c) X AS each | ||
| |GROUP BY a, each.json_string | ||
| |GROUPING sets((a),(a, each.json_string)) | ||
| |""".stripMargin), Nil) | ||
|
||
| } | ||
| } | ||
|
|
||
| test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type") { | ||
| checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1))) | ||
| checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: plz revert the unencessary changes. (Unrelated changes might lead to revert/backport failures sometimes...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry..., forgot to check diff.