Skip to content

Commit 59e3a56

Browse files
maropucloud-fan
authored andcommitted
[SPARK-14471][SQL] Aliases in SELECT could be used in GROUP BY
## What changes were proposed in this pull request? This pr added a new rule in `Analyzer` to resolve aliases in `GROUP BY`. The current master throws an exception if `GROUP BY` clauses have aliases in `SELECT`; ``` scala> spark.sql("select a a1, a1 + 1 as b, count(1) from t group by a1") org.apache.spark.sql.AnalysisException: cannot resolve '`a1`' given input columns: [a]; line 1 pos 51; 'Aggregate ['a1], [a#83L AS a1#87L, ('a1 + 1) AS b#88, count(1) AS count(1)#90L] +- SubqueryAlias t +- Project [id#80L AS a#83L] +- Range (0, 10, step=1, splits=Some(8)) at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:77) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:74) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) ``` ## How was this patch tested? Added tests in `SQLQuerySuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro <[email protected]> Closes #17191 from maropu/SPARK-14471.
1 parent e3c8160 commit 59e3a56

6 files changed

Lines changed: 156 additions & 32 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class Analyzer(
136136
ResolveGroupingAnalytics ::
137137
ResolvePivot ::
138138
ResolveOrdinalInOrderByAndGroupBy ::
139+
ResolveAggAliasInGroupBy ::
139140
ResolveMissingReferences ::
140141
ExtractGenerator ::
141142
ResolveGenerate ::
@@ -172,7 +173,7 @@ class Analyzer(
172173
* Analyze cte definitions and substitute child plan with analyzed cte definitions.
173174
*/
174175
object CTESubstitution extends Rule[LogicalPlan] {
175-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
176+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
176177
case With(child, relations) =>
177178
substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
178179
case (resolved, (name, relation)) =>
@@ -200,7 +201,7 @@ class Analyzer(
200201
* Substitute child plan with WindowSpecDefinitions.
201202
*/
202203
object WindowsSubstitution extends Rule[LogicalPlan] {
203-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
204+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
204205
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
205206
case WithWindowDefinition(windowDefinitions, child) =>
206207
child.transform {
@@ -242,7 +243,7 @@ class Analyzer(
242243
private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
243244
exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)
244245

245-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
246+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
246247
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
247248
Aggregate(groups, assignAliases(aggs), child)
248249

@@ -614,7 +615,7 @@ class Analyzer(
614615
case _ => plan
615616
}
616617

617-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
618+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
618619
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
619620
EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
620621
case v: View =>
@@ -786,7 +787,7 @@ class Analyzer(
786787
}
787788
}
788789

789-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
790+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
790791
case p: LogicalPlan if !p.childrenResolved => p
791792

792793
// If the projection list contains Stars, expand it.
@@ -844,11 +845,10 @@ class Analyzer(
844845

845846
case q: LogicalPlan =>
846847
logTrace(s"Attempting to resolve ${q.simpleString}")
847-
q transformExpressionsUp {
848+
q.transformExpressionsUp {
848849
case u @ UnresolvedAttribute(nameParts) =>
849-
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
850-
val result =
851-
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
850+
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
851+
val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
852852
logDebug(s"Resolving $u to $result")
853853
result
854854
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@@ -961,7 +961,7 @@ class Analyzer(
961961
* have no effect on the results.
962962
*/
963963
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
964-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
964+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
965965
case p if !p.childrenResolved => p
966966
// Replace the index with the related attribute for ORDER BY,
967967
// which is a 1-base position of the projection list.
@@ -997,6 +997,27 @@ class Analyzer(
997997
}
998998
}
999999

1000+
/**
1001+
* Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses.
1002+
* This rule is expected to run after [[ResolveReferences]] applied.
1003+
*/
1004+
object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] {
1005+
1006+
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
1007+
case agg @ Aggregate(groups, aggs, child)
1008+
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
1009+
groups.exists(_.isInstanceOf[UnresolvedAttribute]) =>
1010+
// This is a strict check though, we put this to apply the rule only in alias expressions
1011+
def notResolvableByChild(attrName: String): Boolean =
1012+
!child.output.exists(a => resolver(a.name, attrName))
1013+
agg.copy(groupingExpressions = groups.map {
1014+
case u: UnresolvedAttribute if notResolvableByChild(u.name) =>
1015+
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
1016+
case e => e
1017+
})
1018+
}
1019+
}
1020+
10001021
/**
10011022
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
10021023
* clause. This rule detects such queries and adds the required attributes to the original
@@ -1006,7 +1027,7 @@ class Analyzer(
10061027
* The HAVING clause could also used a grouping columns that is not presented in the SELECT.
10071028
*/
10081029
object ResolveMissingReferences extends Rule[LogicalPlan] {
1009-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1030+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
10101031
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
10111032
case sa @ Sort(_, _, child: Aggregate) => sa
10121033

@@ -1130,7 +1151,7 @@ class Analyzer(
11301151
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
11311152
*/
11321153
object ResolveFunctions extends Rule[LogicalPlan] {
1133-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1154+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
11341155
case q: LogicalPlan =>
11351156
q transformExpressions {
11361157
case u if !u.childrenResolved => u // Skip until children are resolved.
@@ -1469,7 +1490,7 @@ class Analyzer(
14691490
/**
14701491
* Resolve and rewrite all subqueries in an operator tree..
14711492
*/
1472-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1493+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
14731494
// In case of HAVING (a filter after an aggregate) we use both the aggregate and
14741495
// its child for resolution.
14751496
case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
@@ -1484,7 +1505,7 @@ class Analyzer(
14841505
* Turns projections that contain aggregate expressions into aggregations.
14851506
*/
14861507
object GlobalAggregates extends Rule[LogicalPlan] {
1487-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1508+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
14881509
case Project(projectList, child) if containsAggregates(projectList) =>
14891510
Aggregate(Nil, projectList, child)
14901511
}
@@ -1510,7 +1531,7 @@ class Analyzer(
15101531
* underlying aggregate operator and then projected away after the original operator.
15111532
*/
15121533
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
1513-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1534+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
15141535
case filter @ Filter(havingCondition,
15151536
aggregate @ Aggregate(grouping, originalAggExprs, child))
15161537
if aggregate.resolved =>
@@ -1682,7 +1703,7 @@ class Analyzer(
16821703
}
16831704
}
16841705

1685-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1706+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
16861707
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
16871708
val nestedGenerator = projectList.find(hasNestedGenerator).get
16881709
throw new AnalysisException("Generators are not supported when it's nested in " +
@@ -1740,7 +1761,7 @@ class Analyzer(
17401761
* that wrap the [[Generator]].
17411762
*/
17421763
object ResolveGenerate extends Rule[LogicalPlan] {
1743-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1764+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
17441765
case g: Generate if !g.child.resolved || !g.generator.resolved => g
17451766
case g: Generate if !g.resolved =>
17461767
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
@@ -2057,7 +2078,7 @@ class Analyzer(
20572078
* put them into an inner Project and finally project them away at the outer Project.
20582079
*/
20592080
object PullOutNondeterministic extends Rule[LogicalPlan] {
2060-
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2081+
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
20612082
case p if !p.resolved => p // Skip unresolved nodes.
20622083
case p: Project => p
20632084
case f: Filter => f
@@ -2102,7 +2123,7 @@ class Analyzer(
21022123
* and we should return null if the input is null.
21032124
*/
21042125
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
2105-
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2126+
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
21062127
case p if !p.resolved => p // Skip unresolved nodes.
21072128

21082129
case p => p transformExpressionsUp {
@@ -2167,7 +2188,7 @@ class Analyzer(
21672188
* Then apply a Project on a normal Join to eliminate natural or using join.
21682189
*/
21692190
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
2170-
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2191+
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
21712192
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
21722193
if left.resolved && right.resolved && j.duplicateResolved =>
21732194
commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
@@ -2232,7 +2253,7 @@ class Analyzer(
22322253
* to the given input attributes.
22332254
*/
22342255
object ResolveDeserializer extends Rule[LogicalPlan] {
2235-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2256+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
22362257
case p if !p.childrenResolved => p
22372258
case p if p.resolved => p
22382259

@@ -2318,7 +2339,7 @@ class Analyzer(
23182339
* constructed is an inner class.
23192340
*/
23202341
object ResolveNewInstance extends Rule[LogicalPlan] {
2321-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2342+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
23222343
case p if !p.childrenResolved => p
23232344
case p if p.resolved => p
23242345

@@ -2352,7 +2373,7 @@ class Analyzer(
23522373
"type of the field in the target object")
23532374
}
23542375

2355-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2376+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
23562377
case p if !p.childrenResolved => p
23572378
case p if p.resolved => p
23582379

@@ -2406,7 +2427,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
24062427
case other => trimAliases(other)
24072428
}
24082429

2409-
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2430+
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
24102431
case Project(projectList, child) =>
24112432
val cleanedProjectList =
24122433
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
@@ -2474,7 +2495,7 @@ object TimeWindowing extends Rule[LogicalPlan] {
24742495
* @return the logical plan that will generate the time windows using the Expand operator, with
24752496
* the Filter operator for correctness and Project for usability.
24762497
*/
2477-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2498+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
24782499
case p: LogicalPlan if p.children.size == 1 =>
24792500
val child = p.children.head
24802501
val windowExpressions =

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,12 @@ object SQLConf {
421421
.booleanConf
422422
.createWithDefault(true)
423423

424+
val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases")
425+
.doc("When true, aliases in a select list can be used in group by clauses. When false, " +
426+
"an analysis exception is thrown in the case.")
427+
.booleanConf
428+
.createWithDefault(true)
429+
424430
// The output committer class used by data sources. The specified class needs to be a
425431
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
426432
val OUTPUT_COMMITTER_CLASS =
@@ -1003,6 +1009,8 @@ class SQLConf extends Serializable with Logging {
10031009

10041010
def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
10051011

1012+
def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES)
1013+
10061014
def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
10071015

10081016
def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE)

sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1;
4949
-- group by ordinal followed by having
5050
select count(a), a from (select 1 as a) tmp group by 2 having a > 0;
5151

52+
-- mixed cases: group-by ordinals and aliases
53+
select a, a AS k, count(b) from data group by k, 1;
54+
5255
-- turn of group by ordinal
5356
set spark.sql.groupByOrdinal=false;
5457

sql/core/src/test/resources/sql-tests/inputs/group-by.sql

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,21 @@ FROM testData;
3535

3636
-- Aggregate with foldable input and multiple distinct groups.
3737
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;
38+
39+
-- Aliases in SELECT could be used in GROUP BY
40+
SELECT a AS k, COUNT(b) FROM testData GROUP BY k;
41+
SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1;
42+
43+
-- Aggregate functions cannot be used in GROUP BY
44+
SELECT COUNT(b) AS k FROM testData GROUP BY k;
45+
46+
-- Test data.
47+
CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES
48+
(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v);
49+
SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a;
50+
51+
-- turn off group by aliases
52+
set spark.sql.groupByAliases=false;
53+
54+
-- Check analysis exceptions
55+
SELECT a AS k, COUNT(b) FROM testData GROUP BY k;

sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 19
2+
-- Number of queries: 20
33

44

55
-- !query 0
@@ -173,16 +173,26 @@ struct<count(a):bigint,a:int>
173173

174174

175175
-- !query 17
176-
set spark.sql.groupByOrdinal=false
176+
select a, a AS k, count(b) from data group by k, 1
177177
-- !query 17 schema
178-
struct<key:string,value:string>
178+
struct<a:int,k:int,count(b):bigint>
179179
-- !query 17 output
180-
spark.sql.groupByOrdinal false
180+
1 1 2
181+
2 2 2
182+
3 3 2
181183

182184

183185
-- !query 18
184-
select sum(b) from data group by -1
186+
set spark.sql.groupByOrdinal=false
185187
-- !query 18 schema
186-
struct<sum(b):bigint>
188+
struct<key:string,value:string>
187189
-- !query 18 output
190+
spark.sql.groupByOrdinal false
191+
192+
193+
-- !query 19
194+
select sum(b) from data group by -1
195+
-- !query 19 schema
196+
struct<sum(b):bigint>
197+
-- !query 19 output
188198
9

0 commit comments

Comments
 (0)