Skip to content

Commit fee739f

Browse files
gatorsmiledavies
authored andcommitted
[SPARK-13221] [SQL] Fixing GroupingSets when Aggregate Functions Containing GroupBy Columns
Using GroupingSets will generate a wrong result when Aggregate Functions containing GroupBy columns. This PR is to fix it. Since the code changes are very small. Maybe we also can merge it to 1.6 For example, the following query returns a wrong result: ```scala sql("select course, sum(earnings) as sum from courseSales group by course, earnings" + " grouping sets((), (course), (course, earnings))" + " order by course, sum").show() ``` Before the fix, the results are like ``` [null,null] [Java,null] [Java,20000.0] [Java,30000.0] [dotNET,null] [dotNET,5000.0] [dotNET,10000.0] [dotNET,48000.0] ``` After the fix, the results become correct: ``` [null,113000.0] [Java,20000.0] [Java,30000.0] [Java,50000.0] [dotNET,5000.0] [dotNET,10000.0] [dotNET,48000.0] [dotNET,63000.0] ``` UPDATE: This PR also deprecated the external column: GROUPING__ID. Author: gatorsmile <[email protected]> Closes #11100 from gatorsmile/groupingSets.
1 parent e4675c2 commit fee739f

10 files changed

Lines changed: 155 additions & 107 deletions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,23 @@ class Analyzer(
209209
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
210210
}
211211

212+
private def hasGroupingId(expr: Seq[Expression]): Boolean = {
213+
expr.exists(_.collectFirst {
214+
case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u
215+
}.isDefined)
216+
}
217+
212218
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
213219
case a if !a.childrenResolved => a // be sure all of the children are resolved.
214220
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
215221
GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
216222
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
217223
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
218-
case x: GroupingSets =>
224+
case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) =>
225+
failAnalysis(
226+
s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead")
227+
// Ensure all the expressions have been resolved.
228+
case x: GroupingSets if x.expressions.forall(_.resolved) =>
219229
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
220230

221231
// Expand works by setting grouping expressions to null as determined by the bitmasks. To

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2040,6 +2040,36 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
20402040
)
20412041
}
20422042

2043+
test("grouping sets when aggregate functions containing groupBy columns") {
2044+
checkAnswer(
2045+
sql("select course, sum(earnings) as sum from courseSales group by course, earnings " +
2046+
"grouping sets((), (course), (course, earnings)) " +
2047+
"order by course, sum"),
2048+
Row(null, 113000.0) ::
2049+
Row("Java", 20000.0) ::
2050+
Row("Java", 30000.0) ::
2051+
Row("Java", 50000.0) ::
2052+
Row("dotNET", 5000.0) ::
2053+
Row("dotNET", 10000.0) ::
2054+
Row("dotNET", 48000.0) ::
2055+
Row("dotNET", 63000.0) :: Nil
2056+
)
2057+
2058+
checkAnswer(
2059+
sql("select course, sum(earnings) as sum, grouping_id(course, earnings) from courseSales " +
2060+
"group by course, earnings grouping sets((), (course), (course, earnings)) " +
2061+
"order by course, sum"),
2062+
Row(null, 113000.0, 3) ::
2063+
Row("Java", 20000.0, 0) ::
2064+
Row("Java", 30000.0, 0) ::
2065+
Row("Java", 50000.0, 1) ::
2066+
Row("dotNET", 5000.0, 0) ::
2067+
Row("dotNET", 10000.0, 0) ::
2068+
Row("dotNET", 48000.0, 0) ::
2069+
Row("dotNET", 63000.0, 1) :: Nil
2070+
)
2071+
}
2072+
20432073
test("cube") {
20442074
checkAnswer(
20452075
sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"),
@@ -2103,6 +2133,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
21032133
sql("select course, year, grouping_id(course, year) from courseSales group by course, year")
21042134
}
21052135
assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup")
2136+
error = intercept[AnalysisException] {
2137+
sql("select course, year, grouping__id from courseSales group by cube(course, year)")
2138+
}
2139+
assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
21062140
}
21072141

21082142
test("SPARK-13056: Null in map value causes NPE") {

sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864

Lines changed: 0 additions & 6 deletions
This file was deleted.

sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896

Lines changed: 0 additions & 10 deletions
This file was deleted.

sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c

Lines changed: 0 additions & 10 deletions
This file was deleted.

sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a

Lines changed: 0 additions & 6 deletions
This file was deleted.

sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89

Lines changed: 0 additions & 10 deletions
This file was deleted.

sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce

Lines changed: 0 additions & 10 deletions
This file was deleted.

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -123,60 +123,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
123123
assertBroadcastNestedLoopJoin(spark_10484_4)
124124
}
125125

126-
createQueryTest("SPARK-8976 Wrong Result for Rollup #1",
127-
"""
128-
SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP
129-
""".stripMargin)
130-
131-
createQueryTest("SPARK-8976 Wrong Result for Rollup #2",
132-
"""
133-
SELECT
134-
count(*) AS cnt,
135-
key % 5 as k1,
136-
key-5 as k2,
137-
GROUPING__ID as k3
138-
FROM src group by key%5, key-5
139-
WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
140-
""".stripMargin)
141-
142-
createQueryTest("SPARK-8976 Wrong Result for Rollup #3",
143-
"""
144-
SELECT
145-
count(*) AS cnt,
146-
key % 5 as k1,
147-
key-5 as k2,
148-
GROUPING__ID as k3
149-
FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5
150-
WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
151-
""".stripMargin)
152-
153-
createQueryTest("SPARK-8976 Wrong Result for CUBE #1",
154-
"""
155-
SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH CUBE
156-
""".stripMargin)
157-
158-
createQueryTest("SPARK-8976 Wrong Result for CUBE #2",
159-
"""
160-
SELECT
161-
count(*) AS cnt,
162-
key % 5 as k1,
163-
key-5 as k2,
164-
GROUPING__ID as k3
165-
FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5
166-
WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10
167-
""".stripMargin)
168-
169-
createQueryTest("SPARK-8976 Wrong Result for GroupingSet",
170-
"""
171-
SELECT
172-
count(*) AS cnt,
173-
key % 5 as k1,
174-
key-5 as k2,
175-
GROUPING__ID as k3
176-
FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5
177-
GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10
178-
""".stripMargin)
179-
180126
createQueryTest("insert table with generator with column name",
181127
"""
182128
| CREATE TABLE gen_tmp (key Int);

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,6 +1551,116 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
15511551
}
15521552
}
15531553

1554+
test("SPARK-8976 Wrong Result for Rollup #1") {
1555+
checkAnswer(sql(
1556+
"SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH ROLLUP"),
1557+
Seq(
1558+
(113, 3, 0),
1559+
(91, 0, 0),
1560+
(500, null, 1),
1561+
(84, 1, 0),
1562+
(105, 2, 0),
1563+
(107, 4, 0)
1564+
).map(i => Row(i._1, i._2, i._3)))
1565+
}
1566+
1567+
test("SPARK-8976 Wrong Result for Rollup #2") {
1568+
checkAnswer(sql(
1569+
"""
1570+
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
1571+
|FROM src GROUP BY key%5, key-5
1572+
|WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
1573+
""".stripMargin),
1574+
Seq(
1575+
(1, 0, 5, 0),
1576+
(1, 0, 15, 0),
1577+
(1, 0, 25, 0),
1578+
(1, 0, 60, 0),
1579+
(1, 0, 75, 0),
1580+
(1, 0, 80, 0),
1581+
(1, 0, 100, 0),
1582+
(1, 0, 140, 0),
1583+
(1, 0, 145, 0),
1584+
(1, 0, 150, 0)
1585+
).map(i => Row(i._1, i._2, i._3, i._4)))
1586+
}
1587+
1588+
test("SPARK-8976 Wrong Result for Rollup #3") {
1589+
checkAnswer(sql(
1590+
"""
1591+
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
1592+
|FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5
1593+
|WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
1594+
""".stripMargin),
1595+
Seq(
1596+
(1, 0, 5, 0),
1597+
(1, 0, 15, 0),
1598+
(1, 0, 25, 0),
1599+
(1, 0, 60, 0),
1600+
(1, 0, 75, 0),
1601+
(1, 0, 80, 0),
1602+
(1, 0, 100, 0),
1603+
(1, 0, 140, 0),
1604+
(1, 0, 145, 0),
1605+
(1, 0, 150, 0)
1606+
).map(i => Row(i._1, i._2, i._3, i._4)))
1607+
}
1608+
1609+
test("SPARK-8976 Wrong Result for CUBE #1") {
1610+
checkAnswer(sql(
1611+
"SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"),
1612+
Seq(
1613+
(113, 3, 0),
1614+
(91, 0, 0),
1615+
(500, null, 1),
1616+
(84, 1, 0),
1617+
(105, 2, 0),
1618+
(107, 4, 0)
1619+
).map(i => Row(i._1, i._2, i._3)))
1620+
}
1621+
1622+
test("SPARK-8976 Wrong Result for CUBE #2") {
1623+
checkAnswer(sql(
1624+
"""
1625+
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
1626+
|FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5
1627+
|WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10
1628+
""".stripMargin),
1629+
Seq(
1630+
(1, null, -3, 2),
1631+
(1, null, -1, 2),
1632+
(1, null, 3, 2),
1633+
(1, null, 4, 2),
1634+
(1, null, 5, 2),
1635+
(1, null, 6, 2),
1636+
(1, null, 12, 2),
1637+
(1, null, 14, 2),
1638+
(1, null, 15, 2),
1639+
(1, null, 22, 2)
1640+
).map(i => Row(i._1, i._2, i._3, i._4)))
1641+
}
1642+
1643+
test("SPARK-8976 Wrong Result for GroupingSet") {
1644+
checkAnswer(sql(
1645+
"""
1646+
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
1647+
|FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5
1648+
|GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10
1649+
""".stripMargin),
1650+
Seq(
1651+
(1, null, -3, 2),
1652+
(1, null, -1, 2),
1653+
(1, null, 3, 2),
1654+
(1, null, 4, 2),
1655+
(1, null, 5, 2),
1656+
(1, null, 6, 2),
1657+
(1, null, 12, 2),
1658+
(1, null, 14, 2),
1659+
(1, null, 15, 2),
1660+
(1, null, 22, 2)
1661+
).map(i => Row(i._1, i._2, i._3, i._4)))
1662+
}
1663+
15541664
test("SPARK-10562: partition by column with mixed case name") {
15551665
withTable("tbl10562") {
15561666
val df = Seq(2012 -> "a").toDF("Year", "val")

0 commit comments

Comments
 (0)