Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,89 @@ abstract class DeclarativeAggregate
def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
}
}

/**
* This traits allow user to define an AggregateFunction which can store **arbitrary** Java objects
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trait allows an AggregateFunction to use ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to remove allow users because it is not exposed to end-users for defining UDAFs.

* in Aggregation buffer during aggregation of each key group. This trait must be mixed with
* class ImperativeAggregate.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at here, we need to emphasize that the buffer is an internal buffer because we will emit this buffer as the result of an aggregate operator.

*
* Here is how it works in a typical aggregation flow (Partial mode aggregate at Mapper side, and
* Final mode aggregate at Reducer side).
*
* Stage 1: Partial aggregate at Mapper side:
*
* 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores an arbitrary empty
* object, object A for example, in aggBuffer. The object A will be used to store the
* accumulated aggregation result.
* 1. Upon calling method `update(mutableAggBuffer: MutableRow, inputRow: InternalRow)` in
* current group (group by key), user extracts object A from mutableAggBuffer, and then updates
* object A with current inputRow. After updating, object A is stored back to mutableAggBuffer.
* 1. After processing all rows of current group, the framework will call method
* `serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow)` to serialize object A
* to a serializable format in place.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to a Spark SQL internal format(mostly BinaryType) in place

* 1. The framework may spill the aggregationBuffer to disk if there is not enough memory.
* It is safe since we have already convert aggregationBuffer to serializable format.
* 1. Spark framework moves on to next group, until all groups have been
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So many 1. : )

* processed.
*
* Shuffling exchange data to Reducer tasks...
*
* Stage 2: Final mode aggregate at Reducer side:
*
* 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores a new empty object A1
* in aggBuffer. The object A1 will be used to store the accumulated aggregation result.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accumulated aggregation result? it's still buffer right?

* 1. Upon calling method `merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow)`, user
* extracts object A1 from mutableAggBuffer, and extracts object A2 from inputAggBuffer. then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should extract the Spark SQL format from inputAggBuffer and deserialize it to object A2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is 'merge',

* user needs to merge A1, and A2, and stores the merged result back to mutableAggBuffer.
* 1. After processing all inputAggBuffer of current group (group by key), the Spark framework will
* call method `serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow)` to
* serialize object A1 to a serializable format in place.
* 1. The Spark framework may spill the aggregationBuffer to disk if there is not enough memory.
* It is safe since we have already convert aggregationBuffer to serializable format.
* 1. Spark framework moves on to next group, until all groups have been processed.
*/
trait WithObjectAggregateBuffer {
this: ImperativeAggregate =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semes we do not really need this line.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess having this line will make this trait hard to be used in Java.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, seems this trait will be still an java interface. But, I think in general, we do not really need to have this line.


/**
* Serializes and in-place replaces the object stored in Aggregation buffer. The framework
* calls this method every time after finishing updating/merging one group (group by key).
*
* aggregationBuffer before serialization:
*
* The object stored in aggregationBuffer can be **arbitrary** Java objects defined by user.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we want to mention that the data type is ObjectType?

*
* aggregationBuffer after serialization:
*
* The object's type must be one of:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we rephrase this part? We mentioned that we can use arbitrary java objects. But, here we are saying that The object's type must be one of:.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will rephrase this part. I meant to say object type after serialization

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about The serialized object must be Spark SQL internal format.

*
* - Null
* - Boolean
* - Byte
* - Short
* - Int
* - Long
* - Float
* - Double
* - Array[Byte]
* - org.apache.spark.sql.types.Decimal
* - org.apache.spark.unsafe.types.UTF8String
* - org.apache.spark.unsafe.types.CalendarInterval
* - org.apache.spark.sql.catalyst.util.MapData
* - org.apache.spark.sql.catalyst.util.ArrayData
* - org.apache.spark.sql.catalyst.InternalRow
*
* Code example:
*
* {{{
* override def serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow): Unit = {
* val obj = buffer.get(mutableAggBufferOffset, ObjectType(classOf[A])).asInstanceOf[A]
* // Convert the obj to bytes, which is a serializable format.
* buffer(mutableAggBufferOffset) = toBytes(obj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it is the best example. At here, we are showing that the value of a field can be an java object or an byte array.

I guess a more general question for this method will be if this approach work for all "supported" serialized types (e.g. the serialized type is a primitive type)?

* }
* }}}
*
* @param aggregationBuffer aggregation buffer before serialization
*/
def serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,16 @@ object AggUtils {
initialInputBufferOffset: Int = 0,
resultExpressions: Seq[NamedExpression] = Nil,
child: SparkPlan): SparkPlan = {
val useHash = HashAggregateExec.supportsAggregate(

val isUsingObjectAggregationBuffer: Boolean = aggregateExpressions.exists {
case AggregateExpression(agg: WithObjectAggregateBuffer, _, _, _) => true
case _ => false
}

val aggBufferAttributesSupportedByHashAggregate = HashAggregateExec.supportsAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
if (useHash) {

if (aggBufferAttributesSupportedByHashAggregate && !isUsingObjectAggregationBuffer) {
HashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, WithObjectAggregateBuffer}
import org.apache.spark.sql.execution.metric.SQLMetric

/**
Expand Down Expand Up @@ -54,7 +54,15 @@ class SortBasedAggregationIterator(
val bufferRowSize: Int = bufferSchema.length

val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)

val allFieldsMutable = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)

val isUsingObjectAggregationBuffer = aggregateFunctions.exists {
case agg: WithObjectAggregateBuffer => true
case _ => false
}

val useUnsafeBuffer = allFieldsMutable && !isUsingObjectAggregationBuffer

val buffer = if (useUnsafeBuffer) {
val unsafeProjection =
Expand Down Expand Up @@ -90,6 +98,21 @@ class SortBasedAggregationIterator(
// compared to MutableRow (aggregation buffer) directly.
private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType))

// AggregationFunction which store generic object in AggregationBuffer.
// @see [[WithObjectAggregationBuffer]] for more information
private val aggFunctionsWithObjectAggregationBuffer = aggregateFunctions.collect {
case (ag: ImperativeAggregate with WithObjectAggregateBuffer) => ag
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we make WithObjectAggregateBuffer extends ImperativeAggregate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ImperativeAggregate is an abstract class, that will make WithObjectAggregateBuffer quite heavy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heavy? In what sense?

}

// For AggregateFunction with generic object stored in aggregation buffer, we need to
// call serializeObjectAggregationBufferInPlace() explicitly to convert the generic object
// stored in aggregation buffer to serializable format.
private def serializeObjectAggregationBuffer(aggregationBuffer: MutableRow): Unit = {
aggFunctionsWithObjectAggregationBuffer.foreach { agg =>
agg.serializeObjectAggregationBufferInPlace(sortBasedAggregationBuffer)
}
}

protected def initialize(): Unit = {
if (inputIterator.hasNext) {
initializeBuffer(sortBasedAggregationBuffer)
Expand Down Expand Up @@ -131,6 +154,10 @@ class SortBasedAggregationIterator(
firstRowInNextGroup = currentRow.copy()
}
}

// Serializes the generic object stored in aggregation buffer.
serializeObjectAggregationBuffer(sortBasedAggregationBuffer)

// We have not seen a new group. It means that there is no new row in the input
// iter. The current group is the last group of the iter.
if (!findNextPartition) {
Expand Down Expand Up @@ -162,6 +189,8 @@ class SortBasedAggregationIterator(

def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
initializeBuffer(sortBasedAggregationBuffer)
// Serializes the generic object stored in aggregation buffer.
serializeObjectAggregationBuffer(sortBasedAggregationBuffer)
generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.AggregateWithObjectAggregateBufferSuite.MaxWithObjectAggregateBuffer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GenericMutableRow, MutableRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, WithObjectAggregateBuffer}
import org.apache.spark.sql.execution.aggregate.{SortAggregateExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, StructType}

class AggregateWithObjectAggregateBufferSuite extends QueryTest with SharedSQLContext {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also put a basic test in HashAggregationQueryWithControlledFallbackSuite, to test the fallback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will not use HashAggregationExec, so there is no point to fallback from HashAggregationExec?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right, I misread the code.


import testImplicits._

private val data = Seq((1, 0), (3, 1), (2, 0), (6, 3), (3, 1), (4, 1), (5, 0))


test("aggregate with object aggregate buffer, should not use HashAggregate") {
val df = data.toDF("a", "b")
val max = new MaxWithObjectAggregateBuffer($"a".expr)

// Always use SortAggregateExec instead of HashAggregateExec for planning even if the aggregate
// buffer attributes are mutable fields (every field can be mutated inline like int, long...)
val allFieldsMutable = max.aggBufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)
val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan
assert(allFieldsMutable == true && sparkPlan.isInstanceOf[SortAggregateExec])
}

test("aggregate with object aggregate buffer, no group by") {
val df = data.toDF("a", "b").coalesce(2)
checkAnswer(
df.select(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"b"), count($"b")),
Seq(Row(6, 7, 3, 7))
)
}

test("aggregate with object aggregate buffer, with group by") {
val df = data.toDF("a", "b").coalesce(2)
checkAnswer(
df.groupBy($"b").agg(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"a")),
Seq(
Row(0, 5, 3, 5),
Row(1, 4, 3, 4),
Row(3, 6, 1, 6)
)
)
}

test("aggregate with object aggregate buffer, empty inputs, no group by") {
val empty = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
empty.select(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"b"), count($"b")),
Seq(Row(Int.MinValue, 0, Int.MinValue, 0)))
}

test("aggregate with object aggregate buffer, empty inputs, with group by") {
val empty = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
empty.groupBy($"b").agg(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"a")),
Seq.empty[Row])
}

private def objectAggregateMax(column: Column): Column = {
val max = MaxWithObjectAggregateBuffer(column.expr)
Column(max.toAggregateExpression())
}
}

object AggregateWithObjectAggregateBufferSuite {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(we do not need to put the example class inside this object.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use the companion object to define a private scope.


/**
* Calculate the max value with object aggregation buffer. This stores object of class MaxValue
* in aggregation buffer.
*/
private case class MaxWithObjectAggregateBuffer(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends ImperativeAggregate with WithObjectAggregateBuffer {

override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newOffset)

override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newOffset)

// Stores a generic object MaxValue in aggregation buffer.
override def initialize(buffer: MutableRow): Unit = {
// Makes sure we are using an unsafe row for aggregation buffer.
assert(buffer.isInstanceOf[GenericMutableRow])
buffer.update(mutableAggBufferOffset, new MaxValue(Int.MinValue))
}

override def update(buffer: MutableRow, input: InternalRow): Unit = {
val inputValue = child.eval(input).asInstanceOf[Int]
val maxValue = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue]
if (inputValue > maxValue.value) {
maxValue.value = inputValue
}
}

override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue]
val inputMax = deserialize(inputBuffer, inputAggBufferOffset)
if (inputMax.value > bufferMax.value) {
bufferMax.value = inputMax.value
}
}

private def deserialize(buffer: InternalRow, offset: Int): MaxValue = {
new MaxValue((buffer.getInt(offset)))
}

override def serializeObjectAggregationBufferInPlace(buffer: MutableRow): Unit = {
val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue]
buffer(mutableAggBufferOffset) = bufferMax.value
}

override def eval(buffer: InternalRow): Any = {
val max = deserialize(buffer, mutableAggBufferOffset)
max.value
}

override val aggBufferAttributes: Seq[AttributeReference] =
Seq(AttributeReference("buf", IntegerType)())

override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
override def nullable: Boolean = true
override def deterministic: Boolean = false
override def children: Seq[Expression] = Seq(child)
}

private class MaxValue(var value: Int)
}