Skip to content

Commit 5479066

Browse files
committed
Merge pull request apache#36 from marmbrus/partialAgg
Implement partial aggregation.
2 parents 67128b8 + 8017afb commit 5479066

File tree

5 files changed

+179
-11
lines changed

5 files changed

+179
-11
lines changed

src/main/scala/catalyst/execution/PlanningStrategies.scala

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,60 @@ trait PlanningStrategies {
119119
expr.references subsetOf plan.outputSet
120120
}
121121

122+
object PartialAggregation extends Strategy {
123+
def apply(plan: LogicalPlan): Seq[SharkPlan] = plan match {
124+
case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
125+
// Collect all aggregate expressions.
126+
val allAggregates =
127+
aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
128+
// Collect all aggregate expressions that can be computed partially.
129+
val partialAggregates =
130+
aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
131+
132+
// Only do partial aggregation if supported by all aggregate expressions.
133+
if (allAggregates.size == partialAggregates.size) {
134+
// Create a map of expressions to their partial evaluations for all aggregate expressions.
135+
val partialEvaluations: Map[Long, SplitEvaluation] =
136+
partialAggregates.map(a => (a.id, a.asPartial)).toMap
137+
138+
// We need to pass all grouping expressions though so the grouping can happen a second
139+
// time. However some of them might be unnamed so we alias them allowing them to be
140+
// referenced in the second aggregation.
141+
val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
142+
case n: NamedExpression => (n, n)
143+
case other => (other, Alias(other, "PartialGroup")())
144+
}.toMap
145+
146+
// Replace aggregations with a new expression that computes the result from the already
147+
// computed partial evaluations and grouping values.
148+
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
149+
case e: Expression if partialEvaluations.contains(e.id) =>
150+
partialEvaluations(e.id).finalEvaluation
151+
case e: Expression if namedGroupingExpressions.contains(e) =>
152+
namedGroupingExpressions(e).toAttribute
153+
}).asInstanceOf[Seq[NamedExpression]]
154+
155+
val partialComputation =
156+
(namedGroupingExpressions.values ++
157+
partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
158+
159+
// Construct two phased aggregation.
160+
execution.Aggregate(
161+
partial = false,
162+
namedGroupingExpressions.values.map(_.toAttribute).toSeq,
163+
rewrittenAggregateExpressions,
164+
execution.Aggregate(
165+
partial = true,
166+
groupingExpressions,
167+
partialComputation,
168+
planLater(child))(sc))(sc) :: Nil
169+
} else {
170+
Nil
171+
}
172+
case _ => Nil
173+
}
174+
}
175+
122176
object BroadcastNestedLoopJoin extends Strategy {
123177
def apply(plan: LogicalPlan): Seq[SharkPlan] = plan match {
124178
case logical.Join(left, right, joinType, condition) =>
@@ -143,7 +197,8 @@ trait PlanningStrategies {
143197
object BasicOperators extends Strategy {
144198
def apply(plan: LogicalPlan): Seq[SharkPlan] = plan match {
145199
case logical.Distinct(child) =>
146-
execution.Aggregate(child.output, child.output, planLater(child))(sc) :: Nil
200+
execution.Aggregate(
201+
partial = false, child.output, child.output, planLater(child))(sc) :: Nil
147202
case logical.Sort(sortExprs, child) =>
148203
// This sort is a global sort. Its requiredDistribution will be an OrderedDistribution.
149204
execution.Sort(sortExprs, global = true, planLater(child)):: Nil
@@ -156,7 +211,7 @@ trait PlanningStrategies {
156211
case logical.Filter(condition, child) =>
157212
execution.Filter(condition, planLater(child)) :: Nil
158213
case logical.Aggregate(group, agg, child) =>
159-
execution.Aggregate(group, agg, planLater(child))(sc) :: Nil
214+
execution.Aggregate(partial = false, group, agg, planLater(child))(sc) :: Nil
160215
case logical.Sample(fraction, withReplacement, seed, child) =>
161216
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
162217
case logical.LocalRelation(output, data) =>

src/main/scala/catalyst/execution/SharkInstance.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ abstract class SharkInstance extends Logging {
7373
object TrivialPlanner extends QueryPlanner[SharkPlan] with PlanningStrategies {
7474
val sc = self.sc
7575
val strategies =
76+
PartialAggregation ::
7677
SparkEquiInnerJoin ::
7778
PartitionPrunings ::
7879
HiveTableScans ::

src/main/scala/catalyst/execution/aggregates.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,37 @@ import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFEvaluator, AbstractGene
55

66
import catalyst.errors._
77
import catalyst.expressions._
8-
import catalyst.plans.physical.{ClusteredDistribution, AllTuples}
8+
import catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples}
99

1010
/* Implicits */
1111
import org.apache.spark.rdd.SharkPairRDDFunctions._
1212

13+
/**
14+
* Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each
15+
* group.
16+
*
17+
* @param partial if true then aggregation is done partially on local data without shuffling to
18+
* ensure all values where `groupingExpressions` are equal are present.
19+
* @param groupingExpressions expressions that are evaluated to determine grouping.
20+
* @param aggregateExpressions expressions that are computed for each group.
21+
* @param child the input data source.
22+
*/
1323
case class Aggregate(
24+
partial: Boolean,
1425
groupingExpressions: Seq[Expression],
1526
aggregateExpressions: Seq[NamedExpression],
1627
child: SharkPlan)(@transient sc: SharkContext)
1728
extends UnaryNode {
1829

1930
override def requiredChildDistribution =
20-
if (groupingExpressions == Nil) {
21-
AllTuples :: Nil
31+
if (partial) {
32+
UnspecifiedDistribution :: Nil
2233
} else {
23-
ClusteredDistribution(groupingExpressions) :: Nil
34+
if (groupingExpressions == Nil) {
35+
AllTuples :: Nil
36+
} else {
37+
ClusteredDistribution(groupingExpressions) :: Nil
38+
}
2439
}
2540

2641
override def otherCopyArgs = sc :: Nil

src/main/scala/catalyst/expressions/aggregates.scala

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,34 @@ abstract class AggregateExpression extends Expression {
88

99
}
1010

11+
/**
12+
* Represents an aggregation that has been rewritten to be performed in two steps.
13+
*
14+
* @param finalEvaluation an aggregate expression that evaluates to same final result as the
15+
* original aggregation.
16+
* @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial
17+
* data sets and are required to compute the `finalEvaluation`.
18+
*/
19+
case class SplitEvaluation(
20+
finalEvaluation: Expression,
21+
partialEvaluations: Seq[NamedExpression])
22+
23+
/**
24+
* An [[AggregateExpression]] that can be partially computed without seeing all relevent tuples.
25+
* These partial evaluations can then be combined to compute the actual answer.
26+
*/
27+
abstract class PartialAggregate extends AggregateExpression {
28+
self: Product =>
29+
30+
/**
31+
* Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation.
32+
*/
33+
def asPartial: SplitEvaluation
34+
}
35+
1136
/**
1237
* A specific implementation of an aggregate function. Used to wrap a generic
13-
* [[AggregateExpression]] with an algorithm that will be used to compute the result.
38+
* [[AggregateExpression]] with an algorithm that will be used to compute one specific result.
1439
*/
1540
abstract class AggregateFunction
1641
extends AggregateExpression with Serializable with trees.LeafNode[Expression] {
@@ -26,11 +51,16 @@ abstract class AggregateFunction
2651
def result: Any
2752
}
2853

29-
case class Count(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] {
54+
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
3055
def references = child.references
3156
def nullable = false
3257
def dataType = IntegerType
3358
override def toString = s"COUNT($child)"
59+
60+
def asPartial: SplitEvaluation = {
61+
val partialCount = Alias(Count(child), "PartialCount")()
62+
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
63+
}
3464
}
3565

3666
case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
@@ -41,23 +71,48 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
4171
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
4272
}
4373

44-
case class Average(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] {
74+
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
4575
def references = child.references
4676
def nullable = false
4777
def dataType = DoubleType
4878
override def toString = s"AVG($child)"
79+
80+
override def asPartial: SplitEvaluation = {
81+
val partialSum = Alias(Sum(child), "PartialSum")()
82+
val partialCount = Alias(Count(child), "PartialCount")()
83+
val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
84+
val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
85+
86+
SplitEvaluation(
87+
Divide(castedSum, castedCount),
88+
partialCount :: partialSum :: Nil)
89+
}
4990
}
5091

51-
case class Sum(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] {
92+
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
5293
def references = child.references
5394
def nullable = false
5495
def dataType = child.dataType
5596
override def toString = s"SUM($child)"
97+
98+
override def asPartial: SplitEvaluation = {
99+
val partialSum = Alias(Sum(child), "PartialSum")()
100+
SplitEvaluation(
101+
Sum(partialSum.toAttribute),
102+
partialSum :: Nil)
103+
}
56104
}
57105

58-
case class First(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] {
106+
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
59107
def references = child.references
60108
def nullable = child.nullable
61109
def dataType = child.dataType
62110
override def toString = s"FIRST($child)"
111+
112+
override def asPartial: SplitEvaluation = {
113+
val partialFirst = Alias(First(child), "PartialFirst")()
114+
SplitEvaluation(
115+
First(partialFirst.toAttribute),
116+
partialFirst :: Nil)
117+
}
63118
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package catalyst
2+
package execution
3+
4+
import org.scalatest.FunSuite
5+
6+
import catalyst.expressions._
7+
import catalyst.plans.logical
8+
import catalyst.dsl._
9+
10+
class PlannerSuite extends FunSuite {
11+
import TestData._
12+
13+
test("unions are collapsed") {
14+
val query = testData.unionAll(testData).unionAll(testData)
15+
val planned = TestShark.TrivialPlanner.BasicOperators(query).head
16+
val logicalUnions = query collect { case u: logical.Union => u}
17+
val physicalUnions = planned collect { case u: execution.Union => u}
18+
19+
assert(logicalUnions.size === 2)
20+
assert(physicalUnions.size === 1)
21+
}
22+
23+
test("count is partially aggregated") {
24+
val query = testData.groupBy('value)(Count('key)).analyze
25+
val planned = TestShark.TrivialPlanner.PartialAggregation(query).head
26+
val aggregations = planned.collect { case a: Aggregate => a }
27+
28+
assert(aggregations.size === 2)
29+
}
30+
31+
test("count distinct is not partially aggregated") {
32+
val query = testData.groupBy('value)(CountDistinct('key :: Nil)).analyze
33+
val planned = TestShark.TrivialPlanner.PartialAggregation(query)
34+
assert(planned.isEmpty)
35+
}
36+
37+
test("mixed aggregates are not partially aggregated") {
38+
val query = testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).analyze
39+
val planned = TestShark.TrivialPlanner.PartialAggregation(query)
40+
assert(planned.isEmpty)
41+
}
42+
}

0 commit comments

Comments
 (0)