Skip to content

Commit 18f4130

Browse files
committed
fix the issue when users manually specify grouping__id in the query.
1 parent 2f9eeb9 commit 18f4130

File tree

1 file changed

+20
-5
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis

1 file changed

+20
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,30 @@ class Analyzer(
216216
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
217217
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
218218
// Ensure all the expressions have been resolved.
219-
case g: GroupingSets if g.expressions.exists(!_.resolved) => g
220-
case x: GroupingSets =>
219+
case g: GroupingSets if g.expressions.exists(!_.resolved) =>
221220
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
222-
221+
// If users manually specify grouping__id in the aggregation expression, resolve it.
222+
val aggExprs = g.aggregations.map(_.transform {
223+
case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => gid
224+
}.asInstanceOf[NamedExpression])
225+
g.copy(aggregations = aggExprs, groupByExprs = g.groupByExprs :+ gid)
226+
case x: GroupingSets =>
227+
// Find the grouping ID AttributeReference that has been added above
228+
val (groupingID, groupByExprsWithoutGroupingID) = x.groupByExprs.partition {
229+
case u: AttributeReference => resolver(u.name, VirtualColumn.groupingIdName)
230+
case _ => false
231+
}
232+
// If found, use it; otherwise, create a new one.
233+
val gid = if (groupingID.nonEmpty) {
234+
groupingID.head.asInstanceOf[AttributeReference]
235+
} else {
236+
AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
237+
}
223238
// Expand works by setting grouping expressions to null as determined by the bitmasks. To
224239
// prevent these null values from being used in an aggregate instead of the original value
225240
// we need to create new aliases for all group by expressions that will only be used for
226241
// the intended purpose.
227-
val groupByAliases: Seq[Alias] = x.groupByExprs.map {
242+
val groupByAliases: Seq[Alias] = groupByExprsWithoutGroupingID.map {
228243
case e: NamedExpression => Alias(e, e.name)()
229244
case other => Alias(other, other.toString)()
230245
}
@@ -256,7 +271,7 @@ class Analyzer(
256271
val groupByAttributes = groupByAliases.map(attributeMap(_))
257272

258273
Aggregate(
259-
groupByAttributes :+ VirtualColumn.groupingIdAttribute,
274+
groupByAttributes :+ gid,
260275
aggregations,
261276
Expand(x.bitmasks, groupByAttributes, gid, child))
262277
}

0 commit comments

Comments
 (0)