Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,153 @@ abstract class DeclarativeAggregate
def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
}
}

/**
* Aggregation function which allows **arbitrary** user-defined java object to be used as internal
* aggregation buffer object.
*
* {{{
* aggregation buffer for normal aggregation function `avg`
* |
* v
* +--------------+---------------+-----------------------------------+
* | sum1 (Long) | count1 (Long) | generic user-defined java objects |
* +--------------+---------------+-----------------------------------+
* ^
* |
* Aggregation buffer object for `TypedImperativeAggregate` aggregation function
* }}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add a normal agg buffer after the generic one. So, readers will not assume that generic ones will always be put at the end.

*
* Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side):
*
* Stage 1: Partial aggregate at Mapper side:
*
* 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
* buffer object.
* 2. Upon each input row, the framework calls
* `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T.
* 3. After processing all rows of current group (group by key), the framework will serialize
* aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte]
* to disk if needed.
* 4. The framework moves on to next group, until all groups have been processed.
*
* Shuffling exchange data to Reducer tasks...
*
* Stage 2: Final mode aggregate at Reducer side:
*
* 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
* buffer object (type T) for merging.
* 2. For each aggregation output of Stage 1, The framework de-serializes the storage
* format (Array[Byte]) and produces one input aggregation object (type T).
* 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit`
* to merge the input aggregation object into aggregation buffer object.
* 4. After processing all input aggregation objects of current group (group by key), the framework
* calls method `eval(buffer: T)` to generate the final output for this group.
* 5. The framework moves on to next group, until all groups have been processed.
*
* NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
* instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
* buffer's storage format, which is not supported by hash based aggregation. Hash based
* aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
* fixed length and can be mutated in place in UnsafeRow)
*/
abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the wrong way around? Isn't ImperativeAggregate the untyped version of an TypedImperativeAggregate? Much like Dataset and DataFrame?

I know this has been done for engineering purposes, but I still wonder if we shouldn't reverse the hierarchy here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ImperativeAggregate only defines the interface. It does not specify what are accepted buffer types, right?


/**
* Creates an empty aggregation buffer object. This is called before processing each key group
* (group by key).
*
* @return an aggregation buffer object
*/
def createAggregationBuffer(): T

/**
* In-place updates the aggregation buffer object with an input row. buffer = buffer + input.
* This is typically called when doing Partial or Complete mode aggregation.
*
* @param buffer The aggregation buffer object.
* @param input an input row
*/
def update(buffer: T, input: InternalRow): Unit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes the buffer object type T can do in-place update, which is not always true, e.g. percentile_approx, how about def update(buffer: T, input: InternalRow): T?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan User can define a wrapper to do inplace update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems update needs to evaluate the input. We need to document it.


/**
* Merges an input aggregation object into aggregation buffer object. buffer = buffer + input.
* This is typically called when doing PartialMerge or Final mode aggregation.
*
* @param buffer the aggregation buffer object used to store the aggregation result.
* @param input an input aggregation object. Input aggregation object can be produced by
* de-serializing the partial aggregate's output from Mapper side.
*/
def merge(buffer: T, input: T): Unit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan you can find an example at
0a777cc


/**
* Generates the final aggregation result value for current key group with the aggregation buffer
* object.
*
* @param buffer aggregation buffer object.
* @return The aggregation result of current key group
*/
def eval(buffer: T): Any

/** Returns the class of aggregation buffer object */
def aggregationBufferClass: Class[T]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just do TypedImperativeAggregate[T : ClassTag]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to consider the Java API compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to? This is an internal API and I think we will only use it in scala.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. @yhuai mentioned multiple time Java API compatibility
#14753 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I was thinking about just avoid of using scala feature unless we have to.


/** Serializes the aggregation buffer object T to Array[Byte] */
def serialize(buffer: T): Array[Byte]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we limit the serializable format to Array[Byte]

The reason is that SpecialMutableRow will do type check for atomic types for each update call of the aggregation buffer. If we declare the storage format to be IntegerType, but actually stores an arbitrary object in the aggregation buffer, then SpecialMutableRow will catch this error and reports exception.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This detail deserves a comment in the code.


/** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */
def deserialize(storageFormat: Array[Byte]): T

final override def initialize(buffer: MutableRow): Unit = {
val bufferObject = createAggregationBuffer()
buffer.update(mutableAggBufferOffset, bufferObject)
}

final override def update(buffer: MutableRow, input: InternalRow): Unit = {
val bufferObject = getField(buffer, mutableAggBufferOffset).asInstanceOf[T]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getField[T]?

update(bufferObject, input)
}

final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
val bufferObject = getField[T](buffer, mutableAggBufferOffset)
val inputObject = deserialize(getField[Array[Byte]](inputBuffer, inputAggBufferOffset))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment to explain why we are calling deserialize.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should use inputBuffer.getBinary(inputAggBufferOffset) instead of getField[Array[Byte]](inputBuffer, inputAggBufferOffset), as the data type is BinaryType, not ObjectType(classOf[Any])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inputBuffer is a safeRow in SortAggregateExec

processRow(sortBasedAggregationBuffer, safeProj(currentRow))

inputBuffer.getBinary(inputAggBufferOffset) and getField[Array[Byte]](inputBuffer, inputAggBufferOffset) are equivalent.

Yes, it is better to use inputBuffer.getBinary(inputAggBufferOffset) directly

merge(bufferObject, inputObject)
}

final override def eval(buffer: InternalRow): Any = {
val bufferObject = getField[AnyRef](buffer, mutableAggBufferOffset)
if (bufferObject.getClass == aggregationBufferClass) {
// When used in Window frame aggregation, eval(buffer: InternalRow) is called directly
// on the object aggregation buffer without intermediate serializing/de-serializing.
eval(bufferObject.asInstanceOf[T])
} else {
eval(deserialize(bufferObject.asInstanceOf[Array[Byte]]))
}
}

private[this] val anyObjectType = ObjectType(classOf[AnyRef])
private def getField[U](input: InternalRow, fieldIndex: Int): U = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we only need getField(input: InternalRow, fieldIndex: Int): T?

input.get(fieldIndex, anyObjectType).asInstanceOf[U]
}

final override lazy val aggBufferAttributes: Seq[AttributeReference] = {
// Underlying storage type for the aggregation buffer object
Seq(AttributeReference("buf", BinaryType)())
}

final override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

/**
* In-place replaces the aggregation buffer object stored at buffer's index
* `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with SparkSQL internally supported underlying storage format. It can only be BinaryType now.

*
* The framework calls this method every time after updating/merging one group (group by key).
Copy link
Contributor

@cloud-fan cloud-fan Aug 24, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... every time before we output the buffer object to parent operator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think current description is clear?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not corrected, we may not call this method after updating/merging one group, if the buffer is used for eval not shuffle

*/
final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = {
val bufferObject = getField(buffer, mutableAggBufferOffset).asInstanceOf[T]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getField[T] too?

buffer(mutableAggBufferOffset) = serialize(bufferObject)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, TypedImperativeAggregate}
import org.apache.spark.sql.execution.metric.SQLMetric

/**
Expand Down Expand Up @@ -54,6 +54,7 @@ class SortBasedAggregationIterator(
val bufferRowSize: Int = bufferSchema.length

val genericMutableBuffer = new GenericMutableRow(bufferRowSize)

val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)

val buffer = if (useUnsafeBuffer) {
Expand Down Expand Up @@ -90,6 +91,24 @@ class SortBasedAggregationIterator(
// compared to MutableRow (aggregation buffer) directly.
private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType))

// Aggregation function which uses generic aggregation buffer object.
// @see [[TypedImperativeAggregate]] for more information
private val typedImperativeAggregates: Array[TypedImperativeAggregate[_]] = {
aggregateFunctions.collect {
case (ag: TypedImperativeAggregate[_]) => ag
}
}

// For TypedImperativeAggregate with generic aggregation buffer object, we need to call
// serializeAggregateBufferInPlace(...) explicitly to convert the aggregation buffer object
// to Spark Sql internally supported serializable storage format.
private def serializeTypedAggregateBuffer(aggregationBuffer: MutableRow): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused parameter aggregationBuffer. Or replace the following sortBasedAggregationBuffer to aggregationBuffer?

typedImperativeAggregates.foreach { agg =>
// In-place serialization
agg.serializeAggregateBufferInPlace(sortBasedAggregationBuffer)
}
}

protected def initialize(): Unit = {
if (inputIterator.hasNext) {
initializeBuffer(sortBasedAggregationBuffer)
Expand Down Expand Up @@ -131,6 +150,11 @@ class SortBasedAggregationIterator(
firstRowInNextGroup = currentRow.copy()
}
}

// Serializes the generic object stored in aggregation buffer for TypedImperativeAggregate
// aggregation functions.
serializeTypedAggregateBuffer(sortBasedAggregationBuffer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(basically, when we call eval, we always get the original object)


// We have not seen a new group. It means that there is no new row in the input
// iter. The current group is the last group of the iter.
if (!findNextPartition) {
Expand Down Expand Up @@ -162,6 +186,9 @@ class SortBasedAggregationIterator(

def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
initializeBuffer(sortBasedAggregationBuffer)
// Serializes the generic object stored in aggregation buffer for TypedImperativeAggregate
// aggregation functions.
serializeTypedAggregateBuffer(sortBasedAggregationBuffer)
generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer)
}
}
Loading