-
Notifications
You must be signed in to change notification settings - Fork 29k
[SQL][WIP][Test] Supports object-based aggregation function which can store arbitrary objects in aggregation buffer. #14723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -389,3 +389,89 @@ abstract class DeclarativeAggregate | |
| def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * This traits allow user to define an AggregateFunction which can store **arbitrary** Java objects | ||
| * in Aggregation buffer during aggregation of each key group. This trait must be mixed with | ||
| * class ImperativeAggregate. | ||
|
||
| * | ||
| * Here is how it works in a typical aggregation flow (Partial mode aggregate at Mapper side, and | ||
| * Final mode aggregate at Reducer side). | ||
| * | ||
| * Stage 1: Partial aggregate at Mapper side: | ||
| * | ||
| * 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores an arbitrary empty | ||
| * object, object A for example, in aggBuffer. The object A will be used to store the | ||
| * accumulated aggregation result. | ||
| * 1. Upon calling method `update(mutableAggBuffer: MutableRow, inputRow: InternalRow)` in | ||
| * current group (group by key), user extracts object A from mutableAggBuffer, and then updates | ||
| * object A with current inputRow. After updating, object A is stored back to mutableAggBuffer. | ||
| * 1. After processing all rows of current group, the framework will call method | ||
| * `serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow)` to serialize object A | ||
| * to a serializable format in place. | ||
|
||
| * 1. The framework may spill the aggregationBuffer to disk if there is not enough memory. | ||
| * It is safe since we have already convert aggregationBuffer to serializable format. | ||
| * 1. Spark framework moves on to next group, until all groups have been | ||
|
||
| * processed. | ||
| * | ||
| * Shuffling exchange data to Reducer tasks... | ||
| * | ||
| * Stage 2: Final mode aggregate at Reducer side: | ||
| * | ||
| * 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores a new empty object A1 | ||
| * in aggBuffer. The object A1 will be used to store the accumulated aggregation result. | ||
|
||
| * 1. Upon calling method `merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow)`, user | ||
| * extracts object A1 from mutableAggBuffer, and extracts object A2 from inputAggBuffer. then | ||
|
||
| * user needs to merge A1, and A2, and stores the merged result back to mutableAggBuffer. | ||
| * 1. After processing all inputAggBuffer of current group (group by key), the Spark framework will | ||
| * call method `serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow)` to | ||
| * serialize object A1 to a serializable format in place. | ||
| * 1. The Spark framework may spill the aggregationBuffer to disk if there is not enough memory. | ||
| * It is safe since we have already convert aggregationBuffer to serializable format. | ||
| * 1. Spark framework moves on to next group, until all groups have been processed. | ||
| */ | ||
| trait WithObjectAggregateBuffer { | ||
| this: ImperativeAggregate => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Semes we do not really need this line.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess having this line will make this trait hard to be used in Java.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, seems this trait will be still an java |
||
|
|
||
| /** | ||
| * Serializes and in-place replaces the object stored in Aggregation buffer. The framework | ||
| * calls this method every time after finishing updating/merging one group (group by key). | ||
| * | ||
| * aggregationBuffer before serialization: | ||
| * | ||
| * The object stored in aggregationBuffer can be **arbitrary** Java objects defined by user. | ||
|
||
| * | ||
| * aggregationBuffer after serialization: | ||
| * | ||
| * The object's type must be one of: | ||
|
||
| * | ||
| * - Null | ||
| * - Boolean | ||
| * - Byte | ||
| * - Short | ||
| * - Int | ||
| * - Long | ||
| * - Float | ||
| * - Double | ||
| * - Array[Byte] | ||
| * - org.apache.spark.sql.types.Decimal | ||
| * - org.apache.spark.unsafe.types.UTF8String | ||
| * - org.apache.spark.unsafe.types.CalendarInterval | ||
| * - org.apache.spark.sql.catalyst.util.MapData | ||
| * - org.apache.spark.sql.catalyst.util.ArrayData | ||
| * - org.apache.spark.sql.catalyst.InternalRow | ||
| * | ||
| * Code example: | ||
| * | ||
| * {{{ | ||
| * override def serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow): Unit = { | ||
| * val obj = buffer.get(mutableAggBufferOffset, ObjectType(classOf[A])).asInstanceOf[A] | ||
| * // Convert the obj to bytes, which is a serializable format. | ||
| * buffer(mutableAggBufferOffset) = toBytes(obj) | ||
|
||
| * } | ||
| * }}} | ||
| * | ||
| * @param aggregationBuffer aggregation buffer before serialization | ||
| */ | ||
| def serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow): Unit | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate | |
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, WithObjectAggregateBuffer} | ||
| import org.apache.spark.sql.execution.metric.SQLMetric | ||
|
|
||
| /** | ||
|
|
@@ -54,7 +54,15 @@ class SortBasedAggregationIterator( | |
| val bufferRowSize: Int = bufferSchema.length | ||
|
|
||
| val genericMutableBuffer = new GenericMutableRow(bufferRowSize) | ||
| val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) | ||
|
|
||
| val allFieldsMutable = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) | ||
|
|
||
| val isUsingObjectAggregationBuffer = aggregateFunctions.exists { | ||
| case agg: WithObjectAggregateBuffer => true | ||
| case _ => false | ||
| } | ||
|
|
||
| val useUnsafeBuffer = allFieldsMutable && !isUsingObjectAggregationBuffer | ||
|
|
||
| val buffer = if (useUnsafeBuffer) { | ||
| val unsafeProjection = | ||
|
|
@@ -90,6 +98,21 @@ class SortBasedAggregationIterator( | |
| // compared to MutableRow (aggregation buffer) directly. | ||
| private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) | ||
|
|
||
| // AggregationFunction which store generic object in AggregationBuffer. | ||
| // @see [[WithObjectAggregationBuffer]] for more information | ||
| private val aggFunctionsWithObjectAggregationBuffer = aggregateFunctions.collect { | ||
| case (ag: ImperativeAggregate with WithObjectAggregateBuffer) => ag | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about we make
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ImperativeAggregate is an abstract class, that will make
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Heavy? In what sense? |
||
| } | ||
|
|
||
| // For AggregateFunction with generic object stored in aggregation buffer, we need to | ||
| // call serializeObjectAggregationBufferInPlace() explicitly to convert the generic object | ||
| // stored in aggregation buffer to serializable format. | ||
| private def serializeObjectAggregationBuffer(aggregationBuffer: MutableRow): Unit = { | ||
| aggFunctionsWithObjectAggregationBuffer.foreach { agg => | ||
| agg.serializeObjectAggregationBufferInPlace(sortBasedAggregationBuffer) | ||
| } | ||
| } | ||
|
|
||
| protected def initialize(): Unit = { | ||
| if (inputIterator.hasNext) { | ||
| initializeBuffer(sortBasedAggregationBuffer) | ||
|
|
@@ -131,6 +154,10 @@ class SortBasedAggregationIterator( | |
| firstRowInNextGroup = currentRow.copy() | ||
| } | ||
| } | ||
|
|
||
| // Serializes the generic object stored in aggregation buffer. | ||
| serializeObjectAggregationBuffer(sortBasedAggregationBuffer) | ||
|
|
||
| // We have not seen a new group. It means that there is no new row in the input | ||
| // iter. The current group is the last group of the iter. | ||
| if (!findNextPartition) { | ||
|
|
@@ -162,6 +189,8 @@ class SortBasedAggregationIterator( | |
|
|
||
| def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { | ||
| initializeBuffer(sortBasedAggregationBuffer) | ||
| // Serializes the generic object stored in aggregation buffer. | ||
| serializeObjectAggregationBuffer(sortBasedAggregationBuffer) | ||
| generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import org.apache.spark.sql.AggregateWithObjectAggregateBufferSuite.MaxWithObjectAggregateBuffer | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GenericMutableRow, MutableRow, UnsafeRow} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, WithObjectAggregateBuffer} | ||
| import org.apache.spark.sql.execution.aggregate.{SortAggregateExec} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.test.SharedSQLContext | ||
| import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, StructType} | ||
|
|
||
| class AggregateWithObjectAggregateBufferSuite extends QueryTest with SharedSQLContext { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also put a basic test in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will not use HashAggregationExec, so there is no point to fallback from HashAggregationExec?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh right, I misread the code. |
||
|
|
||
| import testImplicits._ | ||
|
|
||
| private val data = Seq((1, 0), (3, 1), (2, 0), (6, 3), (3, 1), (4, 1), (5, 0)) | ||
|
|
||
|
|
||
| test("aggregate with object aggregate buffer, should not use HashAggregate") { | ||
| val df = data.toDF("a", "b") | ||
| val max = new MaxWithObjectAggregateBuffer($"a".expr) | ||
|
|
||
| // Always use SortAggregateExec instead of HashAggregateExec for planning even if the aggregate | ||
| // buffer attributes are mutable fields (every field can be mutated inline like int, long...) | ||
| val allFieldsMutable = max.aggBufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) | ||
| val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan | ||
| assert(allFieldsMutable == true && sparkPlan.isInstanceOf[SortAggregateExec]) | ||
| } | ||
|
|
||
| test("aggregate with object aggregate buffer, no group by") { | ||
| val df = data.toDF("a", "b").coalesce(2) | ||
| checkAnswer( | ||
| df.select(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"b"), count($"b")), | ||
| Seq(Row(6, 7, 3, 7)) | ||
| ) | ||
| } | ||
|
|
||
| test("aggregate with object aggregate buffer, with group by") { | ||
| val df = data.toDF("a", "b").coalesce(2) | ||
| checkAnswer( | ||
| df.groupBy($"b").agg(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"a")), | ||
| Seq( | ||
| Row(0, 5, 3, 5), | ||
| Row(1, 4, 3, 4), | ||
| Row(3, 6, 1, 6) | ||
| ) | ||
| ) | ||
| } | ||
|
|
||
| test("aggregate with object aggregate buffer, empty inputs, no group by") { | ||
| val empty = Seq.empty[(Int, Int)].toDF("a", "b") | ||
| checkAnswer( | ||
| empty.select(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"b"), count($"b")), | ||
| Seq(Row(Int.MinValue, 0, Int.MinValue, 0))) | ||
| } | ||
|
|
||
| test("aggregate with object aggregate buffer, empty inputs, with group by") { | ||
| val empty = Seq.empty[(Int, Int)].toDF("a", "b") | ||
| checkAnswer( | ||
| empty.groupBy($"b").agg(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"a")), | ||
| Seq.empty[Row]) | ||
| } | ||
|
|
||
| private def objectAggregateMax(column: Column): Column = { | ||
| val max = MaxWithObjectAggregateBuffer(column.expr) | ||
| Column(max.toAggregateExpression()) | ||
| } | ||
| } | ||
|
|
||
| object AggregateWithObjectAggregateBufferSuite { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (we do not need to put the example class inside this object.)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I use the companion object to define a private scope. |
||
|
|
||
| /** | ||
| * Calculate the max value with object aggregation buffer. This stores object of class MaxValue | ||
| * in aggregation buffer. | ||
| */ | ||
| private case class MaxWithObjectAggregateBuffer( | ||
| child: Expression, | ||
| mutableAggBufferOffset: Int = 0, | ||
| inputAggBufferOffset: Int = 0) extends ImperativeAggregate with WithObjectAggregateBuffer { | ||
|
|
||
| override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate = | ||
| copy(mutableAggBufferOffset = newOffset) | ||
|
|
||
| override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate = | ||
| copy(inputAggBufferOffset = newOffset) | ||
|
|
||
| // Stores a generic object MaxValue in aggregation buffer. | ||
| override def initialize(buffer: MutableRow): Unit = { | ||
| // Makes sure we are using an unsafe row for aggregation buffer. | ||
| assert(buffer.isInstanceOf[GenericMutableRow]) | ||
| buffer.update(mutableAggBufferOffset, new MaxValue(Int.MinValue)) | ||
| } | ||
|
|
||
| override def update(buffer: MutableRow, input: InternalRow): Unit = { | ||
| val inputValue = child.eval(input).asInstanceOf[Int] | ||
| val maxValue = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue] | ||
| if (inputValue > maxValue.value) { | ||
| maxValue.value = inputValue | ||
| } | ||
| } | ||
|
|
||
| override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { | ||
| val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue] | ||
| val inputMax = deserialize(inputBuffer, inputAggBufferOffset) | ||
| if (inputMax.value > bufferMax.value) { | ||
| bufferMax.value = inputMax.value | ||
| } | ||
| } | ||
|
|
||
| private def deserialize(buffer: InternalRow, offset: Int): MaxValue = { | ||
| new MaxValue((buffer.getInt(offset))) | ||
| } | ||
|
|
||
| override def serializeObjectAggregationBufferInPlace(buffer: MutableRow): Unit = { | ||
| val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue] | ||
| buffer(mutableAggBufferOffset) = bufferMax.value | ||
| } | ||
|
|
||
| override def eval(buffer: InternalRow): Any = { | ||
| val max = deserialize(buffer, mutableAggBufferOffset) | ||
| max.value | ||
| } | ||
|
|
||
| override val aggBufferAttributes: Seq[AttributeReference] = | ||
| Seq(AttributeReference("buf", IntegerType)()) | ||
|
|
||
| override lazy val inputAggBufferAttributes: Seq[AttributeReference] = | ||
| aggBufferAttributes.map(_.newInstance()) | ||
|
|
||
| override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) | ||
| override def dataType: DataType = IntegerType | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) | ||
| override def nullable: Boolean = true | ||
| override def deterministic: Boolean = false | ||
| override def children: Seq[Expression] = Seq(child) | ||
| } | ||
|
|
||
| private class MaxValue(var value: Int) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This trait allows an AggregateFunction to use ...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to remove
allow usersbecause it is not exposed to end-users for defining UDAFs.