Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage
*/
class KafkaStreamWriter(
topic: Option[String], producerParams: Map[String, String], schema: StructType)
extends StreamWriter with SupportsWriteInternalRow {
extends StreamWriter {

validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic)

override def createInternalRowWriterFactory(): KafkaStreamWriterFactory =
override def createWriterFactory(): KafkaStreamWriterFactory =
KafkaStreamWriterFactory(topic, producerParams, schema)

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
Expand All @@ -71,6 +71,10 @@ case class KafkaStreamWriterFactory(
epochId: Long): DataWriter[InternalRow] = {
new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes)
}

// `KafkaRowWriter` copies the input row immediately via a unsafe projection, so we can skip the
// copy at Spark side.
override def reuseDataObject(): Boolean = true
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.sql.sources.v2.writer;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.StreamWriteSupport;
import org.apache.spark.sql.sources.v2.WriteSupport;
Expand Down Expand Up @@ -61,7 +61,7 @@ public interface DataSourceWriter {
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
DataWriterFactory<Row> createWriterFactory();
DataWriterFactory<InternalRow> createWriterFactory();

/**
* Returns whether Spark should use the commit coordinator to ensure that at most one task for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@
* successfully, and have a way to revert committed data writers without the commit message, because
* Spark only accepts the commit message that arrives first and ignore others.
*
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
* source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers
* that mix in {@link SupportsWriteInternalRow}.
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}.
*/
@InterfaceStability.Evolving
public interface DataWriter<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,15 @@ public interface DataWriterFactory<T> extends Serializable {
* this ID will always be 0.
*/
DataWriter<T> createDataWriter(int partitionId, long taskId, long epochId);

/**
* When true, Spark will reuse the same data object instance when sending data to the data writer,
* for better performance. Data writers should carefully handle the data objects if it's reused,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: if it's reused the it here is ambiguous. Maybe change to if they are reused?

* e.g. do not buffer the data objects in a list. By default it returns false for safety, data
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdblue did you hit this problem in iceberg?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, Iceberg assumes that data objects are reused.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: By default it returns false, data sources => By default the method returns false. Data sources

* sources can override it if their data writers immediately write the data object to somewhere
* else like a memory buffer or disk.
*/
default boolean reuseDataObject() {
Copy link
Contributor

@rdblue rdblue Aug 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be added in this commit. This is to move to InternalRow and should not alter the API. If we want to add this, then let's discuss it in a PR for this as a feature. I'm fine documenting the default reuse behavior in this commit, though.

I think writers are responsible for defensive copies if necessary. This default is going to cause sources to be slower and I don't think it is necessary for implementations that aren't tests buffering data in memory.

In general, I think it's okay for us to have higher expectations of sources than of users. It's okay to simply document that rows are reused.

return false;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
override def output: Seq[Attribute] = Nil

override protected def doExecute(): RDD[InternalRow] = {
val writeTask = writer match {
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
}

val writeTask = writer.createWriterFactory()
val useCommitCoordinator = writer.useCommitCoordinator
val rdd = query.execute()
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
Expand Down Expand Up @@ -113,11 +109,15 @@ object DataWritingSparkTask extends Logging {
val attemptId = context.attemptNumber()
val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0")
val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong)
val copyIfNeeded: InternalRow => InternalRow =
if (writeTask.reuseDataObject()) identity else _.copy()

// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
while (iter.hasNext) {
dataWriter.write(iter.next())
// Internally Spark reuse the same UnsafeRow instance when producing output rows, here we
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: reuse => reuses

// copy it to avoid troubles at data source side.
dataWriter.write(copyIfNeeded(iter.next()))
}

val msg = if (useCommitCoordinator) {
Expand Down Expand Up @@ -155,27 +155,3 @@ object DataWritingSparkTask extends Logging {
})
}
}

class InternalRowDataWriterFactory(
rowWriterFactory: DataWriterFactory[Row],
schema: StructType) extends DataWriterFactory[InternalRow] {

override def createDataWriter(
partitionId: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
new InternalRowDataWriter(
rowWriterFactory.createDataWriter(partitionId, taskId, epochId),
RowEncoder.apply(schema).resolveAndBind())
}
}

class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row])
extends DataWriter[InternalRow] {

override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record))

override def commit(): WriterCommitMessage = rowWriter.commit()

override def abort(): Unit = rowWriter.abort()
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp,
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter}
import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
import org.apache.spark.util.{Clock, Utils}

Expand Down Expand Up @@ -498,12 +497,7 @@ class MicroBatchExecution(
newAttributePlan.schema,
outputMode,
new DataSourceOptions(extraOptions.asJava))
if (writer.isInstanceOf[SupportsWriteInternalRow]) {
WriteToDataSourceV2(
new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan)
} else {
WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
}
WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousR
start.runTimeMs,
i,
numPartitions,
perPartitionRate)
.asInstanceOf[InputPartition[InternalRow]]
perPartitionRate): InputPartition[InternalRow]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed to make it compile, but at least we don't need to do cast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in a separate commit. I didn't notice yesterday that this is for the writer until it was linked from the other issue. I think this change needs to get in, but it should not be mixed into changes for the write path.

}.asJava
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@

package org.apache.spark.sql.execution.streaming.continuous

import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory}
import org.apache.spark.util.Utils

/**
Expand All @@ -47,6 +44,8 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
SparkEnv.get)
EpochTracker.initializeCurrentEpoch(
context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
val copyIfNeeded: InternalRow => InternalRow =
if (writeTask.reuseDataObject()) identity else _.copy()

while (!context.isInterrupted() && !context.isCompleted()) {
var dataWriter: DataWriter[InternalRow] = null
Expand All @@ -59,7 +58,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
context.taskAttemptId(),
EpochTracker.getCurrentEpoch.get)
while (dataIterator.hasNext) {
dataWriter.write(dataIterator.next())
dataWriter.write(copyIfNeeded(dataIterator.next()))
}
logInfo(s"Writer for partition ${context.partitionId()} " +
s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@ package org.apache.spark.sql.execution.streaming.continuous

import scala.util.control.NonFatal

import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory}
import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.util.Utils

/**
* The physical plan for writing data into a continuous processing [[StreamWriter]].
Expand All @@ -41,11 +37,7 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla
override def output: Seq[Attribute] = Nil

override protected def doExecute(): RDD[InternalRow] = {
val writerFactory = writer match {
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
}

val writerFactory = writer.createWriterFactory()
val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)

logInfo(s"Start processing data source writer: $writer. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.sql.execution.streaming.sources

import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
Expand All @@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
assert(SparkSession.getActiveSession.isDefined)
protected val spark = SparkSession.getActiveSession.get

def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory
def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
// We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
Expand All @@ -62,8 +62,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
println(printMessage)
println("-------------------------------------------")
// scalastyle:off println
spark
.createDataFrame(rows.toList.asJava, schema)
Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows))
.show(numRowsToShow, isTruncated)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession}
import org.apache.spark.sql.{ForeachWriter, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.python.PythonForeachWriter
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
Expand All @@ -46,11 +46,11 @@ case class ForeachWriterProvider[T](
schema: StructType,
mode: OutputMode,
options: DataSourceOptions): StreamWriter = {
new StreamWriter with SupportsWriteInternalRow {
new StreamWriter {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = {
override def createWriterFactory(): DataWriterFactory[InternalRow] = {
val rowConverter: InternalRow => T = converter match {
case Left(enc) =>
val boundEnc = enc.resolveAndBind(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter

/**
Expand All @@ -34,21 +33,5 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWr

override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages)

override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory()
}

class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter)
extends DataSourceWriter with SupportsWriteInternalRow {
override def commit(messages: Array[WriterCommitMessage]): Unit = {
writer.commit(batchId, messages)
}

override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages)

override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] =
writer match {
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
case _ => throw new IllegalStateException(
"InternalRowMicroBatchWriter should only be created with base writer support")
}
override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory()
}
Loading