@@ -216,15 +216,30 @@ class Analyzer(
216216 case Aggregate (Seq (r @ Rollup (groupByExprs)), aggregateExpressions, child) =>
217217 GroupingSets (bitmasks(r), groupByExprs, child, aggregateExpressions)
218218 // Ensure all the expressions have been resolved.
219- case g : GroupingSets if g.expressions.exists(! _.resolved) => g
220- case x : GroupingSets =>
219+ case g : GroupingSets if g.expressions.exists(! _.resolved) =>
221220 val gid = AttributeReference (VirtualColumn .groupingIdName, IntegerType , false )()
222-
221+ // If users manually specify grouping__id in the aggregation expression, resolve it.
222+ val aggExprs = g.aggregations.map(_.transform {
223+ case u : UnresolvedAttribute if resolver(u.name, VirtualColumn .groupingIdName) => gid
224+ }.asInstanceOf [NamedExpression ])
225+ g.copy(aggregations = aggExprs, groupByExprs = g.groupByExprs :+ gid)
226+ case x : GroupingSets =>
227+ // Find the grouping ID AttributeReference that has been added above
228+ val (groupingID, groupByExprsWithoutGroupingID) = x.groupByExprs.partition {
229+ case u : AttributeReference => resolver(u.name, VirtualColumn .groupingIdName)
230+ case _ => false
231+ }
232+ // If found, use it; otherwise, create a new one.
233+ val gid = if (groupingID.nonEmpty) {
234+ groupingID.head.asInstanceOf [AttributeReference ]
235+ } else {
236+ AttributeReference (VirtualColumn .groupingIdName, IntegerType , false )()
237+ }
223238 // Expand works by setting grouping expressions to null as determined by the bitmasks. To
224239 // prevent these null values from being used in an aggregate instead of the original value
225240 // we need to create new aliases for all group by expressions that will only be used for
226241 // the intended purpose.
227- val groupByAliases : Seq [Alias ] = x.groupByExprs .map {
242+ val groupByAliases : Seq [Alias ] = groupByExprsWithoutGroupingID .map {
228243 case e : NamedExpression => Alias (e, e.name)()
229244 case other => Alias (other, other.toString)()
230245 }
@@ -256,7 +271,7 @@ class Analyzer(
256271 val groupByAttributes = groupByAliases.map(attributeMap(_))
257272
258273 Aggregate (
259- groupByAttributes :+ VirtualColumn .groupingIdAttribute ,
274+ groupByAttributes :+ gid ,
260275 aggregations,
261276 Expand (x.bitmasks, groupByAttributes, gid, child))
262277 }
0 commit comments