Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.linalg

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -46,7 +46,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
}

override def serialize(obj: Matrix): InternalRow = {
val row = new GenericMutableRow(7)
val row = new GenericInternalRow(7)
obj match {
case sm: SparseMatrix =>
row.setByte(0, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.linalg

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

/**
Expand All @@ -42,14 +42,14 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def serialize(obj: Vector): InternalRow = {
obj match {
case SparseVector(size, indices, values) =>
val row = new GenericMutableRow(4)
val row = new GenericInternalRow(4)
row.setByte(0, 0)
row.setInt(1, size)
row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
row
case DenseVector(values) =>
val row = new GenericMutableRow(4)
val row = new GenericInternalRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{linalg => newlinalg}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -189,7 +189,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
}

override def serialize(obj: Matrix): InternalRow = {
val row = new GenericMutableRow(7)
val row = new GenericInternalRow(7)
obj match {
case sm: SparseMatrix =>
row.setByte(0, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.annotation.{AlphaComponent, Since}
import org.apache.spark.ml.{linalg => newlinalg}
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -214,14 +214,14 @@ class VectorUDT extends UserDefinedType[Vector] {
override def serialize(obj: Vector): InternalRow = {
obj match {
case SparseVector(size, indices, values) =>
val row = new GenericMutableRow(4)
val row = new GenericInternalRow(4)
row.setByte(0, 0)
row.setInt(1, size)
row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
row
case DenseVector(values) =>
val row = new GenericMutableRow(4)
val row = new GenericInternalRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
public final class UnsafeRow extends MutableRow implements Externalizable, KryoSerializable {
public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable {

//////////////////////////////////////////////////////////////////////////////
// Static methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.{DataType, Decimal, StructType}

/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
Expand All @@ -31,6 +31,27 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
// This is only use for test and will throw a null pointer exception if the position is null.
def getString(ordinal: Int): String = getUTF8String(ordinal).toString

def setNullAt(i: Int): Unit

def update(i: Int, value: Any): Unit

// default implementation (slow)
def setBoolean(i: Int, value: Boolean): Unit = update(i, value)
def setByte(i: Int, value: Byte): Unit = update(i, value)
def setShort(i: Int, value: Short): Unit = update(i, value)
def setInt(i: Int, value: Int): Unit = update(i, value)
def setLong(i: Int, value: Long): Unit = update(i, value)
def setFloat(i: Int, value: Float): Unit = update(i, value)
def setDouble(i: Int, value: Double): Unit = update(i, value)

/**
* Update the decimal column at `i`.
*
* Note: In order to support update decimal with precision > 18 in UnsafeRow,
* CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision).
*/
def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) }

/**
* Make a copy of the current [[InternalRow]] object.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ case class ExpressionEncoder[T](
private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer)

@transient
private lazy val inputRow = new GenericMutableRow(1)
private lazy val inputRow = new GenericInternalRow(1)

@transient
private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (fromField, toField) => cast(fromField.dataType, toField.dataType)
}
// TODO: Could be faster?
val newRow = new GenericMutableRow(from.fields.length)
val newRow = new GenericInternalRow(from.fields.length)
buildCast[InternalRow](_, row => {
var i = 0
while (i < row.numFields) {
Expand Down Expand Up @@ -892,7 +892,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val fieldsCasts = from.fields.zip(to.fields).map {
case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx)
}
val rowClass = classOf[GenericMutableRow].getName
val rowClass = classOf[GenericInternalRow].getName
val result = ctx.freshName("result")
val tmpRow = ctx.freshName("tmpRow")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,22 @@ class JoinedRow extends InternalRow {

override def anyNull: Boolean = row1.anyNull || row2.anyNull

override def setNullAt(i: Int): Unit = {
if (i < row1.numFields) {
row1.setNullAt(i)
} else {
row2.setNullAt(i - row1.numFields)
}
}

override def update(i: Int, value: Any): Unit = {
if (i < row1.numFields) {
row1.update(i, value)
} else {
row2.update(i - row1.numFields, value)
}
}

override def copy(): InternalRow = {
val copy1 = row1.copy()
val copy2 = row2.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
})

private[this] val exprArray = expressions.toArray
private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length)
private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length)
def currentValue: InternalRow = mutableRow

override def target(row: MutableRow): MutableProjection = {
override def target(row: InternalRow): MutableProjection = {
mutableRow = row
this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.types._

/**
* A parent class for mutable container objects that are reused when the values are changed,
* resulting in less garbage. These values are held by a [[SpecificMutableRow]].
* resulting in less garbage. These values are held by a [[SpecificInternalRow]].
*
* The following code was roughly used to generate these objects:
* {{{
Expand Down Expand Up @@ -191,8 +191,7 @@ final class MutableAny extends MutableValue {
* based on the dataTypes of each column. The intent is to decrease garbage when modifying the
* values of primitive columns.
*/
final class SpecificMutableRow(val values: Array[MutableValue])
extends MutableRow with BaseGenericInternalRow {
final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGenericInternalRow {

def this(dataTypes: Seq[DataType]) =
this(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ case class HyperLogLogPlusPlus(
aggBufferAttributes.map(_.newInstance())

/** Fill all words with zeros. */
override def initialize(buffer: MutableRow): Unit = {
override def initialize(buffer: InternalRow): Unit = {
var word = 0
while (word < numWords) {
buffer.setLong(mutableAggBufferOffset + word, 0)
Expand All @@ -168,7 +168,7 @@ case class HyperLogLogPlusPlus(
*
* Variable names in the HLL++ paper match variable names in the code.
*/
override def update(buffer: MutableRow, input: InternalRow): Unit = {
override def update(buffer: InternalRow, input: InternalRow): Unit = {
val v = child.eval(input)
if (v != null) {
// Create the hashed value 'x'.
Expand Down Expand Up @@ -200,7 +200,7 @@ case class HyperLogLogPlusPlus(
* Merge the HLL buffers by iterating through the registers in both buffers and select the
* maximum number of leading zeros for each register.
*/
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
var idx = 0
var wordOffset = 0
while (wordOffset < numWords) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ object PivotFirst {

// Currently UnsafeRow does not support the generic update method (throws
// UnsupportedOperationException), so we need to explicitly support each DataType.
private val updateFunction: PartialFunction[DataType, (MutableRow, Int, Any) => Unit] = {
private val updateFunction: PartialFunction[DataType, (InternalRow, Int, Any) => Unit] = {
case DoubleType =>
(row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double])
case IntegerType =>
Expand Down Expand Up @@ -89,9 +89,9 @@ case class PivotFirst(

val indexSize = pivotIndex.size

private val updateRow: (MutableRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType)
private val updateRow: (InternalRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType)

override def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit = {
override def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit = {
val pivotColValue = pivotColumn.eval(inputRow)
if (pivotColValue != null) {
// We ignore rows whose pivot column value is not in the list of pivot column values.
Expand All @@ -105,7 +105,7 @@ case class PivotFirst(
}
}

override def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit = {
override def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit = {
for (i <- 0 until indexSize) {
if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) {
val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType)
Expand All @@ -114,7 +114,7 @@ case class PivotFirst(
}
}

override def initialize(mutableAggBuffer: MutableRow): Unit = valueDataType match {
override def initialize(mutableAggBuffer: InternalRow): Unit = valueDataType match {
case d: DecimalType =>
// Per doc of setDecimal we need to do this instead of setNullAt for DecimalType.
for (i <- 0 until indexSize) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ abstract class Collect extends ImperativeAggregate {

protected[this] val buffer: Growable[Any] with Iterable[Any]

override def initialize(b: MutableRow): Unit = {
override def initialize(b: InternalRow): Unit = {
buffer.clear()
}

override def update(b: MutableRow, input: InternalRow): Unit = {
override def update(b: InternalRow, input: InternalRow): Unit = {
// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
val value = child.eval(input)
Expand All @@ -73,7 +73,7 @@ abstract class Collect extends ImperativeAggregate {
}
}

override def merge(buffer: MutableRow, input: InternalRow): Unit = {
override def merge(buffer: InternalRow, input: InternalRow): Unit = {
sys.error("Collect cannot be used in partial aggregations.")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,14 +307,14 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
*
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
*/
def initialize(mutableAggBuffer: MutableRow): Unit
def initialize(mutableAggBuffer: InternalRow): Unit

/**
* Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`.
*
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
*/
def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit
def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit

/**
* Combines new intermediate results from the `inputAggBuffer` with the existing intermediate
Expand All @@ -323,7 +323,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
* Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
*/
def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit
def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit
}

/**
Expand Down Expand Up @@ -504,16 +504,16 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
/** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */
def deserialize(storageFormat: Array[Byte]): T

final override def initialize(buffer: MutableRow): Unit = {
final override def initialize(buffer: InternalRow): Unit = {
val bufferObject = createAggregationBuffer()
buffer.update(mutableAggBufferOffset, bufferObject)
}

final override def update(buffer: MutableRow, input: InternalRow): Unit = {
final override def update(buffer: InternalRow, input: InternalRow): Unit = {
update(getBufferObject(buffer), input)
}

final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
final override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
val bufferObject = getBufferObject(buffer)
// The inputBuffer stores serialized aggregation buffer object produced by partial aggregate
val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset))
Expand Down Expand Up @@ -547,7 +547,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
* This is only called when doing Partial or PartialMerge mode aggregation, before the framework
* shuffle out aggregate buffers.
*/
final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = {
final def serializeAggregateBufferInPlace(buffer: InternalRow): Unit = {
buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ class CodeAndComment(val body: String, val comment: collection.Map[String, Strin
*/
abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {

protected val genericMutableRowType: String = classOf[GenericMutableRow].getName
protected val genericMutableRowType: String = classOf[GenericInternalRow].getName

/**
* Generates a class for a given input expression. Called when there is not cached code
Expand Down Expand Up @@ -889,7 +889,6 @@ object CodeGenerator extends Logging {
classOf[UnsafeArrayData].getName,
classOf[MapData].getName,
classOf[UnsafeMapData].getName,
classOf[MutableRow].getName,
classOf[Expression].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
Expand Down
Loading