diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9965cd654bcb..b2e99b325f93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -529,6 +529,10 @@ class Analyzer( || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + if (!RowOrdering.isOrderable(pivotColumn.dataType)) { + throw new AnalysisException( + s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.") + } // Check all aggregate expressions. aggregates.foreach(checkValidAggregateExpression) // Check all pivot values are literal and match pivot column data type. @@ -574,10 +578,14 @@ class Analyzer( // Since evaluating |pivotValues| if statements for each input row can get slow this is an // alternate plan that instead uses two steps of aggregation. val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) - val bigGroup = groupByExprs ++ pivotColumn.references + val namedPivotCol = pivotColumn match { + case n: NamedExpression => n + case _ => Alias(pivotColumn, "__pivot_col")() + } + val bigGroup = groupByExprs :+ namedPivotCol val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) val pivotAggs = namedAggExps.map { a => - Alias(PivotFirst(pivotColumn, a.toAttribute, evalPivotValues) + Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 523714869242..33bc5b5821b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import scala.collection.immutable.HashMap +import scala.collection.immutable.{HashMap, TreeMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ object PivotFirst { @@ -83,7 +83,12 @@ case class PivotFirst( override val dataType: DataType = ArrayType(valueDataType) - val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*) + val pivotIndex = if (pivotColumn.dataType.isInstanceOf[AtomicType]) { + HashMap(pivotColumnValues.zipWithIndex: _*) + } else { + TreeMap(pivotColumnValues.zipWithIndex: _*)( + TypeUtils.getInterpretedOrdering(pivotColumn.dataType)) + } val indexSize = pivotIndex.size diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index a6c8d4854ff3..1f607b334dc1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -11,10 +11,10 @@ create temporary view years as select * from values (2013, 2) as years(y, s); -create temporary view yearsWithArray as select * from values - (2012, array(1, 1)), - (2013, array(2, 2)) - as yearsWithArray(y, a); +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s); -- pivot courses SELECT * FROM ( @@ -204,7 +204,7 @@ PIVOT ( SELECT * FROM ( SELECT course, year, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( min(a) @@ -215,9 +215,75 @@ PIVOT ( SELECT * FROM ( SELECT course, year, y, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( max(a) FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) ); + +-- pivot on pivot column of array type +SELECT * FROM ( + SELECT earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR a IN (array(1, 1), array(2, 2)) +); + +-- pivot on multiple pivot columns containing array type +SELECT * FROM ( + SELECT course, earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2))) +); + +-- pivot on pivot column of struct type +SELECT * FROM ( + SELECT earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN ((1, 'a'), (2, 'b')) +); + +-- pivot on multiple pivot columns containing struct type +SELECT * FROM ( + SELECT course, earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b'))) +); + +-- pivot on pivot column of map type +SELECT * FROM ( + SELECT earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR m IN (map('1', 1), map('2', 2)) +); + +-- pivot on multiple pivot columns containing map type +SELECT * FROM ( + SELECT course, earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 6bb51b946f96..2dd92930f92a 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 25 +-- Number of queries: 31 -- !query 0 @@ -28,10 +28,10 @@ struct<> -- !query 2 -create temporary view yearsWithArray as select * from values - (2012, array(1, 1)), - (2013, array(2, 2)) - as yearsWithArray(y, a) +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s) -- !query 2 schema struct<> -- !query 2 output @@ -346,7 +346,7 @@ Literal expressions required for pivot values, found 'course#x'; SELECT * FROM ( SELECT course, year, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( min(a) @@ -363,7 +363,7 @@ struct,Java:array> SELECT * FROM ( SELECT course, year, y, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( max(a) @@ -374,3 +374,105 @@ struct,[2013, Java]:array> -- !query 24 output 2012 [1,1] NULL 2013 NULL [2,2] + + +-- !query 25 +SELECT * FROM ( + SELECT earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR a IN (array(1, 1), array(2, 2)) +) +-- !query 25 schema +struct +-- !query 25 output +2012 35000 NULL +2013 NULL 78000 + + +-- !query 26 +SELECT * FROM ( + SELECT course, earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2))) +) +-- !query 26 schema +struct +-- !query 26 output +2012 15000 NULL +2013 NULL 30000 + + +-- !query 27 +SELECT * FROM ( + SELECT earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN ((1, 'a'), (2, 'b')) +) +-- !query 27 schema +struct +-- !query 27 output +2012 35000 NULL +2013 NULL 78000 + + +-- !query 28 +SELECT * FROM ( + SELECT course, earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b'))) +) +-- !query 28 schema +struct +-- !query 28 output +2012 15000 NULL +2013 NULL 30000 + + +-- !query 29 +SELECT * FROM ( + SELECT earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR m IN (map('1', 1), map('2', 2)) +) +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +Invalid pivot column 'm#x'. Pivot columns must be comparable.; + + +-- !query 30 +SELECT * FROM ( + SELECT course, earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) +) +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +Invalid pivot column 'named_struct(course, course#x, m, m#x)'. Pivot columns must be comparable.;