diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 234b0c3ed02d..bd06518a0c76 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -368,6 +368,11 @@ "The change log writer version cannot be ." ] }, + "INVALID_CHECKPOINT_LINEAGE" : { + "message" : [ + "Invalid checkpoint lineage: . " + ] + }, "KEY_ROW_FORMAT_VALIDATION_FAILURE" : { "message" : [ "" @@ -5168,6 +5173,12 @@ ], "sqlState" : "42802" }, + "STATE_STORE_CHECKPOINT_IDS_NOT_SUPPORTED" : { + "message" : [ + "" + ], + "sqlState" : "KD002" + }, "STATE_STORE_CHECKPOINT_LOCATION_NOT_EMPTY" : { "message" : [ "The checkpoint location should be empty on batch 0", @@ -5413,6 +5424,14 @@ }, "sqlState" : "42616" }, + "STDS_MIXED_CHECKPOINT_FORMAT_VERSIONS_NOT_SUPPORTED" : { + "message" : [ + "Reading state across different checkpoint format versions is not supported.", + "startBatchId=, endBatchId=.", + "startFormatVersion=, endFormatVersion=." + ], + "sqlState" : "KD002" + }, "STDS_NO_PARTITION_DISCOVERED_IN_STATE_STORE" : { "message" : [ "The state does not have any partition. Please double check that the query points to the valid state. options: " diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index d3bda545e1c9..527fb7d370e7 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -1916,6 +1916,30 @@ def conf(cls): return cfg +class TransformWithStateInPandasWithCheckpointV2TestsMixin(TransformWithStateInPandasTestsMixin): + @classmethod + def conf(cls): + cfg = super().conf() + cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2") + return cfg + + # TODO(SPARK-53332): Add test back when checkpoint v2 support exists for snapshotStartBatchId + def test_transform_with_value_state_metadata(self): + pass + + +class TransformWithStateInPySparkWithCheckpointV2TestsMixin(TransformWithStateInPySparkTestsMixin): + @classmethod + def conf(cls): + cfg = super().conf() + cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2") + return cfg + + # TODO(SPARK-53332): Add test back when checkpoint v2 support exists for snapshotStartBatchId + def test_transform_with_value_state_metadata(self): + pass + + class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase): pass @@ -1924,6 +1948,18 @@ class TransformWithStateInPySparkTests(TransformWithStateInPySparkTestsMixin, Re pass +class TransformWithStateInPandasWithCheckpointV2Tests( + TransformWithStateInPandasWithCheckpointV2TestsMixin, ReusedSQLTestCase +): + pass + + +class TransformWithStateInPySparkWithCheckpointV2Tests( + TransformWithStateInPySparkWithCheckpointV2TestsMixin, ReusedSQLTestCase +): + pass + + if __name__ == "__main__": from pyspark.sql.tests.pandas.test_pandas_transform_with_state import * # noqa: F401 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index ba229a2e746c..67bb80403b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2740,6 +2740,17 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ) } + def invalidCheckpointLineage(lineage: String, message: String): Throwable = { + new SparkException( + errorClass = "CANNOT_LOAD_STATE_STORE.INVALID_CHECKPOINT_LINEAGE", + messageParameters = Map( + "lineage" -> lineage, + "message" -> message + ), + cause = null + ) + } + def notEnoughMemoryToLoadStore( stateStoreId: String, stateStoreProviderName: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 828c06ab834a..54d3c45d237b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -371,7 +371,8 @@ case class StateSourceOptions( stateVarName: Option[String], readRegisteredTimers: Boolean, flattenCollectionTypes: Boolean, - operatorStateUniqueIds: Option[Array[Array[String]]] = None) { + startOperatorStateUniqueIds: Option[Array[Array[String]]] = None, + endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { @@ -576,29 +577,52 @@ object StateSourceOptions extends DataSourceOptions { batchId.get } - val operatorStateUniqueIds = getOperatorStateUniqueIds( + val endBatchId = if (readChangeFeedOptions.isDefined) { + readChangeFeedOptions.get.changeEndBatchId + } else { + batchId.get + } + + val startOperatorStateUniqueIds = getOperatorStateUniqueIds( sparkSession, startBatchId, operatorId, resolvedCpLocation) - if (operatorStateUniqueIds.isDefined) { + val endOperatorStateUniqueIds = if (startBatchId == endBatchId) { + startOperatorStateUniqueIds + } else { + getOperatorStateUniqueIds( + sparkSession, + endBatchId, + operatorId, + resolvedCpLocation) + } + + if (startOperatorStateUniqueIds.isDefined != endOperatorStateUniqueIds.isDefined) { + val startFormatVersion = if (startOperatorStateUniqueIds.isDefined) 2 else 1 + val endFormatVersion = if (endOperatorStateUniqueIds.isDefined) 2 else 1 + throw StateDataSourceErrors.mixedCheckpointFormatVersionsNotSupported( + startBatchId, + endBatchId, + startFormatVersion, + endFormatVersion + ) + } + + if (startOperatorStateUniqueIds.isDefined) { if (fromSnapshotOptions.isDefined) { throw StateDataSourceErrors.invalidOptionValue( SNAPSHOT_START_BATCH_ID, "Snapshot reading is currently not supported with checkpoint v2.") } - if (readChangeFeedOptions.isDefined) { - throw StateDataSourceErrors.invalidOptionValue( - READ_CHANGE_FEED, - "Read change feed is currently not supported with checkpoint v2.") - } } StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, - stateVarName, readRegisteredTimers, flattenCollectionTypes, operatorStateUniqueIds) + stateVarName, readRegisteredTimers, flattenCollectionTypes, + startOperatorStateUniqueIds, endOperatorStateUniqueIds) } private def resolvedCheckpointLocation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceErrors.scala index b6883a98f3ed..74ab308131f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceErrors.scala @@ -81,6 +81,18 @@ object StateDataSourceErrors { sourceOptions: StateSourceOptions): StateDataSourceException = { new StateDataSourceNoPartitionDiscoveredInStateStore(sourceOptions) } + + def mixedCheckpointFormatVersionsNotSupported( + startBatchId: Long, + endBatchId: Long, + startFormatVersion: Int, + endFormatVersion: Int): StateDataSourceException = { + new StateDataSourceMixedCheckpointFormatVersionsNotSupported( + startBatchId, + endBatchId, + startFormatVersion, + endFormatVersion) + } } abstract class StateDataSourceException( @@ -172,3 +184,18 @@ class StateDataSourceReadOperatorMetadataFailure( "STDS_FAILED_TO_READ_OPERATOR_METADATA", Map("checkpointLocation" -> checkpointLocation, "batchId" -> batchId.toString), cause = null) + +class StateDataSourceMixedCheckpointFormatVersionsNotSupported( + startBatchId: Long, + endBatchId: Long, + startFormatVersion: Int, + endFormatVersion: Int) + extends StateDataSourceException( + "STDS_MIXED_CHECKPOINT_FORMAT_VERSIONS_NOT_SUPPORTED", + Map( + "startBatchId" -> startBatchId.toString, + "endBatchId" -> endBatchId.toString, + "startFormatVersion" -> startFormatVersion.toString, + "endFormatVersion" -> endFormatVersion.toString + ), + cause = null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index ebef6e3dac55..7180fe483fcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -96,11 +96,20 @@ abstract class StatePartitionReaderBase( schema, "value").asInstanceOf[StructType] } - protected val getStoreUniqueId : Option[String] = { + protected def getStoreUniqueId( + operatorStateUniqueIds: Option[Array[Array[String]]]) : Option[String] = { SymmetricHashJoinStateManager.getStateStoreCheckpointId( storeName = partition.sourceOptions.storeName, partitionId = partition.partition, - stateStoreCkptIds = partition.sourceOptions.operatorStateUniqueIds) + stateStoreCkptIds = operatorStateUniqueIds) + } + + protected def getStartStoreUniqueId: Option[String] = { + getStoreUniqueId(partition.sourceOptions.startOperatorStateUniqueIds) + } + + protected def getEndStoreUniqueId: Option[String] = { + getStoreUniqueId(partition.sourceOptions.endOperatorStateUniqueIds) } protected lazy val provider: StateStoreProvider = { @@ -123,7 +132,7 @@ abstract class StatePartitionReaderBase( if (useColFamilies) { val store = provider.getStore( partition.sourceOptions.batchId + 1, - getStoreUniqueId) + getEndStoreUniqueId) require(stateStoreColFamilySchemaOpt.isDefined) val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined) @@ -182,9 +191,11 @@ class StatePartitionReader( private lazy val store: ReadStateStore = { partition.sourceOptions.fromSnapshotOptions match { case None => + assert(getStartStoreUniqueId == getEndStoreUniqueId, + "Start and end store unique IDs must be the same when not reading from snapshot") provider.getReadStore( partition.sourceOptions.batchId + 1, - getStoreUniqueId + getStartStoreUniqueId ) case Some(fromSnapshotOptions) => @@ -261,7 +272,8 @@ class StateStoreChangeDataPartitionReader( .getStateStoreChangeDataReader( partition.sourceOptions.readChangeFeedOptions.get.changeStartBatchId + 1, partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1, - colFamilyNameOpt) + colFamilyNameOpt, + getEndStoreUniqueId) } override lazy val iter: Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index 0f8a3b3b609f..bf0e8968789c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -76,21 +76,22 @@ class StreamStreamJoinStatePartitionReader( partition.sourceOptions.stateCheckpointLocation.toString, partition.sourceOptions.operatorId) - private val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds( + private val startStateStoreCheckpointIds = + SymmetricHashJoinStateManager.getStateStoreCheckpointIds( partition.partition, - partition.sourceOptions.operatorStateUniqueIds, + partition.sourceOptions.startOperatorStateUniqueIds, usesVirtualColumnFamilies) private val keyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) { - stateStoreCheckpointIds.left.keyToNumValues + startStateStoreCheckpointIds.left.keyToNumValues } else { - stateStoreCheckpointIds.right.keyToNumValues + startStateStoreCheckpointIds.right.keyToNumValues } private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) { - stateStoreCheckpointIds.left.keyWithIndexToValue + startStateStoreCheckpointIds.left.keyWithIndexToValue } else { - stateStoreCheckpointIds.right.keyWithIndexToValue + startStateStoreCheckpointIds.right.keyWithIndexToValue } /* diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala index 1b967af38b6d..f8390b7d878f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala @@ -389,7 +389,7 @@ case class TransformWithStateInPySparkExec( store.abort() } } - setStoreMetrics(store) + setStoreMetrics(store, isStreaming) setOperatorMetrics() }).map { row => numOutputRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala index cc8d354a0393..0634a2f05b41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala @@ -430,14 +430,14 @@ trait StateStoreWriter * Set the SQL metrics related to the state store. * This should be called in that task after the store has been updated. */ - protected def setStoreMetrics(store: StateStore): Unit = { + protected def setStoreMetrics(store: StateStore, setCheckpointInfo: Boolean = true): Unit = { val storeMetrics = store.metrics longMetric("numTotalStateRows") += storeMetrics.numKeys longMetric("stateMemory") += storeMetrics.memoryUsedBytes setStoreCustomMetrics(storeMetrics.customMetrics) setStoreInstanceMetrics(storeMetrics.instanceMetrics) - if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf)) { + if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf) && setCheckpointInfo) { // Set the state store checkpoint information for the driver to collect val ssInfo = store.getStateStoreCheckpointInfo() setStateStoreCheckpointInfo( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala index 20e2c32015d8..52a0d470c266 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala @@ -346,7 +346,7 @@ case class TransformWithStateExec( store.abort() } } - setStoreMetrics(store) + setStoreMetrics(store, isStreaming) setOperatorMetrics() closeStatefulProcessor() statefulProcessor.setHandle(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index ac7f1a021960..f37a26012e22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -29,7 +29,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.internal.{Logging, LogKeys} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -292,9 +292,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with /** Get the state store for making updates to create a new `version` of the store. */ override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { if (uniqueId.isDefined) { - throw QueryExecutionErrors.cannotLoadStore(new SparkException( + throw StateStoreErrors.stateStoreCheckpointIdsNotSupported( "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1 " + - "but a state store checkpointID is passed in")) + "but a state store checkpointID is passed in") } val newMap = getLoadedMapForStore(version) logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, version)} " + @@ -369,10 +369,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with hadoopConf: Configuration, useMultipleValuesPerKey: Boolean = false, stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = { - assert( - !storeConf.enableStateStoreCheckpointIds, - "HDFS State Store Provider doesn't support checkpointFormatVersion >= 2 " + - s"checkpointFormatVersion ${storeConf.stateStoreCheckpointFormatVersion}") + if (storeConf.enableStateStoreCheckpointIds) { + throw StateStoreErrors.stateStoreCheckpointIdsNotSupported( + "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1") + } this.stateStoreId_ = stateStoreId this.keySchema = keySchema @@ -1064,8 +1064,16 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with override def getStateStoreChangeDataReader( startVersion: Long, endVersion: Long, - colFamilyNameOpt: Option[String] = None): + colFamilyNameOpt: Option[String] = None, + endVersionStateStoreCkptId: Option[String] = None): StateStoreChangeDataReader = { + + if (endVersionStateStoreCkptId.isDefined) { + throw StateStoreErrors.stateStoreCheckpointIdsNotSupported( + "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1 " + + "but a state store checkpointID is passed in") + } + // Multiple column families are not supported with HDFSBackedStateStoreProvider if (colFamilyNameOpt.isDefined) { throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) @@ -1099,7 +1107,7 @@ class HDFSBackedStateStoreChangeDataReader( extends StateStoreChangeDataReader( fm, stateLocation, startVersion, endVersion, compressionCodec) { - override protected var changelogSuffix: String = "delta" + override protected val changelogSuffix: String = "delta" override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { val reader = currentChangelogReader() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 85e2d72ec163..1e65b737e2bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -342,6 +342,63 @@ class RocksDB( currLineage } + /** + * Construct the full lineage from startVersion to endVersion (inclusive) by + * walking backwards using lineage information embedded in changelog files. + */ + def getFullLineage( + startVersion: Long, + endVersion: Long, + endVersionStateStoreCkptId: Option[String]): Array[LineageItem] = { + assert(startVersion <= endVersion, + s"startVersion $startVersion should be less than or equal to endVersion $endVersion") + assert(endVersionStateStoreCkptId.isDefined, + "endVersionStateStoreCkptId should be defined") + + // A buffer to collect the lineage information, the entries should be decreasing in version + val buf = mutable.ArrayBuffer[LineageItem]() + buf.append(LineageItem(endVersion, endVersionStateStoreCkptId.get)) + + while (buf.last.version > startVersion) { + val prevSmallestVersion = buf.last.version + val lineage = getLineageFromChangelogFile(buf.last.version, Some(buf.last.checkpointUniqueId)) + // lineage array is sorted in increasing order, we need to make it decreasing + val lineageSortedDecreasing = lineage.filter(_.version >= startVersion).sortBy(-_.version) + // append to the buffer in reverse order, so the buffer is always decreasing in version + buf.appendAll(lineageSortedDecreasing) + + // to prevent infinite loop if we make no progress, throw an exception + if (buf.last.version == prevSmallestVersion) { + throw QueryExecutionErrors.invalidCheckpointLineage(printLineageItems(buf.reverse.toArray), + s"Cannot find version smaller than ${buf.last.version} in lineage.") + } + } + + // we return the lineage in increasing order + val ret = buf.reverse.toArray + + // Sanity checks + if (ret.head.version != startVersion) { + throw QueryExecutionErrors.invalidCheckpointLineage(printLineageItems(ret), + s"Lineage does not start with startVersion: $startVersion.") + } + if (ret.last.version != endVersion) { + throw QueryExecutionErrors.invalidCheckpointLineage(printLineageItems(ret), + s"Lineage does not end with endVersion: $endVersion.") + } + // Verify that the lineage versions are increasing by one + // We do this by checking that each entry is one version higher than the previous one + val increasingByOne = ret.sliding(2).forall { + case Array(prev, next) => prev.version + 1 == next.version + case _ => true + } + if (!increasingByOne) { + throw QueryExecutionErrors.invalidCheckpointLineage(printLineageItems(ret), + "Lineage versions are not increasing by one.") + } + + ret + } /** * Load the given version of data in a native RocksDB instance. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 7098fd41f402..a1052c95e199 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -877,15 +877,18 @@ private[sql] class RocksDBStateStoreProvider override def getStateStoreChangeDataReader( startVersion: Long, endVersion: Long, - colFamilyNameOpt: Option[String] = None): + colFamilyNameOpt: Option[String] = None, + endVersionStateStoreCkptId: Option[String] = None): StateStoreChangeDataReader = { val statePath = stateStoreId.storeCheckpointLocation() val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) new RocksDBStateStoreChangeDataReader( CheckpointFileManager.create(statePath, hadoopConf), + rocksDB, statePath, startVersion, endVersion, + endVersionStateStoreCkptId, CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), keyValueEncoderMap, colFamilyNameOpt) @@ -1224,9 +1227,11 @@ object RocksDBStateStoreProvider { /** [[StateStoreChangeDataReader]] implementation for [[RocksDBStateStoreProvider]] */ class RocksDBStateStoreChangeDataReader( fm: CheckpointFileManager, + rocksDB: RocksDB, stateLocation: Path, startVersion: Long, endVersion: Long, + endVersionStateStoreCkptId: Option[String], compressionCodec: CompressionCodec, keyValueEncoderMap: ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder, Short)], @@ -1234,7 +1239,20 @@ class RocksDBStateStoreChangeDataReader( extends StateStoreChangeDataReader( fm, stateLocation, startVersion, endVersion, compressionCodec, colFamilyNameOpt) { - override protected var changelogSuffix: String = "changelog" + override protected val versionsAndUniqueIds: Array[(Long, Option[String])] = + if (endVersionStateStoreCkptId.isDefined) { + val fullVersionLineage = rocksDB.getFullLineage( + startVersion, + endVersion, + endVersionStateStoreCkptId) + fullVersionLineage + .sortBy(_.version) + .map(item => (item.version, Some(item.checkpointUniqueId))) + } else { + (startVersion to endVersion).map((_, None)).toArray + } + + override protected val changelogSuffix: String = "changelog" override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { var currRecord: (RecordType.Value, Array[Byte], Array[Byte]) = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 604a27866f62..f94eecd1dd42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -834,7 +834,8 @@ trait SupportsFineGrainedReplay { def getStateStoreChangeDataReader( startVersion: Long, endVersion: Long, - colFamilyNameOpt: Option[String] = None): + colFamilyNameOpt: Option[String] = None, + endVersionStateStoreCkptId: Option[String] = None): NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index 4c5dea63baea..792f22cc574d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -447,6 +447,7 @@ abstract class StateStoreChangelogReader( Serialization.read[Array[LineageItem]](lineageStr) } + // The array contains lineage information from [mostRecentSnapShotVersion, version - 1] inclusive lazy val lineage: Array[LineageItem] = readLineage() def version: Short @@ -632,27 +633,41 @@ abstract class StateStoreChangeDataReader( * Iterator that iterates over the changelog files in the state store. */ private class ChangeLogFileIterator extends Iterator[Path] { + val versionsAndUniqueIds: Iterator[(Long, Option[String])] = + StateStoreChangeDataReader.this.versionsAndUniqueIds.iterator private var currentVersion = StateStoreChangeDataReader.this.startVersion - 1 + private var currentUniqueId: Option[String] = None /** returns the version of the changelog returned by the latest [[next]] function call */ def getVersion: Long = currentVersion - override def hasNext: Boolean = currentVersion < StateStoreChangeDataReader.this.endVersion + override def hasNext: Boolean = versionsAndUniqueIds.hasNext override def next(): Path = { - currentVersion += 1 - getChangelogPath(currentVersion) + val nextTuple = versionsAndUniqueIds.next() + currentVersion = nextTuple._1 + currentUniqueId = nextTuple._2 + getChangelogPath(currentVersion, currentUniqueId) } - private def getChangelogPath(version: Long): Path = - new Path( - StateStoreChangeDataReader.this.stateLocation, - s"$version.${StateStoreChangeDataReader.this.changelogSuffix}") + private def getChangelogPath(version: Long, checkpointUniqueId: Option[String]): Path = + if (checkpointUniqueId.isDefined) { + new Path( + StateStoreChangeDataReader.this.stateLocation, + s"${version}_${checkpointUniqueId.get}." + + s"${StateStoreChangeDataReader.this.changelogSuffix}") + } else { + new Path( + StateStoreChangeDataReader.this.stateLocation, + s"$version.${StateStoreChangeDataReader.this.changelogSuffix}") + } } /** file format of the changelog files */ - protected var changelogSuffix: String + protected val changelogSuffix: String + protected val versionsAndUniqueIds: Array[(Long, Option[String])] = + (startVersion to endVersion).map((_, None)).toArray private lazy val fileIterator = new ChangeLogFileIterator private var changelogReader: StateStoreChangelogReader = null @@ -671,11 +686,10 @@ abstract class StateStoreChangeDataReader( return null } - changelogReader = if (colFamilyNameOpt.isDefined) { - new StateStoreChangelogReaderV2(fm, fileIterator.next(), compressionCodec) - } else { - new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec) - } + val changelogFile = fileIterator.next() + changelogReader = + new StateStoreChangelogReaderFactory(fm, changelogFile, compressionCodec) + .constructChangelogReader() } changelogReader } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 43682de03446..8a44f5c28456 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -259,6 +259,10 @@ object StateStoreErrors { QueryExecutionErrors.cannotLoadStore(e) } } + + def stateStoreCheckpointIdsNotSupported(msg: String): StateStoreCheckpointIdsNotSupported = { + new StateStoreCheckpointIdsNotSupported(msg) + } } trait ConvertableToCannotLoadStoreError { @@ -545,6 +549,12 @@ class StateStoreOperationOutOfOrder(errorMsg: String) messageParameters = Map("errorMsg" -> errorMsg) ) +class StateStoreCheckpointIdsNotSupported(msg: String) + extends SparkRuntimeException( + errorClass = "STATE_STORE_CHECKPOINT_IDS_NOT_SUPPORTED", + messageParameters = Map("msg" -> msg) + ) + class StateStoreCommitValidationFailed( batchId: Long, expectedCommits: Int, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index be19981dc8a8..a1be83627f31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -17,15 +17,18 @@ package org.apache.spark.sql.execution.datasources.v2.state +import java.io.File +import java.sql.Timestamp import java.util.UUID import org.apache.hadoop.conf.Configuration import org.scalatest.Assertions import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata} import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamExecution} import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -47,6 +50,14 @@ class RocksDBWithChangelogCheckpointStateDataSourceChangeDataReaderSuite extends } } +class RocksDBWithCheckpointV2StateDataSourceChangeDataReaderSuite extends + RocksDBWithChangelogCheckpointStateDataSourceChangeDataReaderSuite { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2") + } +} + abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestBase with Assertions { @@ -124,6 +135,39 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB } } + test("ERROR: mixed checkpoint format versions not supported") { + withTempDir { tempDir => + val commitLog = new CommitLog(spark, + new File(tempDir.getAbsolutePath, "commits").getAbsolutePath) + + // Start version: treated as v1 (no operator unique ids) + val startMetadata = CommitMetadata(0, None) + assert(commitLog.add(0, startMetadata)) + + // End version: treated as v2 (operator 0 has unique ids) + val endMetadata = CommitMetadata(0, + Some(Map[Long, Array[Array[String]]](0L -> Array(Array("uid"))))) + assert(commitLog.add(1, endMetadata)) + + val exc = intercept[StateDataSourceMixedCheckpointFormatVersionsNotSupported] { + spark.read.format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.READ_CHANGE_FEED, true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load() + } + + checkError(exc, "STDS_MIXED_CHECKPOINT_FORMAT_VERSIONS_NOT_SUPPORTED", "KD002", + Map( + "startBatchId" -> "0", + "endBatchId" -> "1", + "startFormatVersion" -> "1", + "endFormatVersion" -> "2" + )) + } + } + test("ERROR: joinSide option is used together with readChangeFeed") { withTempDir { tempDir => val exc = intercept[StateDataSourceConflictOptions] { @@ -139,11 +183,16 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB } test("getChangeDataReader of state store provider") { + val versionToCkptId = scala.collection.mutable.Map[Long, Option[String]]() + def withNewStateStore(provider: StateStoreProvider, version: Int)(f: StateStore => Unit): Unit = { - val stateStore = provider.getStore(version) + val stateStore = provider.getStore(version, versionToCkptId.getOrElse(version, None)) f(stateStore) stateStore.commit() + + val ssInfo = stateStore.getStateStoreCheckpointInfo() + versionToCkptId(ssInfo.batchVersion) = ssInfo.stateStoreCkptId } withTempDir { tempDir => @@ -158,7 +207,8 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB stateStore.remove(dataToKeyRow("b", 2)) } val reader = - provider.asInstanceOf[SupportsFineGrainedReplay].getStateStoreChangeDataReader(1, 4) + provider.asInstanceOf[SupportsFineGrainedReplay] + .getStateStoreChangeDataReader(1, 4, None, versionToCkptId.getOrElse(4, None)) assert(reader.next() === (RecordType.PUT_RECORD, dataToKeyRow("a", 1), dataToValueRow(1), 0L)) assert(reader.next() === (RecordType.PUT_RECORD, dataToKeyRow("b", 2), dataToValueRow(2), 1L)) @@ -322,4 +372,133 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB checkAnswer(keyToNumValuesDf, keyToNumValuesDfExpectedDf) } } + + test("read change feed past multiple snapshots") { + withSQLConf("spark.sql.streaming.stateStore.minDeltasForSnapshot" -> "2") { + withTempDir { tempDir => + val inputData = MemoryStream[Int] + val df = inputData.toDF().groupBy("value").count() + testStream(df, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 2, 3, 4, 1), + ProcessAllAvailable(), + AddData(inputData, 2, 3, 4, 5), + ProcessAllAvailable(), + AddData(inputData, 3, 4, 5, 6), + ProcessAllAvailable(), + AddData(inputData, 1, 1), + ProcessAllAvailable(), + AddData(inputData, 1, 1), + ProcessAllAvailable(), + AddData(inputData, 1, 1), + ProcessAllAvailable() + ) + + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 5) + .load(tempDir.getAbsolutePath) + + val expectedDf = Seq( + Row(0L, "update", Row(3), Row(1), 1), + Row(1L, "update", Row(3), Row(2), 1), + Row(1L, "update", Row(5), Row(1), 1), + Row(2L, "update", Row(3), Row(3), 1), + Row(2L, "update", Row(5), Row(2), 1), + Row(0L, "update", Row(4), Row(1), 2), + Row(1L, "update", Row(4), Row(2), 2), + Row(2L, "update", Row(4), Row(3), 2), + Row(0L, "update", Row(1), Row(2), 3), + Row(3L, "update", Row(1), Row(4), 3), + Row(4L, "update", Row(1), Row(6), 3), + Row(5L, "update", Row(1), Row(8), 3), + Row(0L, "update", Row(2), Row(1), 4), + Row(1L, "update", Row(2), Row(2), 4), + Row(2L, "update", Row(6), Row(1), 4) + ) + + checkAnswer(stateDf, expectedDf) + + val stateDf2 = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 1) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 3) + .load(tempDir.getAbsolutePath) + + val expectedDf2 = Seq( + Row(1L, "update", Row(3), Row(2), 1), + Row(1L, "update", Row(5), Row(1), 1), + Row(2L, "update", Row(3), Row(3), 1), + Row(2L, "update", Row(5), Row(2), 1), + Row(1L, "update", Row(4), Row(2), 2), + Row(2L, "update", Row(4), Row(3), 2), + Row(3L, "update", Row(1), Row(4), 3), + Row(1L, "update", Row(2), Row(2), 4), + Row(2L, "update", Row(6), Row(1), 4) + ) + + checkAnswer(stateDf2, expectedDf2) + + val stateDf3 = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 2) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 4) + .load(tempDir.getAbsolutePath) + + val expectedDf3 = Seq( + Row(2L, "update", Row(3), Row(3), 1), + Row(2L, "update", Row(5), Row(2), 1), + Row(2L, "update", Row(4), Row(3), 2), + Row(3L, "update", Row(1), Row(4), 3), + Row(4L, "update", Row(1), Row(6), 3), + Row(2L, "update", Row(6), Row(1), 4) + ) + + checkAnswer(stateDf3, expectedDf3) + } + } + } + + test("read change feed with delete entries") { + withTempDir { tempDir => + val inputData = MemoryStream[(Int, Timestamp)] + val df = inputData.toDF() + .selectExpr("_1 as key", "_2 as ts") + .withWatermark("ts", "1 second") + .groupBy(window(col("ts"), "1 second")) + .count() + + val ts0 = Timestamp.valueOf("2025-01-01 00:00:00") + val ts1 = Timestamp.valueOf("2025-01-01 00:00:01") + val ts2 = Timestamp.valueOf("2025-01-01 00:00:02") + val ts3 = Timestamp.valueOf("2025-01-01 00:00:03") + val ts4 = Timestamp.valueOf("2025-01-01 00:00:04") + + testStream(df, OutputMode.Append)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, (1, ts0), (2, ts0)), + ProcessAllAvailable(), + AddData(inputData, (3, ts2)), + ProcessAllAvailable(), + AddData(inputData, (4, ts3)), + ProcessAllAvailable(), + StopStream + ) + + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + + val expectedDf = Seq( + Row(0L, "update", Row(Row(ts0, ts1)), Row(2), 4), + Row(1L, "update", Row(Row(ts2, ts3)), Row(1), 1), + Row(2L, "delete", Row(Row(ts0, ts1)), null, 4), + Row(2L, "update", Row(Row(ts3, ts4)), Row(1), 4) + ) + + checkAnswer(stateDf, expectedDf) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index d744304afb42..59c67973a328 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -690,21 +690,6 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceR Map( "optionName" -> StateSourceOptions.SNAPSHOT_START_BATCH_ID, "message" -> "Snapshot reading is currently not supported with checkpoint v2.")) - - // Verify reading change feed throws error with checkpoint v2 - val exc2 = intercept[StateDataSourceInvalidOptionValue] { - val stateDf = spark.read.format("statestore") - .option(StateSourceOptions.READ_CHANGE_FEED, value = true) - .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) - .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) - .load(tmpDir.getAbsolutePath) - stateDf.collect() - } - - checkError(exc2, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", "42616", - Map( - "optionName" -> StateSourceOptions.READ_CHANGE_FEED, - "message" -> "Read change feed is currently not supported with checkpoint v2.")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index 1800319fb8b4..2061cf645a03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -1013,6 +1013,8 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest * the state data. */ testWithChangelogCheckpointingEnabled("snapshotStartBatchId with transformWithState") { + // TODO(SPARK-53332): Remove this line once snapshotStartBatchId is supported for V2 format + assume(SQLConf.get.stateStoreCheckpointFormatVersion == 1) class AggregationStatefulProcessor extends StatefulProcessor[Int, (Int, Long), (Int, Long)] { @transient protected var _countState: ValueState[Long] = _ @@ -1150,3 +1152,12 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } } + +class StateDataSourceTransformWithStateSuiteCheckpointV2 extends + StateDataSourceTransformWithStateSuite { + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLineageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLineageSuite.scala new file mode 100644 index 000000000000..48ef4158266b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLineageSuite.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.SparkException +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +class RocksDBLineageSuite extends SharedSparkSession { + private def newDB(root: String, enableCheckpointIds: Boolean): RocksDB = { + val conf = RocksDBConf().copy(enableChangelogCheckpointing = true) + new RocksDB( + root, + conf, + localRootDir = Utils.createTempDir(), + hadoopConf = new Configuration, + useColumnFamilies = false, + enableStateStoreCheckpointIds = enableCheckpointIds) + } + + private def writeChangelogWithLineage( + db: RocksDB, + version: Long, + uniqueId: String, + lineage: Array[LineageItem]): Unit = { + val writer = db.fileManager.getChangeLogWriter( + version, + useColumnFamilies = false, + checkpointUniqueId = Some(uniqueId), + stateStoreCheckpointIdLineage = Some(lineage)) + writer.commit() + } + + test("getFullLineage: single changelog covers full range") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + val start = 3L + val end = 5L + val id3 = "i3" + val id4 = "i4" + val id5 = "i5" + writeChangelogWithLineage(db, end, id5, Array(LineageItem(4, id4), LineageItem(3, id3))) + + val result = db.getFullLineage(start, end, Some(id5)) + assert(result.map(_.version).sameElements(Array(3L, 4L, 5L))) + assert(result.map(_.checkpointUniqueId).sameElements(Array(id3, id4, id5))) + } finally { + db.close() + } + } + } + + test("getFullLineage: multi-hop across changelog files") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + val start = 1L + val end = 5L + val id1 = "i1"; val id2 = "i2"; val id3 = "i3"; val id4 = "i4"; val id5 = "i5" + writeChangelogWithLineage(db, 3, id3, Array(LineageItem(2, id2), LineageItem(1, id1))) + writeChangelogWithLineage(db, 5, id5, Array(LineageItem(4, id4), LineageItem(3, id3))) + + val result = db.getFullLineage(start, end, Some(id5)) + assert(result.map(_.version).sameElements(Array(1L, 2L, 3L, 4L, 5L))) + assert(result.map(_.checkpointUniqueId).sameElements(Array(id1, id2, id3, id4, id5))) + } finally { + db.close() + } + } + } + + test("getFullLineage: multiple lineages exist for the same version") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + val start = 1L + val end = 5L + val id1 = "i1"; val id2 = "i2"; val id3 = "i3"; val id4 = "i4"; val id5 = "i5" + writeChangelogWithLineage(db, 3, id3, Array(LineageItem(2, id2), LineageItem(1, id1))) + writeChangelogWithLineage(db, 5, id5, Array(LineageItem(4, id4), LineageItem(3, id3))) + // Insert a bad lineage for version 5 + // We should not use this lineage since we call getFullLineage with id5 + val badId4 = id4 + "bad" + val badId5 = id5 + "bad" + writeChangelogWithLineage(db, 5, badId5, Array(LineageItem(4, badId4))) + + val result = db.getFullLineage(start, end, Some(id5)) + assert(result.map(_.version).sameElements(Array(1L, 2L, 3L, 4L, 5L))) + assert(result.map(_.checkpointUniqueId).sameElements(Array(id1, id2, id3, id4, id5))) + } finally { + db.close() + } + } + } + + test("getFullLineage: start equals end returns single item") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + val result = db.getFullLineage(7, 7, Some("i7")) + assert(result.map(_.version).sameElements(Array(7L))) + assert(result.map(_.checkpointUniqueId).sameElements(Array("i7"))) + } finally { + db.close() + } + } + } + + test("getFullLineage: missing intermediate version triggers validation error") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + writeChangelogWithLineage(db, 5, "i5", Array(LineageItem(3, "i3"))) + val ex = intercept[SparkException] { + db.getFullLineage(3, 5, Some("i5")) + } + checkError( + ex, + condition = "CANNOT_LOAD_STATE_STORE.INVALID_CHECKPOINT_LINEAGE", + parameters = Map( + "lineage" -> "3:i3 5:i5", + "message" -> "Lineage versions are not increasing by one." + ) + ) + } finally { + db.close() + } + } + } + + test("getFullLineage: no progress in lineage triggers guard error") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + writeChangelogWithLineage(db, 5, "i5", Array.empty) + val ex = intercept[SparkException] { + db.getFullLineage(3, 5, Some("i5")) + } + checkError( + ex, + condition = "CANNOT_LOAD_STATE_STORE.INVALID_CHECKPOINT_LINEAGE", + parameters = Map( + "lineage" -> "5:i5", + "message" -> "Cannot find version smaller than 5 in lineage." + ) + ) + } finally { + db.close() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 4d1e789a70b0..0b1483241b92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1400,7 +1400,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] val hadoopConf = new Configuration() hadoopConf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString) - val e = intercept[AssertionError] { + val e = intercept[StateStoreCheckpointIdsNotSupported] { provider.init( StateStoreId(newDir(), Random.nextInt(), 0), keySchema, @@ -1411,7 +1411,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] hadoopConf) } assert(e.getMessage.contains( - "HDFS State Store Provider doesn't support checkpointFormatVersion >= 2")) + "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1")) } override def newStoreProvider(): HDFSBackedStateStoreProvider = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index 5f4de279724a..1c8c567b73fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -755,3 +755,11 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest } } } + +class TransformWithStateInitialStateSuiteCheckpointV2 + extends TransformWithStateInitialStateSuite { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2) + } +}