Skip to content

Commit 8cdb81f

Browse files
koertkuiperscloud-fan
authored andcommitted
[SPARK-15204][SQL] improve nullability inference for Aggregator
## What changes were proposed in this pull request? TypedAggregateExpression sets nullable based on the schema of the outputEncoder ## How was this patch tested? Add test in DatasetAggregatorSuite Author: Koert Kuipers <[email protected]> Closes #13532 from koertkuipers/feat-aggregator-nullable.
1 parent 88134e7 commit 8cdb81f

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ object TypedAggregateExpression {
5151
bufferDeserializer,
5252
outputEncoder.serializer,
5353
outputEncoder.deserializer.dataType,
54-
outputType)
54+
outputType,
55+
!outputEncoder.flat || outputEncoder.schema.head.nullable)
5556
}
5657
}
5758

@@ -65,9 +66,8 @@ case class TypedAggregateExpression(
6566
bufferDeserializer: Expression,
6667
outputSerializer: Seq[Expression],
6768
outputExternalType: DataType,
68-
dataType: DataType) extends DeclarativeAggregate with NonSQLExpression {
69-
70-
override def nullable: Boolean = true
69+
dataType: DataType,
70+
nullable: Boolean) extends DeclarativeAggregate with NonSQLExpression {
7171

7272
override def deterministic: Boolean = true
7373

sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,13 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
305305
val ds = Seq(1, 2, 3).toDS()
306306
checkDataset(ds.select(MapTypeBufferAgg.toColumn), 1)
307307
}
308+
309+
test("SPARK-15204 improve nullability inference for Aggregator") {
310+
val ds1 = Seq(1, 3, 2, 5).toDS()
311+
assert(ds1.select(typed.sum((i: Int) => i)).schema.head.nullable === false)
312+
val ds2 = Seq(AggData(1, "a"), AggData(2, "a")).toDS()
313+
assert(ds2.select(SeqAgg.toColumn).schema.head.nullable === true)
314+
val ds3 = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData]
315+
assert(ds3.select(NameAgg.toColumn).schema.head.nullable === true)
316+
}
308317
}

0 commit comments

Comments
 (0)