Skip to content

Commit 55c3347

Browse files
EnricoMicloud-fan
authored andcommitted
[SPARK-38864][SQL] Add unpivot / melt to Dataset
### What changes were proposed in this pull request? This proposes a Scala implementation of the `melt` (aka. `unpivot`) operation. ### Why are the changes needed? The Scala Dataset API provides the `pivot` operation, but not its reverse operation `unpivot` or `melt`. The `melt` operation is part of the [Pandas API](https://pandas.pydata.org/docs/reference/api/pandas.melt.html), which is why this method is provided by PySpark Pandas API, implemented purely in Python. [It should be implemented in Scala](#26912 (review)) to make this operation available to Scala / Java, SQL, PySpark, and to reuse the implementation in PySpark Pandas APIs. The `melt` / `unpivot` operation exists in other systems like [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#unpivot_operator), [T-SQL](https://docs.microsoft.com/en-us/sql/t-sql/queries/from-using-pivot-and-unpivot?view=sql-server-ver15#unpivot-example), [Oracle](https://www.oracletutorial.com/oracle-basics/oracle-unpivot/). It supports expressions for ids and value columns including `*` expansion and structs. So this also fixes / includes SPARK-39292. ### Does this PR introduce _any_ user-facing change? It adds `melt` to the `Dataset` API (Scala and Java). ### How was this patch tested? It is tested in the `DatasetMeltSuite` and `JavaDatasetSuite`. Closes #36150 from EnricoMi/branch-melt. Authored-by: Enrico Minack <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 5c0d551 commit 55c3347

File tree

13 files changed

+837
-2
lines changed

13 files changed

+837
-2
lines changed

core/src/main/resources/error/error-classes.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,18 @@
375375
"Unable to acquire <requestedBytes> bytes of memory, got <receivedBytes>"
376376
]
377377
},
378+
"UNPIVOT_REQUIRES_VALUE_COLUMNS" : {
379+
"message" : [
380+
"At least one value column needs to be specified for UNPIVOT, all columns specified as ids"
381+
],
382+
"sqlState" : "42000"
383+
},
384+
"UNPIVOT_VALUE_DATA_TYPE_MISMATCH" : {
385+
"message" : [
386+
"Unpivot value columns must share a least common type, some types do not: [<types>]"
387+
],
388+
"sqlState" : "42000"
389+
},
378390
"UNRECOGNIZED_SQL_TYPE" : {
379391
"message" : [
380392
"Unrecognized SQL type <typeName>"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class Analyzer(override val catalogManager: CatalogManager)
293293
ResolveUpCast ::
294294
ResolveGroupingAnalytics ::
295295
ResolvePivot ::
296+
ResolveUnpivot ::
296297
ResolveOrdinalInOrderByAndGroupBy ::
297298
ResolveAggAliasInGroupBy ::
298299
ResolveMissingReferences ::
@@ -514,6 +515,10 @@ class Analyzer(override val catalogManager: CatalogManager)
514515
if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) =>
515516
Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child)
516517

518+
case up: Unpivot if up.child.resolved &&
519+
(hasUnresolvedAlias(up.ids) || hasUnresolvedAlias(up.values)) =>
520+
up.copy(ids = assignAliases(up.ids), values = assignAliases(up.values))
521+
517522
case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
518523
Project(assignAliases(projectList), child)
519524

@@ -859,6 +864,36 @@ class Analyzer(override val catalogManager: CatalogManager)
859864
}
860865
}
861866

867+
object ResolveUnpivot extends Rule[LogicalPlan] {
868+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
869+
_.containsPattern(UNPIVOT), ruleId) {
870+
871+
// once children and ids are resolved, we can determine values, if non were given
872+
case up: Unpivot if up.childrenResolved && up.ids.forall(_.resolved) && up.values.isEmpty =>
873+
up.copy(values = up.child.output.diff(up.ids))
874+
875+
case up: Unpivot if !up.childrenResolved || !up.ids.forall(_.resolved) ||
876+
up.values.isEmpty || !up.values.forall(_.resolved) || !up.valuesTypeCoercioned => up
877+
878+
// TypeCoercionBase.UnpivotCoercion determines valueType
879+
// and casts values once values are set and resolved
880+
case Unpivot(ids, values, variableColumnName, valueColumnName, child) =>
881+
// construct unpivot expressions for Expand
882+
val exprs: Seq[Seq[Expression]] = values.map {
883+
value => ids ++ Seq(Literal(value.name), value)
884+
}
885+
886+
// construct output attributes
887+
val output = ids.map(_.toAttribute) ++ Seq(
888+
AttributeReference(variableColumnName, StringType, nullable = false)(),
889+
AttributeReference(valueColumnName, values.head.dataType, values.exists(_.nullable))()
890+
)
891+
892+
// expand the unpivot expressions
893+
Expand(exprs, output, child)
894+
}
895+
}
896+
862897
private def isResolvingView: Boolean = AnalysisContext.get.catalogAndNamespace.nonEmpty
863898
private def isReferredTempViewName(nameParts: Seq[String]): Boolean = {
864899
AnalysisContext.get.referredTempViewNames.exists { n =>
@@ -1349,6 +1384,12 @@ class Analyzer(override val catalogManager: CatalogManager)
13491384
case g: Generate if containsStar(g.generator.children) =>
13501385
throw QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF",
13511386
extractStar(g.generator.children))
1387+
// If the Unpivot ids or values contain Stars, expand them.
1388+
case up: Unpivot if containsStar(up.ids) || containsStar(up.values) =>
1389+
up.copy(
1390+
ids = buildExpandedProjectList(up.ids, up.child),
1391+
values = buildExpandedProjectList(up.values, up.child)
1392+
)
13521393

13531394
case u @ Union(children, _, _)
13541395
// if there are duplicate output columns, give them unique expr ids

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ import org.apache.spark.sql.types._
7474
*/
7575
object AnsiTypeCoercion extends TypeCoercionBase {
7676
override def typeCoercionRules: List[Rule[LogicalPlan]] =
77+
UnpivotCoercion ::
7778
WidenSetOperationTypes ::
7879
new AnsiCombinedTypeCoercionRule(
7980
InConversion ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
422422
}
423423
metrics.foreach(m => checkMetric(m, m))
424424

425+
// see Analyzer.ResolveUnpivot
426+
case up: Unpivot
427+
if up.childrenResolved && up.ids.forall(_.resolved) && up.values.isEmpty =>
428+
throw QueryCompilationErrors.unpivotRequiresValueColumns()
429+
// see TypeCoercionBase.UnpivotCoercion
430+
case up: Unpivot if !up.valuesTypeCoercioned =>
431+
throw QueryCompilationErrors.unpivotValDataTypeMismatchError(up.values)
432+
425433
case Sort(orders, _, _) =>
426434
orders.foreach { order =>
427435
if (!RowOrdering.isOrderable(order.dataType)) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,21 @@ abstract class TypeCoercionBase {
198198
}
199199
}
200200

201+
/**
202+
* Widens the data types of the [[Unpivot]] values.
203+
*/
204+
object UnpivotCoercion extends Rule[LogicalPlan] {
205+
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
206+
case up: Unpivot
207+
if up.values.nonEmpty && up.values.forall(_.resolved) && !up.valuesTypeCoercioned =>
208+
val valueDataType = findWiderTypeWithoutStringPromotion(up.values.map(_.dataType))
209+
val values = valueDataType.map(valueType =>
210+
up.values.map(value => Alias(Cast(value, valueType), value.name)())
211+
).getOrElse(up.values)
212+
up.copy(values = values)
213+
}
214+
}
215+
201216
/**
202217
* Widens the data types of the children of Union/Except/Intersect.
203218
* 1. When ANSI mode is off:
@@ -806,6 +821,7 @@ abstract class TypeCoercionBase {
806821
object TypeCoercion extends TypeCoercionBase {
807822

808823
override def typeCoercionRules: List[Rule[LogicalPlan]] =
824+
UnpivotCoercion ::
809825
WidenSetOperationTypes ::
810826
new CombinedTypeCoercionRule(
811827
InConversion ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,45 @@ case class Pivot(
13541354
override protected def withNewChildInternal(newChild: LogicalPlan): Pivot = copy(child = newChild)
13551355
}
13561356

1357+
/**
1358+
* A constructor for creating an Unpivot, which will later be converted to an [[Expand]]
1359+
* during the query analysis.
1360+
*
1361+
* An empty values array will be replaced during analysis with all resolved outputs of child except
1362+
* the ids. This expansion allows to easily unpivot all non-id columns.
1363+
*
1364+
* @see `org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveUnpivot`
1365+
*
1366+
* The type of the value column is derived from all value columns during analysis once all values
1367+
* are resolved. All values' types have to be compatible, otherwise the result value column cannot
1368+
* be assigned the individual values and an AnalysisException is thrown.
1369+
*
1370+
* @see `org.apache.spark.sql.catalyst.analysis.TypeCoercionBase.UnpivotCoercion`
1371+
*
1372+
* @param ids Id columns
1373+
* @param values Value columns to unpivot
1374+
* @param variableColumnName Name of the variable column
1375+
* @param valueColumnName Name of the value column
1376+
* @param child Child operator
1377+
*/
1378+
case class Unpivot(
1379+
ids: Seq[NamedExpression],
1380+
values: Seq[NamedExpression],
1381+
variableColumnName: String,
1382+
valueColumnName: String,
1383+
child: LogicalPlan) extends UnaryNode {
1384+
override lazy val resolved = false // Unpivot will be replaced after being resolved.
1385+
override def output: Seq[Attribute] = Nil
1386+
override def metadataOutput: Seq[Attribute] = Nil
1387+
final override val nodePatterns: Seq[TreePattern] = Seq(UNPIVOT)
1388+
1389+
override protected def withNewChildInternal(newChild: LogicalPlan): Unpivot =
1390+
copy(child = newChild)
1391+
1392+
def valuesTypeCoercioned: Boolean = values.nonEmpty && values.forall(_.resolved) &&
1393+
values.tail.forall(v => v.dataType.sameType(values.head.dataType))
1394+
}
1395+
13571396
/**
13581397
* A constructor for creating a logical limit, which is split into two separate logical nodes:
13591398
* a [[LocalLimit]], which is a partition local limit, followed by a [[GlobalLimit]].

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ object RuleIdCollection {
7171
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases" ::
7272
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveTables" ::
7373
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveTempViews" ::
74+
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUnpivot" ::
7475
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast" ::
7576
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUserSpecifiedColumns" ::
7677
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ object TreePattern extends Enumeration {
8787
val TRUE_OR_FALSE_LITERAL: Value = Value
8888
val WINDOW_EXPRESSION: Value = Value
8989
val UNARY_POSITIVE: Value = Value
90+
val UNPIVOT: Value = Value
9091
val UPDATE_FIELDS: Value = Value
9192
val UPPER_OR_LOWER: Value = Value
9293
val UP_CAST: Value = Value

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,24 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
9292
pivotVal.toString, pivotVal.dataType.simpleString, pivotCol.dataType.catalogString))
9393
}
9494

95+
def unpivotRequiresValueColumns(): Throwable = {
96+
new AnalysisException(
97+
errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS",
98+
messageParameters = Array.empty)
99+
}
100+
101+
def unpivotValDataTypeMismatchError(values: Seq[NamedExpression]): Throwable = {
102+
val dataTypes = values
103+
.groupBy(_.dataType)
104+
.mapValues(values => values.map(value => toSQLId(value.toString)))
105+
.mapValues(values => if (values.length > 3) values.take(3) :+ "..." else values)
106+
.map { case (dataType, values) => s"${toSQLType(dataType)} (${values.mkString(", ")})" }
107+
108+
new AnalysisException(
109+
errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH",
110+
messageParameters = Array(dataTypes.mkString(", ")))
111+
}
112+
95113
def unsupportedIfNotExistsError(tableName: String): Throwable = {
96114
new AnalysisException(
97115
errorClass = "UNSUPPORTED_FEATURE",

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,7 @@ class Dataset[T] private[sql](
10651065
* @param joinType Type of join to perform. Default `inner`. Must be one of:
10661066
* `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`,
10671067
* `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`,
1068-
* `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, left_anti`.
1068+
* `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`.
10691069
*
10701070
* @note If you perform a self-join using this function without aliasing the input
10711071
* `DataFrame`s, you will NOT be able to reference any columns after the join, since
@@ -2036,6 +2036,142 @@ class Dataset[T] private[sql](
20362036
@scala.annotation.varargs
20372037
def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*)
20382038

2039+
/**
2040+
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
2041+
* This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
2042+
* which cannot be reversed.
2043+
*
2044+
* This function is useful to massage a DataFrame into a format where some
2045+
* columns are identifier columns ("ids"), while all other columns ("values")
2046+
* are "unpivoted" to the rows, leaving just two non-id columns, named as given
2047+
* by `variableColumnName` and `valueColumnName`.
2048+
*
2049+
* {{{
2050+
* val df = Seq((1, 11, 12L), (2, 21, 22L)).toDF("id", "int", "long")
2051+
* df.show()
2052+
* // output:
2053+
* // +---+---+----+
2054+
* // | id|int|long|
2055+
* // +---+---+----+
2056+
* // | 1| 11| 12|
2057+
* // | 2| 21| 22|
2058+
* // +---+---+----+
2059+
*
2060+
* df.unpivot(Array($"id"), Array($"int", $"long"), "variable", "value").show()
2061+
* // output:
2062+
* // +---+--------+-----+
2063+
* // | id|variable|value|
2064+
* // +---+--------+-----+
2065+
* // | 1| int| 11|
2066+
* // | 1| long| 12|
2067+
* // | 2| int| 21|
2068+
* // | 2| long| 22|
2069+
* // +---+--------+-----+
2070+
* // schema:
2071+
* //root
2072+
* // |-- id: integer (nullable = false)
2073+
* // |-- variable: string (nullable = false)
2074+
* // |-- value: long (nullable = true)
2075+
* }}}
2076+
*
2077+
* When no "id" columns are given, the unpivoted DataFrame consists of only the
2078+
* "variable" and "value" columns.
2079+
*
2080+
* All "value" columns must share a least common data type. Unless they are the same data type,
2081+
* all "value" columns are cast to the nearest common data type. For instance,
2082+
* types `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType`
2083+
* do not have a common data type and `unpivot` fails.
2084+
*
2085+
* @param ids Id columns
2086+
* @param values Value columns to unpivot
2087+
* @param variableColumnName Name of the variable column
2088+
* @param valueColumnName Name of the value column
2089+
*
2090+
* @group untypedrel
2091+
* @since 3.4.0
2092+
*/
2093+
def unpivot(
2094+
ids: Array[Column],
2095+
values: Array[Column],
2096+
variableColumnName: String,
2097+
valueColumnName: String): DataFrame = withPlan {
2098+
Unpivot(
2099+
ids.map(_.named),
2100+
values.map(_.named),
2101+
variableColumnName,
2102+
valueColumnName,
2103+
logicalPlan
2104+
)
2105+
}
2106+
2107+
/**
2108+
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
2109+
* This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
2110+
* which cannot be reversed.
2111+
*
2112+
* @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
2113+
*
2114+
* This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)`
2115+
* where `values` is set to all non-id columns that exist in the DataFrame.
2116+
*
2117+
* @param ids Id columns
2118+
* @param variableColumnName Name of the variable column
2119+
* @param valueColumnName Name of the value column
2120+
*
2121+
* @group untypedrel
2122+
* @since 3.4.0
2123+
*/
2124+
def unpivot(
2125+
ids: Array[Column],
2126+
variableColumnName: String,
2127+
valueColumnName: String): DataFrame =
2128+
unpivot(ids, Array.empty, variableColumnName, valueColumnName)
2129+
2130+
/**
2131+
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
2132+
* This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
2133+
* which cannot be reversed. This is an alias for `unpivot`.
2134+
*
2135+
* @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
2136+
*
2137+
* @param ids Id columns
2138+
* @param values Value columns to unpivot
2139+
* @param variableColumnName Name of the variable column
2140+
* @param valueColumnName Name of the value column
2141+
*
2142+
* @group untypedrel
2143+
* @since 3.4.0
2144+
*/
2145+
def melt(
2146+
ids: Array[Column],
2147+
values: Array[Column],
2148+
variableColumnName: String,
2149+
valueColumnName: String): DataFrame =
2150+
unpivot(ids, values, variableColumnName, valueColumnName)
2151+
2152+
/**
2153+
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
2154+
* This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
2155+
* which cannot be reversed. This is an alias for `unpivot`.
2156+
*
2157+
* @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
2158+
*
2159+
* This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)`
2160+
* where `values` is set to all non-id columns that exist in the DataFrame.
2161+
*
2162+
* @param ids Id columns
2163+
* @param variableColumnName Name of the variable column
2164+
* @param valueColumnName Name of the value column
2165+
*
2166+
* @group untypedrel
2167+
* @since 3.4.0
2168+
*/
2169+
def melt(
2170+
ids: Array[Column],
2171+
variableColumnName: String,
2172+
valueColumnName: String): DataFrame =
2173+
unpivot(ids, variableColumnName, valueColumnName)
2174+
20392175
/**
20402176
* Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset
20412177
* that returns the same result as the input, with the following guarantees:

0 commit comments

Comments
 (0)