From 852c6f332bd8f7264cd9c3aae6325e3c84c80ff5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 1 Aug 2018 00:48:20 +0800 Subject: [PATCH 1/3] use InternalRow in DataSourceWriter --- .../sql/kafka010/KafkaStreamWriter.scala | 4 +- .../sources/v2/writer/DataSourceWriter.java | 4 +- .../sql/sources/v2/writer/DataWriter.java | 4 +- .../v2/writer/SupportsWriteInternalRow.java | 41 ------------------- .../datasources/v2/WriteToDataSourceV2.scala | 6 +-- .../streaming/MicroBatchExecution.scala | 10 +---- .../ContinuousRateStreamSource.scala | 3 +- .../WriteToContinuousDataSourceExec.scala | 6 +-- .../streaming/sources/ConsoleWriter.scala | 11 +++-- .../sources/ForeachWriterProvider.scala | 10 ++--- .../streaming/sources/MicroBatchWriter.scala | 21 +--------- .../sources/PackedRowWriterFactory.scala | 14 +++---- .../streaming/sources/memoryV2.scala | 29 ++++++++----- .../streaming/MemorySinkV2Suite.scala | 36 ++++++++-------- .../sources/v2/SimpleWritableDataSource.scala | 20 ++++----- 15 files changed, 74 insertions(+), 145 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 32923dc9f5a6..5f0802b46603 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -42,11 +42,11 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage */ class KafkaStreamWriter( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends StreamWriter with SupportsWriteInternalRow { + extends StreamWriter { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createInternalRowWriterFactory(): KafkaStreamWriterFactory = + override def createWriterFactory(): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 7eedc85a5d6f..385fc294fea8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -18,8 +18,8 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.StreamWriteSupport; import org.apache.spark.sql.sources.v2.WriteSupport; @@ -61,7 +61,7 @@ public interface DataSourceWriter { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - DataWriterFactory createWriterFactory(); + DataWriterFactory createWriterFactory(); /** * Returns whether Spark should use the commit coordinator to ensure that at most one task for diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 1626c0013e4e..27dc5ea224fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -53,9 +53,7 @@ * successfully, and have a way to revert committed data writers without the commit message, because * Spark only accepts the commit message that arrives first and ignore others. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers - * that mix in {@link SupportsWriteInternalRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}. */ @InterfaceStability.Evolving public interface DataWriter { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java deleted file mode 100644 index d2cf7e01c08c..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.writer; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.InternalRow; - -/** - * A mix-in interface for {@link DataSourceWriter}. Data source writers can implement this - * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. - * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get - * changed in the future Spark versions. - */ - -@InterfaceStability.Unstable -public interface SupportsWriteInternalRow extends DataSourceWriter { - - @Override - default DataWriterFactory createWriterFactory() { - throw new IllegalStateException( - "createWriterFactory should not be called with SupportsWriteInternalRow."); - } - - DataWriterFactory createInternalRowWriterFactory(); -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index b1148c0f62f7..1d958d682a54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -50,11 +50,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writeTask = writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) - } - + val writeTask = writer.createWriterFactory() val useCommitCoordinator = writer.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index abb807def623..c759f5be8ba3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -28,10 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -498,12 +497,7 @@ class MicroBatchExecution( newAttributePlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - if (writer.isInstanceOf[SupportsWriteInternalRow]) { - WriteToDataSourceV2( - new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) - } else { - WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) - } + WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 551e07c3db86..c65f0636db0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -89,8 +89,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousR start.runTimeMs, i, numPartitions, - perPartitionRate) - .asInstanceOf[InputPartition[InternalRow]] + perPartitionRate): InputPartition[InternalRow] }.asJava } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index e0af3a2f1b85..facfa34f113a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -41,11 +41,7 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) - } - + val writerFactory = writer.createWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) logInfo(s"Start processing data source writer: $writer. " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index d276403190b3..fd45ba509091 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.streaming.sources -import scala.collection.JavaConverters._ - import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) assert(SparkSession.getActiveSession.isDefined) protected val spark = SparkSession.getActiveSession.get - def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 @@ -62,8 +62,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) println(printMessage) println("-------------------------------------------") // scalastyle:off println - spark - .createDataFrame(rows.toList.asJava, schema) + Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) .show(numRowsToShow, isTruncated) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index bc9b6d93ce7d..e8ce21cc1204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,11 +46,11 @@ case class ForeachWriterProvider[T]( schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new StreamWriter with SupportsWriteInternalRow { + new StreamWriter { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + override def createWriterFactory(): DataWriterFactory[InternalRow] = { val rowConverter: InternalRow => T = converter match { case Left(enc) => val boundEnc = enc.resolveAndBind( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index 56f7ff25cbed..d023a35ea20b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** @@ -34,21 +33,5 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWr override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory() -} - -class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) - extends DataSourceWriter with SupportsWriteInternalRow { - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writer.commit(batchId, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = - writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => throw new IllegalStateException( - "InternalRowMicroBatchWriter should only be created with base writer support") - } + override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index b501d90c81f0..613f58c70a5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} /** @@ -30,11 +30,11 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends DataWriterFactory[Row] { +case object PackedRowWriterFactory extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, taskId: Long, - epochId: Long): DataWriter[Row] = { + epochId: Long): DataWriter[InternalRow] = { new PackedRowDataWriter() } } @@ -43,15 +43,15 @@ case object PackedRowWriterFactory extends DataWriterFactory[Row] { * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most * recent interval. */ -case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage +case class PackedRowCommitMessage(rows: Array[InternalRow]) extends WriterCommitMessage /** * A simple [[DataWriter]] that just sends all the rows it's received as a commit message. */ -class PackedRowDataWriter() extends DataWriter[Row] with Logging { - private val data = mutable.Buffer[Row]() +class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging { + private val data = mutable.Buffer[InternalRow]() - override def write(row: Row): Unit = data.append(row) + override def write(row: InternalRow): Unit = data.append(row) override def commit(): PackedRowCommitMessage = { val msg = PackedRowCommitMessage(data.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index f2a35a90af24..7cfdf6bbe8b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils @@ -46,7 +48,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode) + new MemoryStreamWriter(this, mode, schema) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -115,16 +117,19 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB override def toString(): String = "MemorySinkV2" } -case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} +case class MemoryWriterCommitMessage(partition: Int, data: Seq[InternalRow]) + extends WriterCommitMessage {} -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) +class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType) extends DataSourceWriter with Logging { + private val encoder = RowEncoder(schema).resolveAndBind() + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { - case message: MemoryWriterCommitMessage => message.data + case message: MemoryWriterCommitMessage => message.data.map(encoder.fromRow) } sink.write(batchId, outputMode, newRows) } @@ -134,14 +139,16 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) extends StreamWriter { + private val encoder = RowEncoder(schema).resolveAndBind() + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { - case message: MemoryWriterCommitMessage => message.data + case message: MemoryWriterCommitMessage => message.data.map(encoder.fromRow) } sink.write(epochId, outputMode, newRows) } @@ -151,21 +158,21 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) } } -case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { +case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, taskId: Long, - epochId: Long): DataWriter[Row] = { + epochId: Long): DataWriter[InternalRow] = { new MemoryDataWriter(partitionId, outputMode) } } class MemoryDataWriter(partition: Int, outputMode: OutputMode) - extends DataWriter[Row] with Logging { + extends DataWriter[InternalRow] with Logging { - private val data = mutable.Buffer[Row]() + private val data = mutable.Buffer[InternalRow]() - override def write(row: Row): Unit = { + override def write(row: InternalRow): Unit = { data.append(row) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 9be22d94b565..3a4dbe12d0bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -19,17 +19,18 @@ package org.apache.spark.sql.execution.streaming import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { val partition = 1234 val writer = new MemoryDataWriter(partition, OutputMode.Append()) - writer.write(Row(1)) - writer.write(Row(2)) - writer.write(Row(44)) + writer.write(InternalRow(1)) + writer.write(InternalRow(2)) + writer.write(InternalRow(44)) val msg = writer.commit() assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44)) assert(msg.partition == partition) @@ -40,19 +41,19 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append(), new StructType().add("i", "int")) writer.commit(0, Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) + MemoryWriterCommitMessage(0, Seq(InternalRow(1), InternalRow(2))), + MemoryWriterCommitMessage(1, Seq(InternalRow(3), InternalRow(4))), + MemoryWriterCommitMessage(2, Seq(InternalRow(6), InternalRow(7))) )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) writer.commit(19, Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))) + MemoryWriterCommitMessage(3, Seq(InternalRow(11), InternalRow(22))), + MemoryWriterCommitMessage(0, Seq(InternalRow(33))) )) assert(sink.latestBatchId.contains(19)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) @@ -62,18 +63,19 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append()).commit( + val schema = new StructType().add("i", "int") + new MemoryWriter(sink, 0, OutputMode.Append(), schema).commit( Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) + MemoryWriterCommitMessage(0, Seq(InternalRow(1), InternalRow(2))), + MemoryWriterCommitMessage(1, Seq(InternalRow(3), InternalRow(4))), + MemoryWriterCommitMessage(2, Seq(InternalRow(6), InternalRow(7))) )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append()).commit( + new MemoryWriter(sink, 19, OutputMode.Append(), schema).commit( Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))) + MemoryWriterCommitMessage(3, Seq(InternalRow(11), InternalRow(22))), + MemoryWriterCommitMessage(0, Seq(InternalRow(33))) )) assert(sink.latestBatchId.contains(19)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 183d0399d3bc..32202b6bc270 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext -import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.writer._ @@ -65,7 +65,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { - override def createWriterFactory(): DataWriterFactory[Row] = { + override def createWriterFactory(): DataWriterFactory[InternalRow] = { SimpleCounter.resetCounter new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } @@ -98,13 +98,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } class InternalRowWriter(jobId: String, path: String, conf: Configuration) - extends Writer(jobId, path, conf) with SupportsWriteInternalRow { + extends Writer(jobId, path, conf) { - override def createWriterFactory(): DataWriterFactory[Row] = { - throw new IllegalArgumentException("not expected!") - } - - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + override def createWriterFactory(): DataWriterFactory[InternalRow] = { new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } } @@ -205,12 +201,12 @@ private[v2] object SimpleCounter { } class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[Row] { + extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, taskId: Long, - epochId: Long): DataWriter[Row] = { + epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) @@ -218,11 +214,11 @@ class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: Serializable } } -class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { +class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { private val out = fs.create(file) - override def write(record: Row): Unit = { + override def write(record: InternalRow): Unit = { out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") } From 1e0cb908d9e17117690405837a66805de584b34f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 2 Aug 2018 11:16:02 +0800 Subject: [PATCH 2/3] more work --- .../sql/kafka010/KafkaStreamWriter.scala | 4 ++ .../sources/v2/writer/DataWriterFactory.java | 11 ++++ .../datasources/v2/WriteToDataSourceV2.scala | 30 ++------- .../continuous/ContinuousWriteRDD.scala | 9 ++- .../WriteToContinuousDataSourceExec.scala | 6 +- .../sql/sources/v2/DataSourceV2Suite.scala | 8 --- .../sources/v2/SimpleWritableDataSource.scala | 64 ++----------------- 7 files changed, 30 insertions(+), 102 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 5f0802b46603..1aa73af888de 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -71,6 +71,10 @@ case class KafkaStreamWriterFactory( epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } + + // `KafkaRowWriter` copies the input row immediately via a unsafe projection, so we can skip the + // copy at Spark side. + override def reuseDataObject(): Boolean = true } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 0932ff8f8f8a..152020dd129e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -50,4 +50,15 @@ public interface DataWriterFactory extends Serializable { * this ID will always be 0. */ DataWriter createDataWriter(int partitionId, long taskId, long epochId); + + /** + * When true, Spark will reuse the same data object instance when sending data to the data writer, + * for better performance. Data writers should carefully handle the data objects if it's reused, + * e.g. do not buffer the data objects in a list. By default it returns false for safety, data + * sources can override it if their data writers immediately write the data object to somewhere + * else like a memory buffer or disk. + */ + default boolean reuseDataObject() { + return false; + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 1d958d682a54..03c81682d2b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -109,11 +109,15 @@ object DataWritingSparkTask extends Logging { val attemptId = context.attemptNumber() val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) + val copyIfNeeded: InternalRow => InternalRow = + if (writeTask.reuseDataObject()) identity else _.copy() // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { while (iter.hasNext) { - dataWriter.write(iter.next()) + // Internally Spark reuse the same UnsafeRow instance when producing output rows, here we + // copy it to avoid troubles at data source side. + dataWriter.write(copyIfNeeded(iter.next())) } val msg = if (useCommitCoordinator) { @@ -151,27 +155,3 @@ object DataWritingSparkTask extends Logging { }) } } - -class InternalRowDataWriterFactory( - rowWriterFactory: DataWriterFactory[Row], - schema: StructType) extends DataWriterFactory[InternalRow] { - - override def createDataWriter( - partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[InternalRow] = { - new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, taskId, epochId), - RowEncoder.apply(schema).resolveAndBind()) - } -} - -class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) - extends DataWriter[InternalRow] { - - override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) - - override def commit(): WriterCommitMessage = rowWriter.commit() - - override def abort(): Unit = rowWriter.abort() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 76f3f5baa8d5..8840dd3bb09f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -17,13 +17,10 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.concurrent.atomic.AtomicLong - import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory} import org.apache.spark.util.Utils /** @@ -47,6 +44,8 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor SparkEnv.get) EpochTracker.initializeCurrentEpoch( context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) + val copyIfNeeded: InternalRow => InternalRow = + if (writeTask.reuseDataObject()) identity else _.copy() while (!context.isInterrupted() && !context.isCompleted()) { var dataWriter: DataWriter[InternalRow] = null @@ -59,7 +58,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) while (dataIterator.hasNext) { - dataWriter.write(dataIterator.next()) + dataWriter.write(copyIfNeeded(dataIterator.next())) } logInfo(s"Writer for partition ${context.partitionId()} " + s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index facfa34f113a..927d3a84e296 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -19,18 +19,14 @@ package org.apache.spark.sql.execution.streaming.continuous import scala.util.control.NonFatal -import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory} -import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter -import org.apache.spark.util.Utils /** * The physical plan for writing data into a continuous processing [[StreamWriter]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index c7da13721989..2496ac7bfdce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,7 +24,6 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector @@ -243,13 +242,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Writing job aborted")) // make sure we don't have partial data. assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - - // test internal row writer - spark.range(5).select('id, -'id).write.format(cls.getName) - .option("path", path).option("internal", "true").mode("overwrite").save() - checkAnswer( - spark.read.format(cls.getName).option("path", path).load(), - spark.range(5).select('id, -'id)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 32202b6bc270..e1b8e9c44d72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -67,7 +67,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { override def createWriterFactory(): DataWriterFactory[InternalRow] = { SimpleCounter.resetCounter - new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } override def onDataWriterCommit(message: WriterCommitMessage): Unit = { @@ -97,14 +97,6 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - class InternalRowWriter(jobId: String, path: String, conf: Configuration) - extends Writer(jobId, path, conf) { - - override def createWriterFactory(): DataWriterFactory[InternalRow] = { - new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) - } - } - override def createReader(options: DataSourceOptions): DataSourceReader = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration @@ -120,7 +112,6 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) val path = new Path(options.get("path").get()) - val internal = options.get("internal").isPresent val conf = SparkContext.getActive.get.hadoopConfiguration val fs = path.getFileSystem(conf) @@ -138,17 +129,8 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS fs.delete(path, true) } - Optional.of(createWriter(jobId, path, conf, internal)) - } - - private def createWriter( - jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter = { val pathStr = path.toUri.toString - if (internal) { - new InternalRowWriter(jobId, pathStr, conf) - } else { - new Writer(jobId, pathStr, conf) - } + Optional.of(new Writer(jobId, pathStr, conf)) } } @@ -200,43 +182,7 @@ private[v2] object SimpleCounter { } } -class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[InternalRow] { - - override def createDataWriter( - partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[InternalRow] = { - val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") - val fs = filePath.getFileSystem(conf.value) - new SimpleCSVDataWriter(fs, filePath) - } -} - -class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { - - private val out = fs.create(file) - - override def write(record: InternalRow): Unit = { - out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") - } - - override def commit(): WriterCommitMessage = { - out.close() - null - } - - override def abort(): Unit = { - try { - out.close() - } finally { - fs.delete(file, false) - } - } -} - -class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) +class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { override def createDataWriter( @@ -246,11 +192,11 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: Seriali val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) - new InternalRowCSVDataWriter(fs, filePath) + new CSVDataWriter(fs, filePath) } } -class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { +class CSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { private val out = fs.create(file) From 86817c7ee36f1344e977bb5af14aeb56232c17d5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 3 Aug 2018 14:11:45 +0800 Subject: [PATCH 3/3] address comment --- .../sql/kafka010/KafkaStreamWriter.scala | 4 --- .../sources/v2/writer/DataWriterFactory.java | 16 +++-------- .../datasources/v2/WriteToDataSourceV2.scala | 6 +--- .../ContinuousRateStreamSource.scala | 3 +- .../continuous/ContinuousWriteRDD.scala | 5 +--- .../sources/PackedRowWriterFactory.scala | 3 +- .../streaming/sources/memoryV2.scala | 28 +++++++++---------- .../streaming/MemorySinkV2Suite.scala | 24 ++++++++-------- 8 files changed, 37 insertions(+), 52 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 1aa73af888de..5f0802b46603 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -71,10 +71,6 @@ case class KafkaStreamWriterFactory( epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } - - // `KafkaRowWriter` copies the input row immediately via a unsafe projection, so we can skip the - // copy at Spark side. - override def reuseDataObject(): Boolean = true } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 152020dd129e..3d337b6e0bdf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -33,7 +33,10 @@ public interface DataWriterFactory extends Serializable { /** - * Returns a data writer to do the actual writing work. + * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data + * object instance when sending data to the data writer, for better performance. Data writers + * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a + * list. * * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. @@ -50,15 +53,4 @@ public interface DataWriterFactory extends Serializable { * this ID will always be 0. */ DataWriter createDataWriter(int partitionId, long taskId, long epochId); - - /** - * When true, Spark will reuse the same data object instance when sending data to the data writer, - * for better performance. Data writers should carefully handle the data objects if it's reused, - * e.g. do not buffer the data objects in a list. By default it returns false for safety, data - * sources can override it if their data writers immediately write the data object to somewhere - * else like a memory buffer or disk. - */ - default boolean reuseDataObject() { - return false; - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 03c81682d2b4..0399970495be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -109,15 +109,11 @@ object DataWritingSparkTask extends Logging { val attemptId = context.attemptNumber() val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) - val copyIfNeeded: InternalRow => InternalRow = - if (writeTask.reuseDataObject()) identity else _.copy() // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { while (iter.hasNext) { - // Internally Spark reuse the same UnsafeRow instance when producing output rows, here we - // copy it to avoid troubles at data source side. - dataWriter.write(copyIfNeeded(iter.next())) + dataWriter.write(iter.next()) } val msg = if (useCommitCoordinator) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index c65f0636db0e..551e07c3db86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -89,7 +89,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousR start.runTimeMs, i, numPartitions, - perPartitionRate): InputPartition[InternalRow] + perPartitionRate) + .asInstanceOf[InputPartition[InternalRow]] }.asJava } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 8840dd3bb09f..967dbe24a370 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -44,9 +44,6 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor SparkEnv.get) EpochTracker.initializeCurrentEpoch( context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) - val copyIfNeeded: InternalRow => InternalRow = - if (writeTask.reuseDataObject()) identity else _.copy() - while (!context.isInterrupted() && !context.isCompleted()) { var dataWriter: DataWriter[InternalRow] = null // write the data and commit this writer. @@ -58,7 +55,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) while (dataIterator.hasNext) { - dataWriter.write(copyIfNeeded(dataIterator.next())) + dataWriter.write(dataIterator.next()) } logInfo(s"Writer for partition ${context.partitionId()} " + s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index 613f58c70a5c..f26e11d842b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -51,7 +51,8 @@ case class PackedRowCommitMessage(rows: Array[InternalRow]) extends WriterCommit class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging { private val data = mutable.Buffer[InternalRow]() - override def write(row: InternalRow): Unit = data.append(row) + // Spark reuses the same `InternalRow` instance, here we copy it before buffer it. + override def write(row: InternalRow): Unit = data.append(row.copy()) override def commit(): PackedRowCommitMessage = { val msg = PackedRowCommitMessage(data.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 7cfdf6bbe8b5..afacb2f72c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -117,19 +117,17 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB override def toString(): String = "MemorySinkV2" } -case class MemoryWriterCommitMessage(partition: Int, data: Seq[InternalRow]) +case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType) extends DataSourceWriter with Logging { - private val encoder = RowEncoder(schema).resolveAndBind() - - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { - case message: MemoryWriterCommitMessage => message.data.map(encoder.fromRow) + case message: MemoryWriterCommitMessage => message.data } sink.write(batchId, outputMode, newRows) } @@ -142,13 +140,11 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, sc class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) extends StreamWriter { - private val encoder = RowEncoder(schema).resolveAndBind() - - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { - case message: MemoryWriterCommitMessage => message.data.map(encoder.fromRow) + case message: MemoryWriterCommitMessage => message.data } sink.write(epochId, outputMode, newRows) } @@ -158,22 +154,26 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: } } -case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[InternalRow] { +case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { - new MemoryDataWriter(partitionId, outputMode) + new MemoryDataWriter(partitionId, outputMode, schema) } } -class MemoryDataWriter(partition: Int, outputMode: OutputMode) +class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructType) extends DataWriter[InternalRow] with Logging { - private val data = mutable.Buffer[InternalRow]() + private val data = mutable.Buffer[Row]() + + private val encoder = RowEncoder(schema).resolveAndBind() override def write(row: InternalRow): Unit = { - data.append(row) + data.append(encoder.fromRow(row)) } override def commit(): MemoryWriterCommitMessage = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 3a4dbe12d0bb..b4d9b68c7815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} @@ -27,7 +28,8 @@ import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { val partition = 1234 - val writer = new MemoryDataWriter(partition, OutputMode.Append()) + val writer = new MemoryDataWriter( + partition, OutputMode.Append(), new StructType().add("i", "int")) writer.write(InternalRow(1)) writer.write(InternalRow(2)) writer.write(InternalRow(44)) @@ -44,16 +46,16 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { val writer = new MemoryStreamWriter(sink, OutputMode.Append(), new StructType().add("i", "int")) writer.commit(0, Array( - MemoryWriterCommitMessage(0, Seq(InternalRow(1), InternalRow(2))), - MemoryWriterCommitMessage(1, Seq(InternalRow(3), InternalRow(4))), - MemoryWriterCommitMessage(2, Seq(InternalRow(6), InternalRow(7))) + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) writer.commit(19, Array( - MemoryWriterCommitMessage(3, Seq(InternalRow(11), InternalRow(22))), - MemoryWriterCommitMessage(0, Seq(InternalRow(33))) + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))) )) assert(sink.latestBatchId.contains(19)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) @@ -66,16 +68,16 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { val schema = new StructType().add("i", "int") new MemoryWriter(sink, 0, OutputMode.Append(), schema).commit( Array( - MemoryWriterCommitMessage(0, Seq(InternalRow(1), InternalRow(2))), - MemoryWriterCommitMessage(1, Seq(InternalRow(3), InternalRow(4))), - MemoryWriterCommitMessage(2, Seq(InternalRow(6), InternalRow(7))) + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) new MemoryWriter(sink, 19, OutputMode.Append(), schema).commit( Array( - MemoryWriterCommitMessage(3, Seq(InternalRow(11), InternalRow(22))), - MemoryWriterCommitMessage(0, Seq(InternalRow(33))) + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))) )) assert(sink.latestBatchId.contains(19)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))