From ac7eb2f3cf4cca8ee5d64f90f71c6c0d14931c52 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Mon, 11 Jun 2018 17:38:38 -0700 Subject: [PATCH 01/16] Add in logic to determine the max rows a sink can have --- .../sql/execution/streaming/memory.scala | 27 +++++++++++++++++++ .../streaming/sources/memoryV2.scala | 10 +++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b137f98045c5a..9a87b7d63c884 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode @@ -228,6 +229,32 @@ trait MemorySinkBase extends BaseStreamingSink { def latestBatchId: Option[Long] } +/** + * Companion object to MemorySinkBase. + */ +object MemorySinkBase { + val MAX_MEMORY_SINK_ROWS = "" + val MAX_MEMORY_SINK_ROWS_DEFAULT = -1L + val MAX_MEMORY_SINK_BYTES = "" + val MAX_MEMORY_SINK_BYTES_DEFAULT = -1L + + def getMaxRows(schema: StructType, options: DataSourceOptions): Option[Long] = { + val maxBytes = options.getLong(MAX_MEMORY_SINK_BYTES, MAX_MEMORY_SINK_BYTES_DEFAULT) + val maxRows = options.getLong(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) + val sizePerRow = EstimationUtils.getSizePerRow(schema.toAttributes).longValue() + if (maxBytes >= 0 && maxRows >= 0) { + Some(math.min(maxRows, maxBytes / sizePerRow)) + } else if (maxBytes >= 0) { + Some(maxBytes / sizePerRow) + } else if (maxRows >= 0) { + Some(maxRows) + } else { + None + } + } +} + + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 468313bfe8c3c..761f540d2755e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode) + new MemoryStreamWriter(this, schema, mode, options) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -134,9 +134,15 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) +class MemoryStreamWriter( + val sink: MemorySinkV2, + schema: StructType, + outputMode: OutputMode, + options: DataSourceOptions) extends StreamWriter { + val maxRowsInSink = MemorySinkBase.getMaxRows(schema, options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { From 8dc89cca9129b25ad8f5f4cda856e5b594f53e52 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Tue, 12 Jun 2018 11:55:32 -0700 Subject: [PATCH 02/16] Make MemorySink and MemorySinkV2 respect row and byte limits --- .../sql/execution/streaming/memory.scala | 49 ++++++++++++++----- .../streaming/sources/memoryV2.scala | 27 +++++++--- .../sql/streaming/DataStreamWriter.scala | 4 +- 3 files changed, 59 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 9a87b7d63c884..7b11a4982d994 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -233,19 +233,27 @@ trait MemorySinkBase extends BaseStreamingSink { * Companion object to MemorySinkBase. */ object MemorySinkBase { - val MAX_MEMORY_SINK_ROWS = "" - val MAX_MEMORY_SINK_ROWS_DEFAULT = -1L - val MAX_MEMORY_SINK_BYTES = "" + val MAX_MEMORY_SINK_ROWS = "maxMemorySinkRows" + val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 + val MAX_MEMORY_SINK_BYTES = "maxMemorySinkBytes" val MAX_MEMORY_SINK_BYTES_DEFAULT = -1L - def getMaxRows(schema: StructType, options: DataSourceOptions): Option[Long] = { + /** + * Gets the max number of rows a MemorySink should store. This number is based on the lesser of + * the memory sink row limit or the memory sink byte limit, if either is set. If not, there is + * no limit. + * @param schema The row schema, for use in computing size per row. + * @param options Options for writing from which we get the max rows or bytes. + * @return The maximum number of rows a memorySink should store, or None for no limit. + */ + def getMaxRows(schema: StructType, options: DataSourceOptions): Option[Int] = { val maxBytes = options.getLong(MAX_MEMORY_SINK_BYTES, MAX_MEMORY_SINK_BYTES_DEFAULT) - val maxRows = options.getLong(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) + val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) val sizePerRow = EstimationUtils.getSizePerRow(schema.toAttributes).longValue() if (maxBytes >= 0 && maxRows >= 0) { - Some(math.min(maxRows, maxBytes / sizePerRow)) + Some(math.min(maxRows, (maxBytes / sizePerRow).asInstanceOf[Int])) } else if (maxBytes >= 0) { - Some(maxBytes / sizePerRow) + Some((maxBytes / sizePerRow).asInstanceOf[Int]) } else if (maxRows >= 0) { Some(maxRows) } else { @@ -259,8 +267,8 @@ object MemorySinkBase { * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink - with MemorySinkBase with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions) + extends Sink with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -268,6 +276,12 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + + /** The capacity in rows of this sink. */ + val sinkCapacity: Option[Int] = MemorySinkBase.getMaxRows(schema, options) + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -300,14 +314,24 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } + val newRows = data.collect() + synchronized { + val rowsToAdd = + if (sinkCapacity.isDefined) newRows.take(sinkCapacity.get - numRows) else newRows + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, data.collect()) + val newRows = data.collect() synchronized { + val rowsToAdd = + if (sinkCapacity.isDefined) newRows.take(sinkCapacity.get) else newRows + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -321,6 +345,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 761f540d2755e..9ae53156ff2e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -55,6 +55,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -81,7 +84,8 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB }.mkString("\n") } - def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { + def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row], sinkCapacity: Option[Int]) + : Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } @@ -89,14 +93,22 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, newRows) - synchronized { batches += rows } + synchronized { + val rowsToAdd = + if (sinkCapacity.isDefined) newRows.take(sinkCapacity.get - numRows) else newRows + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, newRows) synchronized { + val rowsToAdd = + if (sinkCapacity.isDefined) newRows.take(sinkCapacity.get) else newRows + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -110,6 +122,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySinkV2" @@ -126,7 +139,7 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(batchId, outputMode, newRows) + sink.write(batchId, outputMode, newRows, None) } override def abort(messages: Array[WriterCommitMessage]): Unit = { @@ -141,7 +154,7 @@ class MemoryStreamWriter( options: DataSourceOptions) extends StreamWriter { - val maxRowsInSink = MemorySinkBase.getMaxRows(schema, options) + val sinkCapacity: Option[Int] = MemorySinkBase.getMaxRows(schema, options) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) @@ -149,7 +162,7 @@ class MemoryStreamWriter( val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows) + sink.write(epochId, outputMode, newRows, sinkCapacity) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index effc1471e8e12..4dad8b58c4cbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -249,7 +249,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) (s, r) case _ => - val s = new MemorySink(df.schema, outputMode) + val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) (s, r) } From 8ddf566259016e4ce727eabb3206fd65303c5580 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Tue, 12 Jun 2018 12:20:44 -0700 Subject: [PATCH 03/16] Make tests compile --- .../spark/sql/execution/streaming/MemorySinkSuite.scala | 9 +++++---- .../sql/execution/streaming/MemorySinkV2Suite.scala | 7 ++++++- .../org/apache/spark/sql/streaming/StreamTest.scala | 4 +++- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 3bc36ce55d902..b90a3726a0da9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -22,6 +22,7 @@ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -36,7 +37,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -70,7 +71,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update) + val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -104,7 +105,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete) + val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -211,7 +212,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) val plan = new MemoryPlan(sink) // Before adding data, check output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 9be22d94b5654..a6be5c372d18b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -21,7 +21,10 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.sources._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { @@ -40,7 +43,9 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append()) + var schema = new StructType().add("value", IntegerType) + val writer = + new MemoryStreamWriter(sink, schema, OutputMode.Append(), DataSourceOptions.empty()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4c3fd58cb2e45..e41b4534ed51d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -337,7 +338,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) + val sink = if (useV2Sink) new MemorySinkV2 + else new MemorySink(stream.schema, outputMode, DataSourceOptions.empty()) val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath From d82c7d5ee84b25e968f705aded2f2c04edc5c140 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Tue, 12 Jun 2018 13:26:56 -0700 Subject: [PATCH 04/16] Make microbatch memory writer work with limits --- .../sql/execution/streaming/sources/memoryV2.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 9ae53156ff2e2..67c20e14dfb3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -130,16 +130,23 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) +class MemoryWriter( + sink: MemorySinkV2, + batchId: Long, + schema: StructType, + outputMode: OutputMode, + options: DataSourceOptions) extends DataSourceWriter with Logging { + val sinkCapacity: Option[Int] = MemorySinkBase.getMaxRows(schema, options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(batchId, outputMode, newRows, None) + sink.write(batchId, outputMode, newRows, sinkCapacity) } override def abort(messages: Array[WriterCommitMessage]): Unit = { From 7fefe877b03fe4ad522275780a64425b58bf5bb0 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Tue, 12 Jun 2018 13:27:03 -0700 Subject: [PATCH 05/16] Test MemorySinkV2 with limits --- .../streaming/MemorySinkV2Suite.scala | 133 +++++++++++++++++- 1 file changed, 129 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index a6be5c372d18b..013309668c695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row @@ -67,16 +69,17 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append()).commit( - Array( + var schema = new StructType().add("value", IntegerType) + new MemoryWriter(sink, 0, schema, OutputMode.Append(), DataSourceOptions.empty()) + .commit(Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append()).commit( - Array( + new MemoryWriter(sink, 19, schema, OutputMode.Append(), DataSourceOptions.empty()) + .commit(Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) )) @@ -85,4 +88,126 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } + + test("continuous writer with row limit") { + val sink = new MemorySinkV2 + var schema = new StructType().add("value", IntegerType) + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + val writer = new MemoryStreamWriter(sink, schema, OutputMode.Append(), options) + writer.commit(0, + Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) + )) + assert(sink.latestBatchId.contains(0)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) + writer.commit(19, + Array( + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))) + )) + assert(sink.latestBatchId.contains(19)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11)) + } + + test("microbatch writer with row limit") { + val sink = new MemorySinkV2 + var schema = new StructType().add("value", IntegerType) + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + + new MemoryWriter(sink, 25, schema, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) + assert(sink.latestBatchId.contains(25)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + new MemoryWriter(sink, 26, schema, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), + MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) + assert(sink.latestBatchId.contains(26)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) + + new MemoryWriter(sink, 27, schema, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), + MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) + assert(sink.latestBatchId.contains(27)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + new MemoryWriter(sink, 28, schema, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), + MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) + assert(sink.latestBatchId.contains(28)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + } + + test("microbatch writer with byte limit") { + val sink = new MemorySinkV2 + var schema = new StructType().add("value", IntegerType) + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_BYTES, 60.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + + new MemoryWriter(sink, 25, schema, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) + assert(sink.latestBatchId.contains(25)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + new MemoryWriter(sink, 26, schema, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), + MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) + assert(sink.latestBatchId.contains(26)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) + + new MemoryWriter(sink, 27, schema, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), + MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) + assert(sink.latestBatchId.contains(27)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + new MemoryWriter(sink, 28, schema, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), + MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) + assert(sink.latestBatchId.contains(28)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + } + + test("microbatch writer with row and byte limit") { + val sink = new MemorySinkV2 + var schema = new StructType().add("value", IntegerType) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 3.toString()) + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_BYTES, 400.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + new MemoryWriter(sink, 25, schema, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) + assert(sink.latestBatchId.contains(25)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3)) + + optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 10.toString()) + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_BYTES, 36.toString()) + options = new DataSourceOptions(optionsMap.toMap.asJava) + new MemoryWriter(sink, 26, schema, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), + MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) + assert(sink.latestBatchId.contains(26)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5, 6, 7)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(5, 6, 7)) + + } } From 58c5044ca2e62ca825df3a4e88c4b4f6d697461e Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Tue, 12 Jun 2018 15:08:49 -0700 Subject: [PATCH 06/16] Add MemorySink test with limit --- .../execution/streaming/MemorySinkSuite.scala | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index b90a3726a0da9..193fcec4b5dab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter @@ -69,6 +70,32 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 1 to 9) } + test("directly add data in Append output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Append, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 5) + checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit + } + test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty()) From 392f05f4c1d008493220f59ff7a4d4b948fdfc4b Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Tue, 12 Jun 2018 15:23:27 -0700 Subject: [PATCH 07/16] rename method --- .../org/apache/spark/sql/execution/streaming/memory.scala | 4 ++-- .../spark/sql/execution/streaming/sources/memoryV2.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7b11a4982d994..616fa5bc9ee38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -246,7 +246,7 @@ object MemorySinkBase { * @param options Options for writing from which we get the max rows or bytes. * @return The maximum number of rows a memorySink should store, or None for no limit. */ - def getMaxRows(schema: StructType, options: DataSourceOptions): Option[Int] = { + def getMemorySinkCapacity(schema: StructType, options: DataSourceOptions): Option[Int] = { val maxBytes = options.getLong(MAX_MEMORY_SINK_BYTES, MAX_MEMORY_SINK_BYTES_DEFAULT) val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) val sizePerRow = EstimationUtils.getSizePerRow(schema.toAttributes).longValue() @@ -280,7 +280,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo private var numRows = 0 /** The capacity in rows of this sink. */ - val sinkCapacity: Option[Int] = MemorySinkBase.getMaxRows(schema, options) + val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(schema, options) /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 67c20e14dfb3c..0c96b5edaedc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -138,7 +138,7 @@ class MemoryWriter( options: DataSourceOptions) extends DataSourceWriter with Logging { - val sinkCapacity: Option[Int] = MemorySinkBase.getMaxRows(schema, options) + val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(schema, options) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) @@ -161,7 +161,7 @@ class MemoryStreamWriter( options: DataSourceOptions) extends StreamWriter { - val sinkCapacity: Option[Int] = MemorySinkBase.getMaxRows(schema, options) + val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(schema, options) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) From 9097dd52bf654d7de059a0a0eaca961bd424f3cd Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Wed, 13 Jun 2018 13:36:08 -0700 Subject: [PATCH 08/16] Don't use byte limit, and log if we truncate rows --- .../sql/execution/streaming/memory.scala | 50 +++++++++---------- .../streaming/sources/memoryV2.scala | 30 +++++++---- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 616fa5bc9ee38..889201737c0e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -235,30 +235,16 @@ trait MemorySinkBase extends BaseStreamingSink { object MemorySinkBase { val MAX_MEMORY_SINK_ROWS = "maxMemorySinkRows" val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 - val MAX_MEMORY_SINK_BYTES = "maxMemorySinkBytes" - val MAX_MEMORY_SINK_BYTES_DEFAULT = -1L /** - * Gets the max number of rows a MemorySink should store. This number is based on the lesser of - * the memory sink row limit or the memory sink byte limit, if either is set. If not, there is - * no limit. - * @param schema The row schema, for use in computing size per row. - * @param options Options for writing from which we get the max rows or bytes. + * Gets the max number of rows a MemorySink should store. This number is based on the memory + * sink row limit if it is set. If not, there is no limit. + * @param options Options for writing from which we get the max rows option * @return The maximum number of rows a memorySink should store, or None for no limit. */ - def getMemorySinkCapacity(schema: StructType, options: DataSourceOptions): Option[Int] = { - val maxBytes = options.getLong(MAX_MEMORY_SINK_BYTES, MAX_MEMORY_SINK_BYTES_DEFAULT) + def getMemorySinkCapacity(options: DataSourceOptions): Option[Int] = { val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) - val sizePerRow = EstimationUtils.getSizePerRow(schema.toAttributes).longValue() - if (maxBytes >= 0 && maxRows >= 0) { - Some(math.min(maxRows, (maxBytes / sizePerRow).asInstanceOf[Int])) - } else if (maxBytes >= 0) { - Some((maxBytes / sizePerRow).asInstanceOf[Int]) - } else if (maxRows >= 0) { - Some(maxRows) - } else { - None - } + if (maxRows >= 0) Some(maxRows) else None } } @@ -280,7 +266,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo private var numRows = 0 /** The capacity in rows of this sink. */ - val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(schema, options) + val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(options) /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { @@ -314,20 +300,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val newRows = data.collect() + var rowsToAdd = data.collect() synchronized { - val rowsToAdd = - if (sinkCapacity.isDefined) newRows.take(sinkCapacity.get - numRows) else newRows + if (sinkCapacity.isDefined) { + val rowsRemaining = sinkCapacity.get - numRows + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, rowsRemaining, batchId) + } val rows = AddedData(batchId, rowsToAdd) batches += rows numRows += rowsToAdd.length } case Complete => - val newRows = data.collect() + var rowsToAdd = data.collect() synchronized { - val rowsToAdd = - if (sinkCapacity.isDefined) newRows.take(sinkCapacity.get) else newRows + if (sinkCapacity.isDefined) { + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity.get, batchId) + } val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows @@ -348,6 +337,15 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo numRows = 0 } + def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { + if (rows.length > maxRows) { + logWarning(s"Truncating batch $batchId to $maxRows rows") + rows.take(maxRows) + } else { + rows + } + } + override def toString(): String = "MemorySink" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 0c96b5edaedc8..b7316f9cef412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, schema, mode, options) + new MemoryStreamWriter(this, mode, options) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -94,8 +94,11 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB outputMode match { case Append | Update => synchronized { - val rowsToAdd = - if (sinkCapacity.isDefined) newRows.take(sinkCapacity.get - numRows) else newRows + var rowsToAdd = newRows + if (sinkCapacity.isDefined) { + val rowsRemaining = sinkCapacity.get - numRows + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, rowsRemaining, batchId) + } val rows = AddedData(batchId, rowsToAdd) batches += rows numRows += rowsToAdd.length @@ -103,8 +106,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB case Complete => synchronized { - val rowsToAdd = - if (sinkCapacity.isDefined) newRows.take(sinkCapacity.get) else newRows + var rowsToAdd = newRows + if (sinkCapacity.isDefined) { + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity.get, batchId) + } val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows @@ -125,6 +130,15 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB numRows = 0 } + def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { + if (rows.length > maxRows) { + logWarning(s"Truncating batch $batchId to $maxRows rows") + rows.take(maxRows) + } else { + rows + } + } + override def toString(): String = "MemorySinkV2" } @@ -133,12 +147,11 @@ case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends Wri class MemoryWriter( sink: MemorySinkV2, batchId: Long, - schema: StructType, outputMode: OutputMode, options: DataSourceOptions) extends DataSourceWriter with Logging { - val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(schema, options) + val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(options) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) @@ -156,12 +169,11 @@ class MemoryWriter( class MemoryStreamWriter( val sink: MemorySinkV2, - schema: StructType, outputMode: OutputMode, options: DataSourceOptions) extends StreamWriter { - val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(schema, options) + val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(options) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) From a28fb38053395c04a72b5d79f1f12a3aa5d49972 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Wed, 13 Jun 2018 13:36:21 -0700 Subject: [PATCH 09/16] Update tests --- .../execution/streaming/MemorySinkSuite.scala | 26 +++++ .../streaming/MemorySinkV2Suite.scala | 110 +++++------------- 2 files changed, 53 insertions(+), 83 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 193fcec4b5dab..b2fd6ba27ebb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -164,6 +164,32 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 7 to 9) } + test("directly add data in Complete output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Complete, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 10) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 8) + checkAnswer(sink.allData, 4 to 8) // new data should replace old data + } + test("registering as a table in Append output mode") { val input = MemoryStream[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 013309668c695..97f9748a92545 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -45,9 +45,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - var schema = new StructType().add("value", IntegerType) - val writer = - new MemoryStreamWriter(sink, schema, OutputMode.Append(), DataSourceOptions.empty()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append(), DataSourceOptions.empty()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), @@ -69,8 +67,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - var schema = new StructType().add("value", IntegerType) - new MemoryWriter(sink, 0, schema, OutputMode.Append(), DataSourceOptions.empty()) + new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()) .commit(Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -78,7 +75,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, schema, OutputMode.Append(), DataSourceOptions.empty()) + new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()) .commit(Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -91,123 +88,70 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer with row limit") { val sink = new MemorySinkV2 - var schema = new StructType().add("value", IntegerType) val optionsMap = new scala.collection.mutable.HashMap[String, String] optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString()) val options = new DataSourceOptions(optionsMap.toMap.asJava) - val writer = new MemoryStreamWriter(sink, schema, OutputMode.Append(), options) - writer.commit(0, - Array( + val appendWriter = new MemoryStreamWriter(sink, OutputMode.Append(), options) + appendWriter.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) - )) + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))))) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - writer.commit(19, - Array( + appendWriter.commit(19, Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))) - )) + MemoryWriterCommitMessage(0, Seq(Row(33))))) assert(sink.latestBatchId.contains(19)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11)) assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11)) + + val completeWriter = new MemoryStreamWriter(sink, OutputMode.Complete(), options) + completeWriter.commit(20, Array( + MemoryWriterCommitMessage(4, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(5, Seq(Row(33))))) + assert(sink.latestBatchId.contains(20)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) + completeWriter.commit(21, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2), Row(3))), + MemoryWriterCommitMessage(1, Seq(Row(4), Row(5), Row(6))), + MemoryWriterCommitMessage(2, Seq(Row(7), Row(8), Row(9))))) + assert(sink.latestBatchId.contains(21)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) } test("microbatch writer with row limit") { val sink = new MemorySinkV2 - var schema = new StructType().add("value", IntegerType) val optionsMap = new scala.collection.mutable.HashMap[String, String] optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) val options = new DataSourceOptions(optionsMap.toMap.asJava) - new MemoryWriter(sink, 25, schema, OutputMode.Append(), options).commit(Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) - assert(sink.latestBatchId.contains(25)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) - new MemoryWriter(sink, 26, schema, OutputMode.Append(), options).commit(Array( - MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), - MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) - assert(sink.latestBatchId.contains(26)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) - - new MemoryWriter(sink, 27, schema, OutputMode.Complete(), options).commit(Array( - MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), - MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) - assert(sink.latestBatchId.contains(27)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) - new MemoryWriter(sink, 28, schema, OutputMode.Complete(), options).commit(Array( - MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), - MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) - assert(sink.latestBatchId.contains(28)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) - } - - test("microbatch writer with byte limit") { - val sink = new MemorySinkV2 - var schema = new StructType().add("value", IntegerType) - val optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_BYTES, 60.toString()) - val options = new DataSourceOptions(optionsMap.toMap.asJava) - - new MemoryWriter(sink, 25, schema, OutputMode.Append(), options).commit(Array( + new MemoryWriter(sink, 25, OutputMode.Append(), options).commit(Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) assert(sink.latestBatchId.contains(25)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) - new MemoryWriter(sink, 26, schema, OutputMode.Append(), options).commit(Array( + new MemoryWriter(sink, 26, OutputMode.Append(), options).commit(Array( MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) assert(sink.latestBatchId.contains(26)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) - new MemoryWriter(sink, 27, schema, OutputMode.Complete(), options).commit(Array( + new MemoryWriter(sink, 27, OutputMode.Complete(), options).commit(Array( MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) assert(sink.latestBatchId.contains(27)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) - new MemoryWriter(sink, 28, schema, OutputMode.Complete(), options).commit(Array( + new MemoryWriter(sink, 28, OutputMode.Complete(), options).commit(Array( MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) assert(sink.latestBatchId.contains(28)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) } - - test("microbatch writer with row and byte limit") { - val sink = new MemorySinkV2 - var schema = new StructType().add("value", IntegerType) - - var optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 3.toString()) - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_BYTES, 400.toString()) - var options = new DataSourceOptions(optionsMap.toMap.asJava) - new MemoryWriter(sink, 25, schema, OutputMode.Complete(), options).commit(Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) - assert(sink.latestBatchId.contains(25)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3)) - - optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 10.toString()) - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_BYTES, 36.toString()) - options = new DataSourceOptions(optionsMap.toMap.asJava) - new MemoryWriter(sink, 26, schema, OutputMode.Complete(), options).commit(Array( - MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), - MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) - assert(sink.latestBatchId.contains(26)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5, 6, 7)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(5, 6, 7)) - - } } From f981cb818ffc95ddce2b59fcd64142615037b6a3 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Wed, 13 Jun 2018 13:50:43 -0700 Subject: [PATCH 10/16] minor refactor --- .../org/apache/spark/sql/execution/streaming/memory.scala | 3 +-- .../spark/sql/execution/streaming/sources/memoryV2.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 889201737c0e1..5f8cee0924ca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -303,8 +303,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo var rowsToAdd = data.collect() synchronized { if (sinkCapacity.isDefined) { - val rowsRemaining = sinkCapacity.get - numRows - rowsToAdd = truncateRowsIfNeeded(rowsToAdd, rowsRemaining, batchId) + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity.get - numRows, batchId) } val rows = AddedData(batchId, rowsToAdd) batches += rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index b7316f9cef412..1d0d335f410f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -96,8 +96,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB synchronized { var rowsToAdd = newRows if (sinkCapacity.isDefined) { - val rowsRemaining = sinkCapacity.get - numRows - rowsToAdd = truncateRowsIfNeeded(rowsToAdd, rowsRemaining, batchId) + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity.get - numRows, batchId) } val rows = AddedData(batchId, rowsToAdd) batches += rows From 74d5b6b4203aadcbf7339304b63d69382da7bf57 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Wed, 13 Jun 2018 15:00:49 -0700 Subject: [PATCH 11/16] fixed indenting a bit --- .../spark/sql/execution/streaming/MemorySinkV2Suite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 97f9748a92545..e539510e15755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -67,16 +67,16 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()) - .commit(Array( + new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()).commit( + Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()) - .commit(Array( + new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()).commit( + Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) )) From 4ab9bdaea895f6d0c76ee9ddd44c131f499eaec5 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Wed, 13 Jun 2018 15:06:50 -0700 Subject: [PATCH 12/16] make helper methods private --- .../scala/org/apache/spark/sql/execution/streaming/memory.scala | 2 +- .../apache/spark/sql/execution/streaming/sources/memoryV2.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 5f8cee0924ca7..ee52ce23868f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -336,7 +336,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo numRows = 0 } - def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { + private def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { if (rows.length > maxRows) { logWarning(s"Truncating batch $batchId to $maxRows rows") rows.take(maxRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 1d0d335f410f8..189da639a8a62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -129,7 +129,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB numRows = 0 } - def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { + private def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { if (rows.length > maxRows) { logWarning(s"Truncating batch $batchId to $maxRows rows") rows.take(maxRows) From b2ef59c40e58cdd6efdb0f5414f16ac5358bc99a Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Wed, 13 Jun 2018 16:45:32 -0700 Subject: [PATCH 13/16] Add additional safeguard that we don't call take() on a negative number --- .../scala/org/apache/spark/sql/execution/streaming/memory.scala | 2 +- .../apache/spark/sql/execution/streaming/sources/memoryV2.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index ee52ce23868f4..da6ef585b32c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -337,7 +337,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo } private def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { - if (rows.length > maxRows) { + if (rows.length > maxRows && maxRows >= 0) { logWarning(s"Truncating batch $batchId to $maxRows rows") rows.take(maxRows) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 189da639a8a62..9a7084722c280 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -130,7 +130,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB } private def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { - if (rows.length > maxRows) { + if (rows.length > maxRows && maxRows >= 0) { logWarning(s"Truncating batch $batchId to $maxRows rows") rows.take(maxRows) } else { From 25d6de1db8223975ebd9b69c7ca77c26e3d8674c Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Thu, 14 Jun 2018 11:50:36 -0700 Subject: [PATCH 14/16] Move truncate method to parent class --- .../sql/execution/streaming/memory.scala | 27 ++++++++++++------- .../streaming/sources/memoryV2.scala | 9 ------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index da6ef585b32c9..83356d08a9750 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -222,11 +222,27 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow]) } /** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink { +trait MemorySinkBase extends BaseStreamingSink with Logging { def allData: Seq[Row] def latestBatchData: Seq[Row] def dataSinceBatch(sinceBatchId: Long): Seq[Row] def latestBatchId: Option[Long] + + /** + * Truncates the given rows to return at most maxRows rows. + * @param rows The data that may need to be truncated. + * @param maxRows Number of rows to truncate to keep. + * @param batchId The ID of the batch that sent these rows, for logging purposes. + * @return Truncated rows. + */ + protected def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { + if (rows.length > maxRows && maxRows >= 0) { + logWarning(s"Truncating batch $batchId to $maxRows rows") + rows.take(maxRows) + } else { + rows + } + } } /** @@ -336,15 +352,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo numRows = 0 } - private def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { - if (rows.length > maxRows && maxRows >= 0) { - logWarning(s"Truncating batch $batchId to $maxRows rows") - rows.take(maxRows) - } else { - rows - } - } - override def toString(): String = "MemorySink" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 9a7084722c280..c49a6e423547a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -129,15 +129,6 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB numRows = 0 } - private def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { - if (rows.length > maxRows && maxRows >= 0) { - logWarning(s"Truncating batch $batchId to $maxRows rows") - rows.take(maxRows) - } else { - rows - } - } - override def toString(): String = "MemorySinkV2" } From e5b6175f0b638dc7235e4d3b610284a761e01480 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Thu, 14 Jun 2018 14:35:22 -0700 Subject: [PATCH 15/16] Address Burak's comments --- .../sql/execution/streaming/memory.scala | 33 ++++++++++--------- .../streaming/sources/memoryV2.scala | 22 ++++++------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 83356d08a9750..f9369d8858f32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -231,14 +231,19 @@ trait MemorySinkBase extends BaseStreamingSink with Logging { /** * Truncates the given rows to return at most maxRows rows. * @param rows The data that may need to be truncated. - * @param maxRows Number of rows to truncate to keep. + * @param batchLimit Number of rows to keep in this batch; the rest will be truncated + * @param sinkLimit Total number of rows kept in this sink, for logging purposes. * @param batchId The ID of the batch that sent these rows, for logging purposes. * @return Truncated rows. */ - protected def truncateRowsIfNeeded(rows: Array[Row], maxRows: Int, batchId: Long): Array[Row] = { - if (rows.length > maxRows && maxRows >= 0) { - logWarning(s"Truncating batch $batchId to $maxRows rows") - rows.take(maxRows) + protected def truncateRowsIfNeeded( + rows: Array[Row], + batchLimit: Int, + sinkLimit: Int, + batchId: Long): Array[Row] = { + if (rows.length > batchLimit && batchLimit >= 0) { + logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit") + rows.take(batchLimit) } else { rows } @@ -249,7 +254,7 @@ trait MemorySinkBase extends BaseStreamingSink with Logging { * Companion object to MemorySinkBase. */ object MemorySinkBase { - val MAX_MEMORY_SINK_ROWS = "maxMemorySinkRows" + val MAX_MEMORY_SINK_ROWS = "maxRows" val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 /** @@ -258,13 +263,12 @@ object MemorySinkBase { * @param options Options for writing from which we get the max rows option * @return The maximum number of rows a memorySink should store, or None for no limit. */ - def getMemorySinkCapacity(options: DataSourceOptions): Option[Int] = { + def getMemorySinkCapacity(options: DataSourceOptions): Int = { val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) - if (maxRows >= 0) Some(maxRows) else None + if (maxRows >= 0) maxRows else Int.MaxValue - 10 } } - /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. @@ -282,7 +286,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo private var numRows = 0 /** The capacity in rows of this sink. */ - val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(options) + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { @@ -318,9 +322,8 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo case Append | Update => var rowsToAdd = data.collect() synchronized { - if (sinkCapacity.isDefined) { - rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity.get - numRows, batchId) - } + rowsToAdd = + truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId) val rows = AddedData(batchId, rowsToAdd) batches += rows numRows += rowsToAdd.length @@ -329,9 +332,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo case Complete => var rowsToAdd = data.collect() synchronized { - if (sinkCapacity.isDefined) { - rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity.get, batchId) - } + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId) val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index c49a6e423547a..47b482007822d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -84,8 +84,11 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB }.mkString("\n") } - def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row], sinkCapacity: Option[Int]) - : Unit = { + def write( + batchId: Long, + outputMode: OutputMode, + newRows: Array[Row], + sinkCapacity: Int): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } @@ -94,10 +97,8 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB outputMode match { case Append | Update => synchronized { - var rowsToAdd = newRows - if (sinkCapacity.isDefined) { - rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity.get - numRows, batchId) - } + val rowsToAdd = + truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId) val rows = AddedData(batchId, rowsToAdd) batches += rows numRows += rowsToAdd.length @@ -105,10 +106,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB case Complete => synchronized { - var rowsToAdd = newRows - if (sinkCapacity.isDefined) { - rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity.get, batchId) - } + val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId) val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows @@ -141,7 +139,7 @@ class MemoryWriter( options: DataSourceOptions) extends DataSourceWriter with Logging { - val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(options) + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) @@ -163,7 +161,7 @@ class MemoryStreamWriter( options: DataSourceOptions) extends StreamWriter { - val sinkCapacity: Option[Int] = MemorySinkBase.getMemorySinkCapacity(options) + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) From 0402b6042b6f0b773a17d2bc6d30eda1c46dd731 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Fri, 15 Jun 2018 09:59:35 -0700 Subject: [PATCH 16/16] fix documentation --- .../org/apache/spark/sql/execution/streaming/memory.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index f9369d8858f32..7fa13c4aa2c01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -259,9 +259,10 @@ object MemorySinkBase { /** * Gets the max number of rows a MemorySink should store. This number is based on the memory - * sink row limit if it is set. If not, there is no limit. + * sink row limit option if it is set. If not, we use a large value so that data truncates + * rather than causing out of memory errors. * @param options Options for writing from which we get the max rows option - * @return The maximum number of rows a memorySink should store, or None for no limit. + * @return The maximum number of rows a memorySink should store. */ def getMemorySinkCapacity(options: DataSourceOptions): Int = { val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT)