Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def pivot(self, pivot_col, values=None):

>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
>>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
if values is None:
jgd = self._jgd.pivot(pivot_col)
Expand Down Expand Up @@ -296,6 +298,12 @@ def _test():
Row(course="dotNET", year=2012, earnings=5000),
Row(course="dotNET", year=2013, earnings=48000),
Row(course="Java", year=2013, earnings=30000)]).toDF()
globs['df5'] = sc.parallelize([
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)),
Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)),
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)),
Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)),
Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF()

(failure_count, test_count) = doctest.testmod(
pyspark.sql.group, globs=globs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,36 +340,52 @@ class RelationalGroupedDataset protected[sql](

/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we note this in Column API too, or note that this is an overloaded version of string's?

* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings")
* df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings")
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param pivotColumn the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
* @since 2.4.0
*/
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make diffs smaller, can you move this under the signature def pivot(pivotColumn: String, values: Seq[Any])?

groupType match {
case RelationalGroupedDataset.GroupByType =>
new RelationalGroupedDataset(
df,
groupingExprs,
RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply)))
RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply)))
case _: RelationalGroupedDataset.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case _ =>
throw new UnsupportedOperationException("pivot is only supported after a groupBy")
}
}

/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
*/
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
pivot(Column(pivotColumn), values)
}

/**
* (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
* aggregation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
import testImplicits._

test("pivot courses") {
val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
expected)
}

test("pivot year") {
val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)),
expected)
}

test("pivot courses with multiple aggregations") {
val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy($"year")
.pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"year")
.pivot($"course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
expected)
}

test("pivot year with string values (cast)") {
Expand Down Expand Up @@ -181,10 +193,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
}

test("pivot with datatype not supported by PivotFirst") {
val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil
checkAnswer(
complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")),
Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil
)
expected)
checkAnswer(
complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)),
expected)
}

test("pivot with datatype not supported by PivotFirst 2") {
Expand Down Expand Up @@ -246,4 +261,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone))
}
}

test("SPARK-24722: pivot trainings - nested columns") {
val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
checkAnswer(
trainingSales
.groupBy($"sales.year")
.pivot($"sales.course", Seq("dotNET", "Java"))
.agg(sum($"sales.earnings")),
expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,17 @@ private[sql] trait SQLTestData { self =>
df
}

protected lazy val trainingSales: DataFrame = {
val df = spark.sparkContext.parallelize(
TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) ::
TrainingSales("Experts", CourseSales("Java", 2012, 20000)) ::
TrainingSales("Dummies", CourseSales("dotNET", 2012, 5000)) ::
TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) ::
TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF()
df.createOrReplaceTempView("trainingSales")
df
}

/**
* Initialize all test data such that all temp tables are properly registered.
*/
Expand Down Expand Up @@ -310,4 +321,5 @@ private[sql] object SQLTestData {
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
case class CourseSales(course: String, year: Int, earnings: Double)
case class TrainingSales(training: String, sales: CourseSales)
}