Skip to content

Commit 599e9e0

Browse files
committed
Add pivot to dataframe api
1 parent 6175d6c commit 599e9e0

5 files changed

Lines changed: 123 additions & 0 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class Analyzer(
7272
ResolveRelations ::
7373
ResolveReferences ::
7474
ResolveGroupingAnalytics ::
75+
ResolvePivot ::
7576
ResolveSortReferences ::
7677
ResolveGenerate ::
7778
ResolveFunctions ::
@@ -166,6 +167,10 @@ class Analyzer(
166167
if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) =>
167168
g.withNewAggs(assignAliases(g.aggregations))
168169

170+
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregate, child)
171+
if child.resolved && groupByExprs.exists(_.isInstanceOf[UnresolvedAlias]) =>
172+
Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregate, child)
173+
169174
case Project(projectList, child)
170175
if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) =>
171176
Project(assignAliases(projectList), child)
@@ -249,6 +254,27 @@ class Analyzer(
249254
}
250255
}
251256

257+
object ResolvePivot extends Rule[LogicalPlan] {
258+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
259+
case p: Pivot if !p.childrenResolved => p
260+
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregate, child) => aggregate match {
261+
case u: UnaryExpression if u.isInstanceOf[AggregateExpression] =>
262+
val pivotAggregates = pivotValues.map { value =>
263+
val filteredAggregate = u.withNewChildren(Seq(
264+
If(EqualTo(pivotColumn, Literal(value)), u.child, Literal(null))
265+
))
266+
Alias(filteredAggregate, value)()
267+
}
268+
val newGroupByExprs = groupByExprs.map {
269+
case UnresolvedAlias(e) => e
270+
case e => e
271+
}
272+
Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
273+
case unknown => throw new AnalysisException(s"$unknown is not an aggregate expression")
274+
}
275+
}
276+
}
277+
252278
/**
253279
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
254280
*/
@@ -924,6 +950,7 @@ class Analyzer(
924950
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
925951
case p: Project => p
926952
case f: Filter => f
953+
case p: Pivot => p
927954

928955
// todo: It's hard to write a general rule to pull out nondeterministic expressions
929956
// from LogicalPlan, currently we only do it for UnaryNode which has same output

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,16 @@ case class Rollup(
373373
this.copy(aggregations = aggs)
374374
}
375375

376+
case class Pivot(
377+
groupByExprs: Seq[NamedExpression],
378+
pivotColumn: Expression,
379+
pivotValues: Seq[String],
380+
aggregate: Expression,
381+
child: LogicalPlan) extends UnaryNode {
382+
override def output: Seq[Attribute] =
383+
groupByExprs.map(_.toAttribute) ++ pivotValues.map(AttributeReference(_, aggregate.dataType)())
384+
}
385+
376386
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
377387
override def output: Seq[Attribute] = child.output
378388

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,41 @@ class DataFrame private[sql](
918918
GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType)
919919
}
920920

921+
/**
922+
* (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
923+
* aggregation.
924+
* {{{
925+
* // Compute the sum of earnings for each year by course with each course as a separate column.
926+
* df.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))
927+
* }}}
928+
* @param groupBy Columns to group by.
929+
* @param pivotColumn Column to pivot
930+
* @param pivotValues Values of pivotColumn that will be translated to columns in the output data
931+
* frame.
932+
* @param aggregate Aggregate expression to preform for each combination of groupBy and
933+
* pivotValues.
934+
* @group dfops
935+
* @since 1.5.0
936+
*/
937+
def pivot(
938+
groupBy: Seq[Column],
939+
pivotColumn: Column,
940+
pivotValues: Seq[String],
941+
aggregate: Column): DataFrame = {
942+
943+
val aliasedGroupBy = groupBy.map(_.expr).map {
944+
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
945+
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
946+
// make it a NamedExpression.
947+
case u: UnresolvedAttribute => UnresolvedAlias(u)
948+
case expr: NamedExpression => expr
949+
case expr: Expression => Alias(expr, expr.prettyString)()
950+
}
951+
952+
new DataFrame(sqlContext,
953+
Pivot(aliasedGroupBy, pivotColumn.expr, pivotValues, aggregate.expr, this.logicalPlan))
954+
}
955+
921956
/**
922957
* (Scala-specific) Aggregates on the entire [[DataFrame]] without groups.
923958
* {{{
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.sql.TestData._
21+
import org.apache.spark.sql.functions._
22+
23+
class DataFramePivotSuite extends QueryTest {
24+
25+
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
26+
import ctx.implicits._
27+
28+
test("pivot courses") {
29+
checkAnswer(
30+
courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings")),
31+
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
32+
)
33+
}
34+
35+
test("pivot year") {
36+
checkAnswer(
37+
courseSales.pivot(Seq($"course"), $"year", Seq("2012", "2013"), sum($"earnings")),
38+
Row("dotNet", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
39+
)
40+
}
41+
42+
}

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,13 @@ object TestData {
194194
:: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false)
195195
:: Nil).toDF()
196196
complexData.registerTempTable("complexData")
197+
198+
case class CourseSales(course: String, year: Int, earnings: Double)
199+
val courseSales = TestSQLContext.sparkContext.parallelize(
200+
CourseSales("dotNET", 2012, 10000) ::
201+
CourseSales("Java", 2012, 20000) ::
202+
CourseSales("dotNET", 2012, 5000) ::
203+
CourseSales("dotNET", 2013, 48000) ::
204+
CourseSales("Java", 2013, 30000) :: Nil).toDF()
205+
courseSales.registerTempTable("courseSales")
197206
}

0 commit comments

Comments
 (0)