Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,8 @@ case class StateSourceOptions(
stateVarName: Option[String],
readRegisteredTimers: Boolean,
flattenCollectionTypes: Boolean,
operatorStateUniqueIds: Option[Array[Array[String]]] = None) {
startOperatorStateUniqueIds: Option[Array[Array[String]]] = None,
endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
Expand Down Expand Up @@ -576,29 +577,46 @@ object StateSourceOptions extends DataSourceOptions {
batchId.get
}

val operatorStateUniqueIds = getOperatorStateUniqueIds(
val endBatchId = if (readChangeFeedOptions.isDefined) {
readChangeFeedOptions.get.changeEndBatchId
} else {
batchId.get
}

val startOperatorStateUniqueIds = getOperatorStateUniqueIds(
sparkSession,
startBatchId,
operatorId,
resolvedCpLocation)

if (operatorStateUniqueIds.isDefined) {
val endOperatorStateUniqueIds = if (startBatchId == endBatchId) {
startOperatorStateUniqueIds
} else {
getOperatorStateUniqueIds(
sparkSession,
endBatchId,
operatorId,
resolvedCpLocation)
}

if (startOperatorStateUniqueIds.isDefined != endOperatorStateUniqueIds.isDefined) {
throw StateDataSourceErrors.internalError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm - this is backed by an error class correct ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is. But I changed it to be backed by a new error class STDS_MIXED_CHECKPOINT_FORMAT_VERSIONS_NOT_SUPPORTED

"Reading source across different checkpoint format versions is not supported.")
}

if (startOperatorStateUniqueIds.isDefined) {
if (fromSnapshotOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID,
"Snapshot reading is currently not supported with checkpoint v2.")
}
if (readChangeFeedOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
READ_CHANGE_FEED,
"Read change feed is currently not supported with checkpoint v2.")
}
}

StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
stateVarName, readRegisteredTimers, flattenCollectionTypes, operatorStateUniqueIds)
stateVarName, readRegisteredTimers, flattenCollectionTypes,
startOperatorStateUniqueIds, endOperatorStateUniqueIds)
}

private def resolvedCheckpointLocation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,20 @@ abstract class StatePartitionReaderBase(
schema, "value").asInstanceOf[StructType]
}

protected val getStoreUniqueId : Option[String] = {
protected def getStoreUniqueId(
operatorStateUniqueIds: Option[Array[Array[String]]]) : Option[String] = {
SymmetricHashJoinStateManager.getStateStoreCheckpointId(
storeName = partition.sourceOptions.storeName,
partitionId = partition.partition,
stateStoreCkptIds = partition.sourceOptions.operatorStateUniqueIds)
stateStoreCkptIds = operatorStateUniqueIds)
}

protected def getStartStoreUniqueId: Option[String] = {
getStoreUniqueId(partition.sourceOptions.startOperatorStateUniqueIds)
}

protected def getEndStoreUniqueId: Option[String] = {
getStoreUniqueId(partition.sourceOptions.endOperatorStateUniqueIds)
}

protected lazy val provider: StateStoreProvider = {
Expand All @@ -123,7 +132,7 @@ abstract class StatePartitionReaderBase(
if (useColFamilies) {
val store = provider.getStore(
partition.sourceOptions.batchId + 1,
getStoreUniqueId)
getEndStoreUniqueId)
require(stateStoreColFamilySchemaOpt.isDefined)
val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get
require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined)
Expand Down Expand Up @@ -182,9 +191,11 @@ class StatePartitionReader(
private lazy val store: ReadStateStore = {
partition.sourceOptions.fromSnapshotOptions match {
case None =>
assert(getStartStoreUniqueId == getEndStoreUniqueId,
"Start and end store unique IDs must be the same when not reading from snapshot")
provider.getReadStore(
partition.sourceOptions.batchId + 1,
getStoreUniqueId
getStartStoreUniqueId
)

case Some(fromSnapshotOptions) =>
Expand Down Expand Up @@ -261,7 +272,8 @@ class StateStoreChangeDataPartitionReader(
.getStateStoreChangeDataReader(
partition.sourceOptions.readChangeFeedOptions.get.changeStartBatchId + 1,
partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1,
colFamilyNameOpt)
colFamilyNameOpt,
getEndStoreUniqueId)
}

override lazy val iter: Iterator[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,22 @@ class StreamStreamJoinStatePartitionReader(
partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId)

private val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
private val startStateStoreCheckpointIds =
SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partition.partition,
partition.sourceOptions.operatorStateUniqueIds,
partition.sourceOptions.startOperatorStateUniqueIds,
usesVirtualColumnFamilies)

private val keyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) {
stateStoreCheckpointIds.left.keyToNumValues
startStateStoreCheckpointIds.left.keyToNumValues
} else {
stateStoreCheckpointIds.right.keyToNumValues
startStateStoreCheckpointIds.right.keyToNumValues
}

private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) {
stateStoreCheckpointIds.left.keyWithIndexToValue
startStateStoreCheckpointIds.left.keyWithIndexToValue
} else {
stateStoreCheckpointIds.right.keyWithIndexToValue
startStateStoreCheckpointIds.right.keyWithIndexToValue
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ case class TransformWithStateInPySparkExec(
store.abort()
}
}
setStoreMetrics(store)
setStoreMetrics(store, isStreaming)
setOperatorMetrics()
}).map { row =>
numOutputRows += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,14 @@ trait StateStoreWriter
* Set the SQL metrics related to the state store.
* This should be called in that task after the store has been updated.
*/
protected def setStoreMetrics(store: StateStore): Unit = {
protected def setStoreMetrics(store: StateStore, setCheckpointInfo: Boolean = true): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm why do we need this change ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In setStoreMetrics we call store.getStateStoreCheckpointInfo(). If we call this in the store.abort() case in TransformWithStateExec or TransformWithStateInPySparkExec it will throw an exception since the checkpoint info does not exist since we never committed. https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala#L343

val storeMetrics = store.metrics
longMetric("numTotalStateRows") += storeMetrics.numKeys
longMetric("stateMemory") += storeMetrics.memoryUsedBytes
setStoreCustomMetrics(storeMetrics.customMetrics)
setStoreInstanceMetrics(storeMetrics.instanceMetrics)

if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf)) {
if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf) && setCheckpointInfo) {
// Set the state store checkpoint information for the driver to collect
val ssInfo = store.getStateStoreCheckpointInfo()
setStateStoreCheckpointInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ case class TransformWithStateExec(
store.abort()
}
}
setStoreMetrics(store)
setStoreMetrics(store, isStreaming)
setOperatorMetrics()
closeStatefulProcessor()
statefulProcessor.setHandle(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,15 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
override def getStateStoreChangeDataReader(
startVersion: Long,
endVersion: Long,
colFamilyNameOpt: Option[String] = None):
colFamilyNameOpt: Option[String] = None,
endVersionStateStoreCkptId: Option[String] = None):
StateStoreChangeDataReader = {

if (endVersionStateStoreCkptId.isDefined) {
throw QueryExecutionErrors.cannotLoadStore(new SparkException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make a new error condition for this (and change the other place where we do this)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added error class STATE_STORE_CHECKPOINT_IDS_NOT_SUPPORTED and used it here and the other places.

"HDFSBackedStateStoreProvider does not support endVersionStateStoreCkptId"))
}

// Multiple column families are not supported with HDFSBackedStateStoreProvider
if (colFamilyNameOpt.isDefined) {
throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName)
Expand Down Expand Up @@ -1099,7 +1106,7 @@ class HDFSBackedStateStoreChangeDataReader(
extends StateStoreChangeDataReader(
fm, stateLocation, startVersion, endVersion, compressionCodec) {

override protected var changelogSuffix: String = "delta"
override protected val changelogSuffix: String = "delta"

override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
val reader = currentChangelogReader()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,51 @@ class RocksDB(
currLineage
}

/**
* Construct the full lineage from startVersion to endVersion (inclusive) by
* walking backwards using lineage information embedded in changelog files.
*/
def getFullLineage(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some unit tests for this new function? The logic seems quite complicated, want to make sure we can test all edge cases. Particularly the error cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added RocksDBLineageSuite.scala that covers the main error cases.

startVersion: Long,
endVersion: Long,
endVersionStateStoreCkptId: Option[String]): Array[LineageItem] = {
assert(startVersion <= endVersion,
s"startVersion $startVersion should be less than or equal to endVersion $endVersion")

// A buffer to collect the lineage information, the entries should be decreasing in version
val buf = mutable.ArrayBuffer[LineageItem]()
buf.append(LineageItem(endVersion, endVersionStateStoreCkptId.get))

while (buf.last.version > startVersion) {
val prevSmallestVersion = buf.last.version
val lineage = getLineageFromChangelogFile(buf.last.version, Some(buf.last.checkpointUniqueId))
// lineage array is sorted in increasing order, we need to reverse it
val lineageSorted = lineage.filter(_.version >= startVersion).sortBy(_.version).reverse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just pass descending as the negative key or Ordering param ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to.sortBy(-_.version)

// append to the buffer in reverse order, so the buffer is always decreasing in version
buf.appendAll(lineageSorted)

// to prevent infinite loop if we make no progress, throw an exception
if (buf.last.version == prevSmallestVersion) {
throw new IllegalStateException(s"Lineage is not complete")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create an error class for this ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Created INVALID_CHECKPOINT_LINEAGE.

}
}

// we return the lineage in increasing order
val ret = buf.reverse.toArray

// Sanity checks
assert(ret.head.version == startVersion,
s"Expected first lineage version to be $startVersion, but got ${ret.head.version}")
assert(ret.last.version == endVersion,
s"Expected last lineage version to be $endVersion, but got ${ret.last.version}")
// Assert that the lineage array is strictly increasing in version
assert(ret.sliding(2).forall {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this to an error class as well ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made these also use INVALID_CHECKPOINT_LINEAGE.

case Array(prev, next) => prev.version + 1 == next.version
case _ => true
}, s"Lineage array is not strictly increasing in version")

ret
}

/**
* Load the given version of data in a native RocksDB instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -877,15 +877,18 @@ private[sql] class RocksDBStateStoreProvider
override def getStateStoreChangeDataReader(
startVersion: Long,
endVersion: Long,
colFamilyNameOpt: Option[String] = None):
colFamilyNameOpt: Option[String] = None,
endVersionStateStoreCkptId: Option[String] = None):
StateStoreChangeDataReader = {
val statePath = stateStoreId.storeCheckpointLocation()
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
new RocksDBStateStoreChangeDataReader(
CheckpointFileManager.create(statePath, hadoopConf),
rocksDB,
statePath,
startVersion,
endVersion,
endVersionStateStoreCkptId,
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
keyValueEncoderMap,
colFamilyNameOpt)
Expand Down Expand Up @@ -1224,17 +1227,32 @@ object RocksDBStateStoreProvider {
/** [[StateStoreChangeDataReader]] implementation for [[RocksDBStateStoreProvider]] */
class RocksDBStateStoreChangeDataReader(
fm: CheckpointFileManager,
rocksDB: RocksDB,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, seems a little strange to me that we are passing in RocksDB here in its entirety just so we can use getFullLineage. Is there a way to abstract out the getFullLineage functionality so we can reuse it a different way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially had done some refactoring to refactor all the lineage related methods to RocksDBFileManager and only pass that in here. I did not do it in this PR just to reduce the amount of changes in this PR.

At a glance, all the lineage related methods (getChangelogReader, getLineageFromChangelogFile) exist in either RocksDB or RocksDBFilemanager. We should be able to abstract these methods out into something like ChangelogFileManager.scala since changelog lineage stuff is not directly dependent on RocksDB related methods.

I am not sure if we want to do this refactoring in the PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that's fine. We don't have to do it in this PR

stateLocation: Path,
startVersion: Long,
endVersion: Long,
endVersionStateStoreCkptId: Option[String],
compressionCodec: CompressionCodec,
keyValueEncoderMap:
ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder, Short)],
colFamilyNameOpt: Option[String] = None)
extends StateStoreChangeDataReader(
fm, stateLocation, startVersion, endVersion, compressionCodec, colFamilyNameOpt) {

override protected var changelogSuffix: String = "changelog"
override protected val versionsAndUniqueIds: Array[(Long, Option[String])] =
if (endVersionStateStoreCkptId.isDefined) {
val fullVersionLineage = rocksDB.getFullLineage(
startVersion,
endVersion,
endVersionStateStoreCkptId)
fullVersionLineage
.sortBy(_.version)
.map(item => (item.version, Some(item.checkpointUniqueId)))
} else {
(startVersion to endVersion).map((_, None)).toArray
}

override protected val changelogSuffix: String = "changelog"

override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
var currRecord: (RecordType.Value, Array[Byte], Array[Byte]) = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,8 @@ trait SupportsFineGrainedReplay {
def getStateStoreChangeDataReader(
startVersion: Long,
endVersion: Long,
colFamilyNameOpt: Option[String] = None):
colFamilyNameOpt: Option[String] = None,
endVersionStateStoreCkptId: Option[String] = None):
NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ abstract class StateStoreChangelogReader(
Serialization.read[Array[LineageItem]](lineageStr)
}

// The array contains lineage information from [snapShotVersion, version]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both left and right inclusive right ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a mistake it is actually [snapShotVersion, version). The version that was used to get this array is not included. I updated the comment to make this more clear.

lazy val lineage: Array[LineageItem] = readLineage()

def version: Short
Expand Down Expand Up @@ -632,27 +633,41 @@ abstract class StateStoreChangeDataReader(
* Iterator that iterates over the changelog files in the state store.
*/
private class ChangeLogFileIterator extends Iterator[Path] {
val versionsAndUniqueIds: Iterator[(Long, Option[String])] =
StateStoreChangeDataReader.this.versionsAndUniqueIds.iterator

private var currentVersion = StateStoreChangeDataReader.this.startVersion - 1
private var currentUniqueId: Option[String] = None

/** returns the version of the changelog returned by the latest [[next]] function call */
def getVersion: Long = currentVersion

override def hasNext: Boolean = currentVersion < StateStoreChangeDataReader.this.endVersion
override def hasNext: Boolean = versionsAndUniqueIds.hasNext

override def next(): Path = {
currentVersion += 1
getChangelogPath(currentVersion)
val nextTuple = versionsAndUniqueIds.next()
currentVersion = nextTuple._1
currentUniqueId = nextTuple._2
getChangelogPath(currentVersion, currentUniqueId)
}

private def getChangelogPath(version: Long): Path =
new Path(
StateStoreChangeDataReader.this.stateLocation,
s"$version.${StateStoreChangeDataReader.this.changelogSuffix}")
private def getChangelogPath(version: Long, checkpointUniqueId: Option[String]): Path =
if (checkpointUniqueId.isDefined) {
new Path(
StateStoreChangeDataReader.this.stateLocation,
s"${version}_${checkpointUniqueId.get}." +
s"${StateStoreChangeDataReader.this.changelogSuffix}")
} else {
new Path(
StateStoreChangeDataReader.this.stateLocation,
s"$version.${StateStoreChangeDataReader.this.changelogSuffix}")
}
}

/** file format of the changelog files */
protected var changelogSuffix: String
protected val changelogSuffix: String
protected val versionsAndUniqueIds: Array[(Long, Option[String])] =
(startVersion to endVersion).map((_, None)).toArray
private lazy val fileIterator = new ChangeLogFileIterator
private var changelogReader: StateStoreChangelogReader = null

Expand All @@ -671,11 +686,10 @@ abstract class StateStoreChangeDataReader(
return null
}

changelogReader = if (colFamilyNameOpt.isDefined) {
new StateStoreChangelogReaderV2(fm, fileIterator.next(), compressionCodec)
} else {
new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec)
}
val changelogFile = fileIterator.next()
changelogReader =
new StateStoreChangelogReaderFactory(fm, changelogFile, compressionCodec)
.constructChangelogReader()
}
changelogReader
}
Expand Down
Loading