Skip to content

Commit d57164a

Browse files
stevomitricstefankandic
authored andcommitted
[SPARK-47430][SQL] Support GROUP BY for MapType
### What changes were proposed in this pull request? Changes proposed in this PR include: - Relaxed checks that prevent aggregating of map types - Added new analyzer rule that uses `MapSort` expression proposed in [this PR](#45639) - Created codegen that compares two sorted maps ### Why are the changes needed? Adding new functionality to GROUP BY map types ### Does this PR introduce _any_ user-facing change? Yes, ability to use `GROUP BY MapType` ### How was this patch tested? With new UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #45549 from stevomitric/stevomitric/map-group-by. Lead-authored-by: Stevo Mitric <stevo.mitric@databricks.com> Co-authored-by: Stefan Kandic <stefan.kandic@databricks.com> Co-authored-by: Stevo Mitric <stevomitric2000@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent b540cc5 commit d57164a

12 files changed

Lines changed: 161 additions & 157 deletions

File tree

common/utils/src/main/resources/error/error-classes.json

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,12 +1373,6 @@
13731373
],
13741374
"sqlState" : "42805"
13751375
},
1376-
"GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE" : {
1377-
"message" : [
1378-
"The expression <sqlExpr> cannot be used as a grouping expression because its data type <dataType> is not an orderable data type."
1379-
],
1380-
"sqlState" : "42822"
1381-
},
13821376
"HLL_INVALID_INPUT_SKETCH_BUFFER" : {
13831377
"message" : [
13841378
"Invalid call to <function>; only valid HLL sketch buffers are supported as inputs (such as those produced by the `hll_sketch_agg` function)."

docs/sql-error-conditions.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -852,12 +852,6 @@ GROUP BY `<index>` refers to an expression `<aggExpr>` that contains an aggregat
852852

853853
GROUP BY position `<index>` is not in select list (valid range is [1, `<size>`]).
854854

855-
### GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE
856-
857-
[SQLSTATE: 42822](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
858-
859-
The expression `<sqlExpr>` cannot be used as a grouping expression because its data type `<dataType>` is not an orderable data type.
860-
861855
### HLL_INVALID_INPUT_SKETCH_BUFFER
862856

863857
[SQLSTATE: 22546](sql-error-conditions-sqlstates.html#class-22-data-exception)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,6 @@ object ExprUtils extends QueryErrorsBase {
193193
messageParameters = Map("sqlExpr" -> expr.sql))
194194
}
195195

196-
// Check if the data type of expr is orderable.
197-
if (expr.dataType.existsRecursively(_.isInstanceOf[MapType])) {
198-
expr.failAnalysis(
199-
errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
200-
messageParameters = Map(
201-
"sqlExpr" -> toSQLExpr(expr),
202-
"dataType" -> toSQLType(expr.dataType)))
203-
}
204-
205196
if (!expr.deterministic) {
206197
// This is just a sanity check, our analysis rule PullOutNondeterministic should
207198
// already pull out those nondeterministic expressions and evaluate them in

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -660,13 +660,8 @@ class CodegenContext extends Logging {
660660
case NullType => "0"
661661
case array: ArrayType =>
662662
val elementType = array.elementType
663-
val elementA = freshName("elementA")
664-
val isNullA = freshName("isNullA")
665-
val elementB = freshName("elementB")
666-
val isNullB = freshName("isNullB")
667663
val compareFunc = freshName("compareArray")
668664
val minLength = freshName("minLength")
669-
val jt = javaType(elementType)
670665
val funcCode: String =
671666
s"""
672667
public int $compareFunc(ArrayData a, ArrayData b) {
@@ -679,22 +674,7 @@ class CodegenContext extends Logging {
679674
int lengthB = b.numElements();
680675
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
681676
for (int i = 0; i < $minLength; i++) {
682-
boolean $isNullA = a.isNullAt(i);
683-
boolean $isNullB = b.isNullAt(i);
684-
if ($isNullA && $isNullB) {
685-
// Nothing
686-
} else if ($isNullA) {
687-
return -1;
688-
} else if ($isNullB) {
689-
return 1;
690-
} else {
691-
$jt $elementA = ${getValue("a", elementType, "i")};
692-
$jt $elementB = ${getValue("b", elementType, "i")};
693-
int comp = ${genComp(elementType, elementA, elementB)};
694-
if (comp != 0) {
695-
return comp;
696-
}
697-
}
677+
${genCompElementsAt("a", "b", "i", elementType)}
698678
}
699679

700680
if (lengthA < lengthB) {
@@ -722,12 +702,71 @@ class CodegenContext extends Logging {
722702
}
723703
"""
724704
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
705+
case map: MapType =>
706+
val compareFunc = freshName("compareMapData")
707+
val funcCode = genCompMapData(map.keyType, map.valueType, compareFunc)
708+
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
725709
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
726710
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
727711
case _ =>
728712
throw QueryExecutionErrors.cannotGenerateCodeForIncomparableTypeError("compare", dataType)
729713
}
730714

715+
private def genCompMapData(
716+
keyType: DataType,
717+
valueType: DataType,
718+
compareFunc: String): String = {
719+
s"""
720+
|public int $compareFunc(MapData a, MapData b) {
721+
| int lengthA = a.numElements();
722+
| int lengthB = b.numElements();
723+
| ArrayData keyArrayA = a.keyArray();
724+
| ArrayData valueArrayA = a.valueArray();
725+
| ArrayData keyArrayB = b.keyArray();
726+
| ArrayData valueArrayB = b.valueArray();
727+
| int minLength = (lengthA > lengthB) ? lengthB : lengthA;
728+
| for (int i = 0; i < minLength; i++) {
729+
| ${genCompElementsAt("keyArrayA", "keyArrayB", "i", keyType)}
730+
| ${genCompElementsAt("valueArrayA", "valueArrayB", "i", valueType)}
731+
| }
732+
|
733+
| if (lengthA < lengthB) {
734+
| return -1;
735+
| } else if (lengthA > lengthB) {
736+
| return 1;
737+
| }
738+
| return 0;
739+
|}
740+
""".stripMargin
741+
}
742+
743+
private def genCompElementsAt(arrayA: String, arrayB: String, i: String,
744+
elementType : DataType): String = {
745+
val elementA = freshName("elementA")
746+
val isNullA = freshName("isNullA")
747+
val elementB = freshName("elementB")
748+
val isNullB = freshName("isNullB")
749+
val jt = javaType(elementType);
750+
s"""
751+
|boolean $isNullA = $arrayA.isNullAt($i);
752+
|boolean $isNullB = $arrayB.isNullAt($i);
753+
|if ($isNullA && $isNullB) {
754+
| // Nothing
755+
|} else if ($isNullA) {
756+
| return -1;
757+
|} else if ($isNullB) {
758+
| return 1;
759+
|} else {
760+
| $jt $elementA = ${getValue(arrayA, elementType, i)};
761+
| $jt $elementB = ${getValue(arrayB, elementType, i)};
762+
| int comp = ${genComp(elementType, elementA, elementB)};
763+
| if (comp != 0) {
764+
| return comp;
765+
| }
766+
|}
767+
""".stripMargin
768+
}
769+
731770
/**
732771
* Generates code for greater of two expressions.
733772
*
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.expressions.MapSort
21+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
22+
import org.apache.spark.sql.catalyst.rules.Rule
23+
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
24+
import org.apache.spark.sql.types.MapType
25+
26+
/**
27+
* Adds MapSort to group expressions containing map columns, as the key/value paris need to be
28+
* in the correct order before grouping:
29+
* SELECT COUNT(*) FROM TABLE GROUP BY map_column =>
30+
* SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column)
31+
*/
32+
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
33+
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
34+
_.containsPattern(AGGREGATE), ruleId) {
35+
case a @ Aggregate(groupingExpr, _, _) =>
36+
val newGrouping = groupingExpr.map { expr =>
37+
if (!expr.isInstanceOf[MapSort] && expr.dataType.isInstanceOf[MapType]) {
38+
MapSort(expr)
39+
} else {
40+
expr
41+
}
42+
}
43+
a.copy(groupingExpressions = newGrouping)
44+
}
45+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.SparkException
21-
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
21+
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, TransformValues, UnaryExpression}
2222
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2323
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
2424
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window}
@@ -98,9 +98,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
9898
case FloatType | DoubleType => true
9999
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
100100
case ArrayType(et, _) => needNormalize(et)
101-
// Currently MapType is not comparable and analyzer should fail earlier if this case happens.
102-
case _: MapType =>
103-
throw SparkException.internalError("grouping/join/window partition keys cannot be map type.")
101+
case MapType(_, vt, _) => needNormalize(vt)
104102
case _ => false
105103
}
106104

@@ -144,6 +142,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
144142
val function = normalize(lv)
145143
KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv))))
146144

145+
case _ if expr.dataType.isInstanceOf[MapType] =>
146+
val MapType(kt, vt, containsNull) = expr.dataType
147+
val keys = NamedLambdaVariable("arg", kt, containsNull)
148+
val values = NamedLambdaVariable("arg", vt, containsNull)
149+
val function = normalize(values)
150+
KnownFloatingPointNormalized(TransformValues(expr,
151+
LambdaFunction(function, Seq(keys, values))))
152+
147153
case _ => throw SparkException.internalError(s"fail to normalize $expr")
148154
}
149155

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
243243
CollapseProject,
244244
RemoveRedundantAliases,
245245
RemoveNoopOperators) :+
246+
Batch("InsertMapSortInGroupingExpressions", Once,
247+
InsertMapSortInGroupingExpressions) :+
246248
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
247249
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
248250
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, U
2828
import org.apache.spark.sql.catalyst.types.DataTypeUtils
2929
import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
3030
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
31-
import org.apache.spark.sql.types.{MapType, StructType}
31+
import org.apache.spark.sql.types.StructType
3232

3333

3434
abstract class LogicalPlan
@@ -348,23 +348,6 @@ object LogicalPlanIntegrity {
348348
}.flatten
349349
}
350350

351-
/**
352-
* Validate that the grouping key types in Aggregate plans are valid.
353-
* Returns an error message if the check fails, or None if it succeeds.
354-
*/
355-
def validateGroupByTypes(plan: LogicalPlan): Option[String] = {
356-
plan.collectFirst {
357-
case a @ Aggregate(groupingExprs, _, _) =>
358-
val badExprs = groupingExprs.filter(_.dataType.isInstanceOf[MapType]).map(_.toString)
359-
if (badExprs.nonEmpty) {
360-
Some(s"Grouping expressions ${badExprs.mkString(", ")} cannot be of type Map " +
361-
s"for plan:\n ${a.treeString}")
362-
} else {
363-
None
364-
}
365-
}.flatten
366-
}
367-
368351
/**
369352
* Validate that the aggregation expressions in Aggregate plans are valid.
370353
* Returns an error message if the check fails, or None if it succeeds.
@@ -417,7 +400,6 @@ object LogicalPlanIntegrity {
417400
.orElse(LogicalPlanIntegrity.validateExprIdUniqueness(currentPlan))
418401
.orElse(LogicalPlanIntegrity.validateSchemaOutput(previousPlan, currentPlan))
419402
.orElse(LogicalPlanIntegrity.validateNoDanglingReferences(currentPlan))
420-
.orElse(LogicalPlanIntegrity.validateGroupByTypes(currentPlan))
421403
.orElse(LogicalPlanIntegrity.validateAggregateExpressions(currentPlan))
422404
.map(err => s"${err}\nPrevious schema:${previousPlan.output.mkString(", ")}" +
423405
s"\nPrevious plan: ${previousPlan.treeString}")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ object RuleIdCollection {
126126
"org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
127127
"org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" ::
128128
"org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" ::
129+
"org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" ::
129130
"org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
130131
"org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
131132
"org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::

0 commit comments

Comments
 (0)