Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ case class KafkaStreamWriterFactory(
epochId: Long): DataWriter[InternalRow] = {
new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes)
}

// `KafkaRowWriter` copies the input row immediately via a unsafe projection, so we can skip the
// copy at Spark side.
override def reuseDataObject(): Boolean = true
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
public interface DataWriterFactory<T> extends Serializable {

/**
* Returns a data writer to do the actual writing work.
* Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data
* object instance when sending data to the data writer, for better performance. Data writers
* are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a
* list.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the description about defensive copied in data writers, may be put in DataWriter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix it in my next PR, thanks!

*
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
Expand All @@ -50,15 +53,4 @@ public interface DataWriterFactory<T> extends Serializable {
* this ID will always be 0.
*/
DataWriter<T> createDataWriter(int partitionId, long taskId, long epochId);

/**
* When true, Spark will reuse the same data object instance when sending data to the data writer,
* for better performance. Data writers should carefully handle the data objects if it's reused,
* e.g. do not buffer the data objects in a list. By default it returns false for safety, data
* sources can override it if their data writers immediately write the data object to somewhere
* else like a memory buffer or disk.
*/
default boolean reuseDataObject() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,11 @@ object DataWritingSparkTask extends Logging {
val attemptId = context.attemptNumber()
val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0")
val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong)
val copyIfNeeded: InternalRow => InternalRow =
if (writeTask.reuseDataObject()) identity else _.copy()

// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
while (iter.hasNext) {
// Internally Spark reuse the same UnsafeRow instance when producing output rows, here we
// copy it to avoid troubles at data source side.
dataWriter.write(copyIfNeeded(iter.next()))
dataWriter.write(iter.next())
}

val msg = if (useCommitCoordinator) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousR
start.runTimeMs,
i,
numPartitions,
perPartitionRate): InputPartition[InternalRow]
perPartitionRate)
.asInstanceOf[InputPartition[InternalRow]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this cast necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to address your comments: do not mix read-side changes. So I reverted it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sorry about that. I should have looked at the whole diff.

}.asJava
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
SparkEnv.get)
EpochTracker.initializeCurrentEpoch(
context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
val copyIfNeeded: InternalRow => InternalRow =
if (writeTask.reuseDataObject()) identity else _.copy()

while (!context.isInterrupted() && !context.isCompleted()) {
var dataWriter: DataWriter[InternalRow] = null
// write the data and commit this writer.
Expand All @@ -58,7 +55,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
context.taskAttemptId(),
EpochTracker.getCurrentEpoch.get)
while (dataIterator.hasNext) {
dataWriter.write(copyIfNeeded(dataIterator.next()))
dataWriter.write(dataIterator.next())
}
logInfo(s"Writer for partition ${context.partitionId()} " +
s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ case class PackedRowCommitMessage(rows: Array[InternalRow]) extends WriterCommit
class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging {
private val data = mutable.Buffer[InternalRow]()

override def write(row: InternalRow): Unit = data.append(row)
// Spark reuses the same `InternalRow` instance, here we copy it before buffer it.
override def write(row: InternalRow): Unit = data.append(row.copy())

override def commit(): PackedRowCommitMessage = {
val msg = PackedRowCommitMessage(data.toArray)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,17 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
override def toString(): String = "MemorySinkV2"
}

case class MemoryWriterCommitMessage(partition: Int, data: Seq[InternalRow])
case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row])
extends WriterCommitMessage {}

class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType)
extends DataSourceWriter with Logging {

private val encoder = RowEncoder(schema).resolveAndBind()

override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema)

def commit(messages: Array[WriterCommitMessage]): Unit = {
val newRows = messages.flatMap {
case message: MemoryWriterCommitMessage => message.data.map(encoder.fromRow)
case message: MemoryWriterCommitMessage => message.data
}
sink.write(batchId, outputMode, newRows)
}
Expand All @@ -142,13 +140,11 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, sc
class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType)
extends StreamWriter {

private val encoder = RowEncoder(schema).resolveAndBind()

override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema)

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
val newRows = messages.flatMap {
case message: MemoryWriterCommitMessage => message.data.map(encoder.fromRow)
case message: MemoryWriterCommitMessage => message.data
}
sink.write(epochId, outputMode, newRows)
}
Expand All @@ -158,22 +154,26 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema:
}
}

case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[InternalRow] {
case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType)
extends DataWriterFactory[InternalRow] {

override def createDataWriter(
partitionId: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
new MemoryDataWriter(partitionId, outputMode)
new MemoryDataWriter(partitionId, outputMode, schema)
}
}

class MemoryDataWriter(partition: Int, outputMode: OutputMode)
class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructType)
extends DataWriter[InternalRow] with Logging {

private val data = mutable.Buffer[InternalRow]()
private val data = mutable.Buffer[Row]()

private val encoder = RowEncoder(schema).resolveAndBind()

override def write(row: InternalRow): Unit = {
data.append(row)
data.append(encoder.fromRow(row))
}

override def commit(): MemoryWriterCommitMessage = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.streaming.sources._
import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
Expand All @@ -27,7 +28,8 @@ import org.apache.spark.sql.types.StructType
class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
test("data writer") {
val partition = 1234
val writer = new MemoryDataWriter(partition, OutputMode.Append())
val writer = new MemoryDataWriter(
partition, OutputMode.Append(), new StructType().add("i", "int"))
writer.write(InternalRow(1))
writer.write(InternalRow(2))
writer.write(InternalRow(44))
Expand All @@ -44,16 +46,16 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
val writer = new MemoryStreamWriter(sink, OutputMode.Append(), new StructType().add("i", "int"))
writer.commit(0,
Array(
MemoryWriterCommitMessage(0, Seq(InternalRow(1), InternalRow(2))),
MemoryWriterCommitMessage(1, Seq(InternalRow(3), InternalRow(4))),
MemoryWriterCommitMessage(2, Seq(InternalRow(6), InternalRow(7)))
MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this changed back to Row?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the DataWriter needs to copy the input row before buffering it, which can be done by the RowEncoder when converting InternalRow to Row. Then the write message carries Rows to the driver side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use InternalRow.copy? I'd rather keep the update to InternalRow, but as long as the tests pass I wouldn't block this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the memory sink needs Rows at the end. Instead of collecting InternalRows via copy and then convert to Rows, I think it's more efficient to collect Rows directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thanks for explaining your rationale.

MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))
))
assert(sink.latestBatchId.contains(0))
assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
writer.commit(19,
Array(
MemoryWriterCommitMessage(3, Seq(InternalRow(11), InternalRow(22))),
MemoryWriterCommitMessage(0, Seq(InternalRow(33)))
MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
MemoryWriterCommitMessage(0, Seq(Row(33)))
))
assert(sink.latestBatchId.contains(19))
assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))
Expand All @@ -66,16 +68,16 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
val schema = new StructType().add("i", "int")
new MemoryWriter(sink, 0, OutputMode.Append(), schema).commit(
Array(
MemoryWriterCommitMessage(0, Seq(InternalRow(1), InternalRow(2))),
MemoryWriterCommitMessage(1, Seq(InternalRow(3), InternalRow(4))),
MemoryWriterCommitMessage(2, Seq(InternalRow(6), InternalRow(7)))
MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))
))
assert(sink.latestBatchId.contains(0))
assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
new MemoryWriter(sink, 19, OutputMode.Append(), schema).commit(
Array(
MemoryWriterCommitMessage(3, Seq(InternalRow(11), InternalRow(22))),
MemoryWriterCommitMessage(0, Seq(InternalRow(33)))
MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
MemoryWriterCommitMessage(0, Seq(Row(33)))
))
assert(sink.latestBatchId.contains(19))
assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))
Expand Down