Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization


/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
* [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
Expand Down Expand Up @@ -70,13 +69,16 @@ object OffsetSeq {
* bound the lateness of data that will processed. Time unit: milliseconds
* @param batchTimestampMs: The current batch processing timestamp.
* Time unit: milliseconds
* @param conf: Additional conf_s to be persisted across batches, e.g. number of shuffle partitions.
*/
case class OffsetSeqMetadata(var batchWatermarkMs: Long = 0, var batchTimestampMs: Long = 0) {
case class OffsetSeqMetadata(
batchWatermarkMs: Long = 0,
batchTimestampMs: Long = 0,
conf: Map[String, String] = Map.empty) {
def json: String = Serialization.write(this)(OffsetSeqMetadata.format)
}

object OffsetSeqMetadata {
private implicit val format = Serialization.formats(NoTypeHints)
def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
}

Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.StreamingExplainCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
import org.apache.spark.util.{Clock, UninterruptibleThread, Utils}

Expand Down Expand Up @@ -117,7 +118,9 @@ class StreamExecution(
}

/** Metadata associated with the offset seq of a batch in the query. */
protected var offsetSeqMetadata = OffsetSeqMetadata()
protected var offsetSeqMetadata =
OffsetSeqMetadata(conf = Map(SQLConf.SHUFFLE_PARTITIONS.key ->
sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS).toString))

override val id: UUID = UUID.fromString(streamMetadata.id)

Expand Down Expand Up @@ -256,6 +259,10 @@ class StreamExecution(
updateStatusMessage("Initializing sources")
// force initialization of the logical plan so that the sources can be created
logicalPlan

// Isolated spark session to run the batches with.
val sparkSessionToRunBatches = sparkSession.cloneSession()

if (state.compareAndSet(INITIALIZING, ACTIVE)) {
// Unblock `awaitInitialization`
initializationLatch.countDown()
Expand All @@ -276,7 +283,7 @@ class StreamExecution(
if (dataAvailable) {
currentStatus = currentStatus.copy(isDataAvailable = true)
updateStatusMessage("Processing new data")
runBatch()
runBatch(sparkSessionToRunBatches)
}
}

Expand Down Expand Up @@ -387,7 +394,29 @@ class StreamExecution(
logInfo(s"Resuming streaming query, starting with batch $batchId")
currentBatchId = batchId
availableOffsets = nextOffsets.toStreamProgress(sources)
offsetSeqMetadata = nextOffsets.metadata.getOrElse(OffsetSeqMetadata())

// initialize metadata
val shufflePartitionsSparkSession: Int = sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)
offsetSeqMetadata = {
if (nextOffsets.metadata.isEmpty) {
OffsetSeqMetadata(
batchWatermarkMs = 0,
batchTimestampMs = 0,
conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitionsSparkSession.toString))
} else {
val metadata = nextOffsets.metadata.get
val shufflePartitionsToUse = metadata.conf.getOrElse(SQLConf.SHUFFLE_PARTITIONS.key, {
// For backward compatibility, if # partitions was not recorded in the offset log,
// then ensure it is not missing. The new value is picked up from the conf.
logWarning("Number of shuffle partitions from previous run not found in checkpoint. "
+ s"Using the value from the conf, $shufflePartitionsSparkSession partitions.")
shufflePartitionsSparkSession
})
OffsetSeqMetadata(metadata.batchWatermarkMs, metadata.batchTimestampMs,
metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitionsToUse.toString))
}
}

logDebug(s"Found possibly unprocessed offsets $availableOffsets " +
s"at batch timestamp ${offsetSeqMetadata.batchTimestampMs}")

Expand Down Expand Up @@ -444,25 +473,27 @@ class StreamExecution(
}
}
if (hasNewData) {
// Current batch timestamp in milliseconds
offsetSeqMetadata.batchTimestampMs = triggerClock.getTimeMillis()
var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs
// Update the eventTime watermark if we find one in the plan.
if (lastExecution != null) {
lastExecution.executedPlan.collect {
case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 =>
logDebug(s"Observed event time stats: ${e.eventTimeStats.value}")
e.eventTimeStats.value.max - e.delayMs
}.headOption.foreach { newWatermarkMs =>
if (newWatermarkMs > offsetSeqMetadata.batchWatermarkMs) {
if (newWatermarkMs > batchWatermarkMs) {
logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms")
offsetSeqMetadata.batchWatermarkMs = newWatermarkMs
batchWatermarkMs = newWatermarkMs
} else {
logDebug(
s"Event time didn't move: $newWatermarkMs < " +
s"${offsetSeqMetadata.batchWatermarkMs}")
s"$batchWatermarkMs")
}
}
}
offsetSeqMetadata = offsetSeqMetadata.copy(
batchWatermarkMs = batchWatermarkMs,
batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds

updateStatusMessage("Writing offsets to log")
reportTimeTaken("walCommit") {
Expand Down Expand Up @@ -505,8 +536,9 @@ class StreamExecution(

/**
* Processes any data available between `availableOffsets` and `committedOffsets`.
* @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with.
*/
private def runBatch(): Unit = {
private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = {
// Request unprocessed data from all sources.
newData = reportTimeTaken("getBatch") {
availableOffsets.flatMap {
Expand Down Expand Up @@ -549,9 +581,15 @@ class StreamExecution(
cd.dataType, cd.timeZoneId)
}

// Reset confs to disallow change in number of partitions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need to set the confs for every batch? You can set it after recovering offsetSeqMetadata.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, changed.

sparkSessionToRunBatch.conf.set(
SQLConf.SHUFFLE_PARTITIONS.key,
offsetSeqMetadata.conf(SQLConf.SHUFFLE_PARTITIONS.key))
sparkSessionToRunBatch.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")

reportTimeTaken("queryPlanning") {
lastExecution = new IncrementalExecution(
sparkSession,
sparkSessionToRunBatch,
triggerLogicalPlan,
outputMode,
checkpointFile("state"),
Expand All @@ -561,7 +599,7 @@ class StreamExecution(
}

val nextBatch =
new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema))
new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema))

reportTimeTaken("addBatch") {
sink.addBatch(currentBatchId, nextBatch)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"id":"dddc5e7f-1e71-454c-8362-de184444fb5a"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
v1
{"batchWatermarkMs":0,"batchTimestampMs":1489180207737}
0
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
v1
{"batchWatermarkMs":0,"batchTimestampMs":1489180209261}
2
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.File

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.stringToFile
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext {
Expand All @@ -29,12 +30,32 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext {
case class StringOffset(override val json: String) extends Offset

test("OffsetSeqMetadata - deserialization") {
assert(OffsetSeqMetadata(0, 0) === OffsetSeqMetadata("""{}"""))
assert(OffsetSeqMetadata(1, 0) === OffsetSeqMetadata("""{"batchWatermarkMs":1}"""))
assert(OffsetSeqMetadata(0, 2) === OffsetSeqMetadata("""{"batchTimestampMs":2}"""))
assert(
OffsetSeqMetadata(1, 2) ===
OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
val key = SQLConf.SHUFFLE_PARTITIONS.key

def getConfWith(shufflePartitions: Int): Map[String, String] = {
Map(key -> shufflePartitions.toString)
}

// None set
assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}"""))

// One set
assert(OffsetSeqMetadata(1, 0, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1}"""))
assert(OffsetSeqMetadata(0, 2, Map.empty) === OffsetSeqMetadata("""{"batchTimestampMs":2}"""))
assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) ===
OffsetSeqMetadata(s"""{"conf": {"$key":2}}"""))

// Two set
assert(OffsetSeqMetadata(1, 2, Map.empty) ===
OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) ===
OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}"""))
assert(OffsetSeqMetadata(0, 2, getConfWith(shufflePartitions = 3)) ===
OffsetSeqMetadata(s"""{"batchTimestampMs":2,"conf": {"$key":3}}"""))

// All set
assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) ===
OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}}"""))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you add a test to verify that unknown fields don't break the serialization? Such as

    assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) ===
      OffsetSeqMetadata(
        s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}},"unknown":1"""))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

}

test("OffsetSeqLog - serialization - deserialization") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@

package org.apache.spark.sql.streaming

import java.io.{InterruptedIOException, IOException}
import java.io.{File, InterruptedIOException, IOException}
import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}

import scala.reflect.ClassTag
import scala.util.control.ControlThrowable

import org.apache.commons.io.FileUtils

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

Expand Down Expand Up @@ -389,6 +392,102 @@ class StreamSuite extends StreamTest {
query.stop()
assert(query.exception.isEmpty)
}

test("SPARK-19873: streaming aggregation with change in number of partitions") {
val inputData = MemoryStream[(Int, Int)]
val agg = inputData.toDS().groupBy("_1").count()

testStream(agg, OutputMode.Complete())(
AddData(inputData, (1, 0), (2, 0)),
StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "2")),
CheckAnswer((1, 1), (2, 1)),
StopStream,
AddData(inputData, (3, 0), (2, 0)),
StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "5")),
CheckAnswer((1, 1), (2, 2), (3, 1)),
StopStream,
AddData(inputData, (3, 0), (1, 0)),
StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "1")),
CheckAnswer((1, 2), (2, 2), (3, 2)))
}

test("recover from a Spark v2.1 checkpoint") {
var inputData: MemoryStream[Int] = null
var query: DataStreamWriter[Row] = null

def prepareMemoryStream(): Unit = {
inputData = MemoryStream[Int]
inputData.addData(1, 2, 3, 4)
inputData.addData(3, 4, 5, 6)
inputData.addData(5, 6, 7, 8)

query = inputData
.toDF()
.groupBy($"value")
.agg(count("*"))
.writeStream
.outputMode("complete")
.format("memory")
}

// Get an existing checkpoint generated by Spark v2.1.
// v2.1 does not record # shuffle partitions in the offset metadata.
val resourceUri =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment saying that start the query with existing checkpoints generated by 2.1 which do not have shuffle partitions recorded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more comments.

this.getClass.getResource("/structured-streaming/checkpoint-version-2.1.0").toURI
val checkpointDir = new File(resourceUri)

// 1 - Test if recovery from the checkpoint is successful.
prepareMemoryStream()
withTempDir { dir =>
// Copy the checkpoint to a temp dir to prevent changes to the original.
// Not doing this will lead to the test passing on the first run, but fail subsequent runs.
FileUtils.copyDirectory(checkpointDir, dir)

// Checkpoint data was generated by a query with 10 shuffle partitions.
// In order to test reading from the checkpoint, the checkpoint must have two or more batches,
// since the last batch may be rerun.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
var streamingQuery: StreamingQuery = null
try {
streamingQuery =
query.queryName("counts").option("checkpointLocation", dir.getCanonicalPath).start()
streamingQuery.processAllAvailable()
inputData.addData(9)
streamingQuery.processAllAvailable()

QueryTest.checkAnswer(spark.table("counts").toDF(),
Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) ::
Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil)
} finally {
if (streamingQuery ne null) {
streamingQuery.stop()
}
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you dont seem to stop the query? would be good put a try .. finally within the withSQLConf to stop the query. otherwise can lead to cascaded failures.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added try .. finally


// 2 - Check recovery with wrong num shuffle partitions
prepareMemoryStream()
withTempDir { dir =>
FileUtils.copyDirectory(checkpointDir, dir)

// Since the number of partitions is greater than 10, should throw exception.
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") {
var streamingQuery: StreamingQuery = null
try {
intercept[StreamingQueryException] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the error message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

streamingQuery =
query.queryName("badQuery").option("checkpointLocation", dir.getCanonicalPath).start()
streamingQuery.processAllAvailable()
}
} finally {
if (streamingQuery ne null) {
streamingQuery.stop()
}
}
}
}
}
}

abstract class FakeSource extends StreamSourceProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit
import scala.concurrent.duration._

import org.apache.hadoop.fs.Path
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter

Expand Down Expand Up @@ -370,21 +371,22 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
.option("checkpointLocation", checkpointLocationURI.toString)
.trigger(ProcessingTime(10.seconds))
.start()
q.processAllAvailable()
q.stop()

verify(LastOptions.mockStreamSourceProvider).createSource(
spark.sqlContext,
s"$checkpointLocationURI/sources/0",
None,
"org.apache.spark.sql.streaming.test",
Map.empty)
any(),
meq(s"$checkpointLocationURI/sources/0"),
meq(None),
meq("org.apache.spark.sql.streaming.test"),
meq(Map.empty))

verify(LastOptions.mockStreamSourceProvider).createSource(
spark.sqlContext,
s"$checkpointLocationURI/sources/1",
None,
"org.apache.spark.sql.streaming.test",
Map.empty)
any(),
meq(s"$checkpointLocationURI/sources/1"),
meq(None),
meq("org.apache.spark.sql.streaming.test"),
meq(Map.empty))
}

private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath
Expand Down