Skip to content

Commit 78cf93b

Browse files
committed
Revert "[SPARK-12978][SQL] Skip unnecessary final group-by when input data already clustered with group-by keys"
This reverts commit 2b0cc4e.
1 parent 428dd1b commit 78cf93b

File tree

8 files changed

+224
-257
lines changed

8 files changed

+224
-257
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,17 +259,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
259259
}
260260

261261
val aggregateOperator =
262-
if (functionsWithDistinct.isEmpty) {
262+
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
263+
if (functionsWithDistinct.nonEmpty) {
264+
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
265+
"aggregate functions which don't support partial aggregation.")
266+
} else {
267+
aggregate.AggUtils.planAggregateWithoutPartial(
268+
groupingExpressions,
269+
aggregateExpressions,
270+
resultExpressions,
271+
planLater(child))
272+
}
273+
} else if (functionsWithDistinct.isEmpty) {
263274
aggregate.AggUtils.planAggregateWithoutDistinct(
264275
groupingExpressions,
265276
aggregateExpressions,
266277
resultExpressions,
267278
planLater(child))
268279
} else {
269-
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
270-
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
271-
"aggregate functions which don't support partial aggregation.")
272-
}
273280
aggregate.AggUtils.planAggregateWithOneDistinct(
274281
groupingExpressions,
275282
functionsWithDistinct,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 135 additions & 115 deletions
Large diffs are not rendered by default.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala

Lines changed: 0 additions & 56 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.errors._
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate._
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
27+
import org.apache.spark.sql.catalyst.plans.physical._
2728
import org.apache.spark.sql.execution._
2829
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
2930
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
@@ -41,7 +42,11 @@ case class HashAggregateExec(
4142
initialInputBufferOffset: Int,
4243
resultExpressions: Seq[NamedExpression],
4344
child: SparkPlan)
44-
extends AggregateExec with CodegenSupport {
45+
extends UnaryExecNode with CodegenSupport {
46+
47+
private[this] val aggregateBufferAttributes = {
48+
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
49+
}
4550

4651
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
4752

@@ -55,6 +60,21 @@ case class HashAggregateExec(
5560
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
5661
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
5762

63+
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
64+
65+
override def producedAttributes: AttributeSet =
66+
AttributeSet(aggregateAttributes) ++
67+
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
68+
AttributeSet(aggregateBufferAttributes)
69+
70+
override def requiredChildDistribution: List[Distribution] = {
71+
requiredChildDistributionExpressions match {
72+
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
73+
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
74+
case None => UnspecifiedDistribution :: Nil
75+
}
76+
}
77+
5878
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
5979
// map and/or the sort-based aggregation once it has processed a given number of input rows.
6080
private val testFallbackStartsAt: Option[(Int, Int)] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.errors._
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate._
25-
import org.apache.spark.sql.execution.SparkPlan
25+
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
26+
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
2627
import org.apache.spark.sql.execution.metric.SQLMetrics
2728
import org.apache.spark.util.Utils
2829

@@ -37,11 +38,30 @@ case class SortAggregateExec(
3738
initialInputBufferOffset: Int,
3839
resultExpressions: Seq[NamedExpression],
3940
child: SparkPlan)
40-
extends AggregateExec {
41+
extends UnaryExecNode {
42+
43+
private[this] val aggregateBufferAttributes = {
44+
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
45+
}
46+
47+
override def producedAttributes: AttributeSet =
48+
AttributeSet(aggregateAttributes) ++
49+
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
50+
AttributeSet(aggregateBufferAttributes)
4151

4252
override lazy val metrics = Map(
4353
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
4454

55+
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
56+
57+
override def requiredChildDistribution: List[Distribution] = {
58+
requiredChildDistributionExpressions match {
59+
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
60+
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
61+
case None => UnspecifiedDistribution :: Nil
62+
}
63+
}
64+
4565
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
4666
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
4767
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.plans.physical._
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323
import org.apache.spark.sql.execution._
24-
import org.apache.spark.sql.execution.aggregate.AggUtils
25-
import org.apache.spark.sql.execution.aggregate.PartialAggregate
2624
import org.apache.spark.sql.internal.SQLConf
2725

2826
/**
@@ -153,30 +151,18 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
153151
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
154152
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
155153
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
156-
assert(requiredChildDistributions.length == operator.children.length)
157-
assert(requiredChildOrderings.length == operator.children.length)
154+
var children: Seq[SparkPlan] = operator.children
155+
assert(requiredChildDistributions.length == children.length)
156+
assert(requiredChildOrderings.length == children.length)
158157

159-
def createShuffleExchange(dist: Distribution, child: SparkPlan) =
160-
ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child)
161-
162-
var (parent, children) = operator match {
163-
case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) =>
164-
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
165-
// aggregation and a shuffle are added as children.
166-
val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
167-
(mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil)
168-
case _ =>
169-
// Ensure that the operator's children satisfy their output distribution requirements:
170-
val childrenWithDist = operator.children.zip(requiredChildDistributions)
171-
val newChildren = childrenWithDist.map {
172-
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
173-
child
174-
case (child, BroadcastDistribution(mode)) =>
175-
BroadcastExchangeExec(mode, child)
176-
case (child, distribution) =>
177-
createShuffleExchange(distribution, child)
178-
}
179-
(operator, newChildren)
158+
// Ensure that the operator's children satisfy their output distribution requirements:
159+
children = children.zip(requiredChildDistributions).map {
160+
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
161+
child
162+
case (child, BroadcastDistribution(mode)) =>
163+
BroadcastExchangeExec(mode, child)
164+
case (child, distribution) =>
165+
ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
180166
}
181167

182168
// If the operator has multiple children and specifies child output distributions (e.g. join),
@@ -269,7 +255,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
269255
}
270256
}
271257

272-
parent.withNewChildren(children)
258+
operator.withNewChildren(children)
273259
}
274260

275261
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
12481248
}
12491249

12501250
/**
1251-
* Verifies that there is a single Aggregation for `df`
1251+
* Verifies that there is no Exchange between the Aggregations for `df`
12521252
*/
1253-
private def verifyNonExchangingSingleAgg(df: DataFrame) = {
1253+
private def verifyNonExchangingAgg(df: DataFrame) = {
12541254
var atFirstAgg: Boolean = false
12551255
df.queryExecution.executedPlan.foreach {
12561256
case agg: HashAggregateExec =>
1257+
atFirstAgg = !atFirstAgg
1258+
case _ =>
12571259
if (atFirstAgg) {
1258-
fail("Should not have back to back Aggregates")
1260+
fail("Should not have operators between the two aggregations")
12591261
}
1260-
atFirstAgg = true
1261-
case _ =>
12621262
}
12631263
}
12641264

@@ -1292,10 +1292,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
12921292
// Group by the column we are distributed by. This should generate a plan with no exchange
12931293
// between the aggregates
12941294
val df3 = testData.repartition($"key").groupBy("key").count()
1295-
verifyNonExchangingSingleAgg(df3)
1296-
verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
1295+
verifyNonExchangingAgg(df3)
1296+
verifyNonExchangingAgg(testData.repartition($"key", $"value")
12971297
.groupBy("key", "value").count())
1298-
verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count())
12991298

13001299
// Grouping by just the first distributeBy expr, need to exchange.
13011300
verifyExchangingAgg(testData.repartition($"key", $"value")

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.rdd.RDD
21-
import org.apache.spark.sql.{execution, DataFrame, Row}
21+
import org.apache.spark.sql.{execution, Row}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.Inner
@@ -37,65 +37,36 @@ class PlannerSuite extends SharedSQLContext {
3737

3838
setupTestData()
3939

40-
private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = {
40+
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
4141
val planner = spark.sessionState.planner
4242
import planner._
43-
val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
44-
val planned = Aggregation(query).headOption.map(ensureRequirements(_))
45-
.getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
46-
planned.collect { case n if n.nodeName contains "Aggregate" => n }
43+
val plannedOption = Aggregation(query).headOption
44+
val planned =
45+
plannedOption.getOrElse(
46+
fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
47+
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
48+
49+
// For the new aggregation code path, there will be four aggregate operator for
50+
// distinct aggregations.
51+
assert(
52+
aggregations.size == 2 || aggregations.size == 4,
53+
s"The plan of query $query does not have partial aggregations.")
4754
}
4855

4956
test("count is partially aggregated") {
5057
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
51-
assert(testPartialAggregationPlan(query).size == 2,
52-
s"The plan of query $query does not have partial aggregations.")
58+
testPartialAggregationPlan(query)
5359
}
5460

5561
test("count distinct is partially aggregated") {
5662
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
5763
testPartialAggregationPlan(query)
58-
// For the new aggregation code path, there will be four aggregate operator for distinct
59-
// aggregations.
60-
assert(testPartialAggregationPlan(query).size == 4,
61-
s"The plan of query $query does not have partial aggregations.")
6264
}
6365

6466
test("mixed aggregates are partially aggregated") {
6567
val query =
6668
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
67-
// For the new aggregation code path, there will be four aggregate operator for distinct
68-
// aggregations.
69-
assert(testPartialAggregationPlan(query).size == 4,
70-
s"The plan of query $query does not have partial aggregations.")
71-
}
72-
73-
test("non-partial aggregation for aggregates") {
74-
withTempView("testNonPartialAggregation") {
75-
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
76-
val row = Row.fromSeq(Seq.fill(1)(null))
77-
val rowRDD = sparkContext.parallelize(row :: Nil)
78-
spark.createDataFrame(rowRDD, schema).repartition($"value")
79-
.createOrReplaceTempView("testNonPartialAggregation")
80-
81-
val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value")
82-
.queryExecution.executedPlan
83-
84-
// If input data are already partitioned and the same columns are used in grouping keys and
85-
// aggregation values, no partial aggregation exist in query plans.
86-
val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
87-
assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.")
88-
89-
val planned2 = sql(
90-
"""
91-
|SELECT t.value, SUM(DISTINCT t.value)
92-
|FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
93-
|GROUP BY t.value
94-
""".stripMargin).queryExecution.executedPlan
95-
96-
val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
97-
assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.")
98-
}
69+
testPartialAggregationPlan(query)
9970
}
10071

10172
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {

0 commit comments

Comments
 (0)