-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-48772][SS][SQL] State Data Source Change Feed Reader Mode #47188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
1ade442
98bf8ec
fb890ae
db45c6f
1926e5e
24c0351
d4a4b80
42552ac
24db837
adde991
d3ca86c
5199c56
ce75133
84dcf15
22a086b
c797d0b
5921479
e5674cf
c012e1a
ff0cd43
2ad7590
43420f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DI | |
| import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} | ||
| import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} | ||
| import org.apache.spark.sql.sources.DataSourceRegister | ||
| import org.apache.spark.sql.types.{IntegerType, StructType} | ||
| import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} | ||
| import org.apache.spark.sql.util.CaseInsensitiveStringMap | ||
| import org.apache.spark.util.SerializableConfiguration | ||
|
|
||
|
|
@@ -94,10 +94,21 @@ class StateDataSource extends TableProvider with DataSourceRegister { | |
| manager.readSchemaFile() | ||
| } | ||
|
|
||
| new StructType() | ||
| .add("key", keySchema) | ||
| .add("value", valueSchema) | ||
| .add("partition_id", IntegerType) | ||
| if (sourceOptions.readChangeFeed) { | ||
| new StructType() | ||
| .add("key", keySchema) | ||
| .add("value", valueSchema) | ||
| .add("change_type", StringType) | ||
| .add("batch_id", LongType) | ||
| .add("partition_id", IntegerType) | ||
| } else { | ||
| new StructType() | ||
| .add("key", keySchema) | ||
| .add("value", valueSchema) | ||
| .add("partition_id", IntegerType) | ||
| } | ||
|
|
||
eason-yuchen-liu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| } catch { | ||
| case NonFatal(e) => | ||
| throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e) | ||
|
|
@@ -132,7 +143,10 @@ case class StateSourceOptions( | |
| storeName: String, | ||
| joinSide: JoinSideValues, | ||
| snapshotStartBatchId: Option[Long], | ||
| snapshotPartitionId: Option[Int]) { | ||
| snapshotPartitionId: Option[Int], | ||
|
||
| readChangeFeed: Boolean, | ||
| changeStartBatchId: Option[Long], | ||
| changeEndBatchId: Option[Long]) { | ||
| def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) | ||
|
|
||
| override def toString: String = { | ||
|
|
@@ -151,6 +165,9 @@ object StateSourceOptions extends DataSourceOptions { | |
| val JOIN_SIDE = newOption("joinSide") | ||
| val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId") | ||
| val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId") | ||
| val READ_CHANGE_FEED = newOption("readChangeFeed") | ||
| val CHANGE_START_BATCH_ID = newOption("changeStartBatchId") | ||
| val CHANGE_END_BATCH_ID = newOption("changeEndBatchId") | ||
|
|
||
| object JoinSideValues extends Enumeration { | ||
| type JoinSideValues = Value | ||
|
|
@@ -231,9 +248,45 @@ object StateSourceOptions extends DataSourceOptions { | |
| throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID) | ||
| } | ||
|
|
||
| val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean) | ||
|
|
||
| val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong) | ||
| var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong) | ||
|
|
||
| if (readChangeFeed) { | ||
| if (joinSide != JoinSideValues.none) { | ||
| throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, READ_CHANGE_FEED)) | ||
| } | ||
| if (changeStartBatchId.isEmpty) { | ||
| throw StateDataSourceErrors.requiredOptionUnspecified(CHANGE_START_BATCH_ID) | ||
| } | ||
| changeEndBatchId = Option(changeEndBatchId.getOrElse(batchId)) | ||
|
||
|
|
||
| // changeStartBatchId and changeEndBatchId must all be defined at this point | ||
| if (changeStartBatchId.get < 0) { | ||
| throw StateDataSourceErrors.invalidOptionValueIsNegative(CHANGE_START_BATCH_ID) | ||
| } | ||
| if (changeEndBatchId.get < changeStartBatchId.get) { | ||
| throw StateDataSourceErrors.invalidOptionValue(CHANGE_END_BATCH_ID, | ||
| s"$CHANGE_END_BATCH_ID cannot be smaller than $CHANGE_START_BATCH_ID. " + | ||
| s"Please check the input to $CHANGE_END_BATCH_ID, or if you are using its default " + | ||
| s"value, make sure that $CHANGE_START_BATCH_ID is less than ${changeEndBatchId.get}.") | ||
| } | ||
| } else { | ||
| if (changeStartBatchId.isDefined) { | ||
| throw StateDataSourceErrors.invalidOptionValue(CHANGE_START_BATCH_ID, | ||
| s"Only specify this option when $READ_CHANGE_FEED is set to true.") | ||
| } | ||
| if (changeEndBatchId.isDefined) { | ||
| throw StateDataSourceErrors.invalidOptionValue(CHANGE_END_BATCH_ID, | ||
| s"Only specify this option when $READ_CHANGE_FEED is set to true.") | ||
| } | ||
| } | ||
|
|
||
| StateSourceOptions( | ||
| resolvedCpLocation, batchId, operatorId, storeName, | ||
| joinSide, snapshotStartBatchId, snapshotPartitionId) | ||
| joinSide, snapshotStartBatchId, snapshotPartitionId, | ||
| readChangeFeed, changeStartBatchId, changeEndBatchId) | ||
| } | ||
|
|
||
| private def resolvedCheckpointLocation( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,9 @@ import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, Par | |
| import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry | ||
| import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil | ||
| import org.apache.spark.sql.execution.streaming.state._ | ||
| import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType} | ||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
| import org.apache.spark.util.SerializableConfiguration | ||
|
|
||
| /** | ||
|
|
@@ -37,8 +39,14 @@ class StatePartitionReaderFactory( | |
| stateStoreMetadata: Array[StateMetadataTableEntry]) extends PartitionReaderFactory { | ||
|
|
||
| override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { | ||
| new StatePartitionReader(storeConf, hadoopConf, | ||
| partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata) | ||
| val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] | ||
| if (stateStoreInputPartition.sourceOptions.readChangeFeed) { | ||
| new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, | ||
| partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata) | ||
eason-yuchen-liu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } else { | ||
| new StatePartitionReader(storeConf, hadoopConf, | ||
| partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata) | ||
eason-yuchen-liu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -57,7 +65,7 @@ class StatePartitionReader( | |
| private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] | ||
| private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType] | ||
|
|
||
| private lazy val provider: StateStoreProvider = { | ||
| protected lazy val provider: StateStoreProvider = { | ||
| val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, | ||
| partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) | ||
| val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) | ||
|
|
@@ -104,11 +112,11 @@ class StatePartitionReader( | |
| } | ||
| } | ||
|
|
||
| private lazy val iter: Iterator[InternalRow] = { | ||
| protected lazy val iter: Iterator[InternalRow] = { | ||
| store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value))) | ||
| } | ||
|
|
||
| private var current: InternalRow = _ | ||
| protected var current: InternalRow = _ | ||
|
|
||
| override def next(): Boolean = { | ||
| if (iter.hasNext) { | ||
|
|
@@ -136,3 +144,48 @@ class StatePartitionReader( | |
| row | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * An implementation of [[PartitionReader]] for the readChangeFeed mode of State Data Source. | ||
| * It reads the change of state over batches of a particular partition. | ||
| */ | ||
| class StateStoreChangeDataPartitionReader( | ||
| storeConf: StateStoreConf, | ||
| hadoopConf: SerializableConfiguration, | ||
| partition: StateStoreInputPartition, | ||
| schema: StructType, | ||
| stateStoreMetadata: Array[StateMetadataTableEntry]) | ||
| extends StatePartitionReader(storeConf, hadoopConf, partition, schema, stateStoreMetadata) { | ||
|
|
||
| private lazy val changeDataReader: StateStoreChangeDataReader = { | ||
| if (!provider.isInstanceOf[SupportsFineGrainedReplay]) { | ||
| throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( | ||
| provider.getClass.toString) | ||
| } | ||
| provider.asInstanceOf[SupportsFineGrainedReplay] | ||
| .getStateStoreChangeDataReader( | ||
| partition.sourceOptions.changeStartBatchId.get + 1, | ||
| partition.sourceOptions.changeEndBatchId.get + 1) | ||
| } | ||
|
|
||
| override protected lazy val iter: Iterator[InternalRow] = { | ||
|
||
| changeDataReader.iterator.map(unifyStateChangeDataRow) | ||
| } | ||
|
|
||
| override def close(): Unit = { | ||
| current = null | ||
| changeDataReader.close() | ||
| provider.close() | ||
| } | ||
|
|
||
| private def unifyStateChangeDataRow(row: (RecordType, UnsafeRow, UnsafeRow, Long)): | ||
| InternalRow = { | ||
| val result = new GenericInternalRow(5) | ||
| result.update(0, row._2) | ||
| result.update(1, row._3) | ||
| result.update(2, UTF8String.fromString(getRecordTypeAsString(row._1))) | ||
| result.update(3, row._4) | ||
| result.update(4, partition.partition) | ||
| result | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.Jo | |
| import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry | ||
| import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil | ||
| import org.apache.spark.sql.execution.streaming.state.StateStoreConf | ||
| import org.apache.spark.sql.types.{IntegerType, StructType} | ||
| import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} | ||
| import org.apache.spark.sql.util.CaseInsensitiveStringMap | ||
| import org.apache.spark.util.ArrayImplicits._ | ||
|
|
||
|
|
@@ -76,6 +76,9 @@ class StateTable( | |
| override def properties(): util.Map[String, String] = Map.empty[String, String].asJava | ||
|
|
||
| private def isValidSchema(schema: StructType): Boolean = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My proposal could handle both non-CDF and CDF altogether in the single flow - this still needs a divergence and also every column has its own if or else if. Have you tried with my proposal?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry overlooked the code. It is indeed more elegant. Thanks for the suggestion. |
||
| if (sourceOptions.readChangeFeed) { | ||
| return isValidChangeDataSchema(schema) | ||
| } | ||
|
||
| if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) { | ||
| false | ||
| } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { | ||
|
|
@@ -89,6 +92,25 @@ class StateTable( | |
| } | ||
| } | ||
|
|
||
| private def isValidChangeDataSchema(schema: StructType): Boolean = { | ||
| if (schema.fieldNames.toImmutableArraySeq != | ||
| Seq("key", "value", "change_type", "batch_id", "partition_id")) { | ||
| false | ||
| } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { | ||
| false | ||
| } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { | ||
| false | ||
| } else if (!SchemaUtil.getSchemaAsDataType(schema, "change_type").isInstanceOf[StringType]) { | ||
| false | ||
| } else if (!SchemaUtil.getSchemaAsDataType(schema, "batch_id").isInstanceOf[LongType]) { | ||
| false | ||
| } else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) { | ||
| false | ||
| } else { | ||
| true | ||
| } | ||
| } | ||
|
|
||
| override def metadataColumns(): Array[MetadataColumn] = Array.empty | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect
change_typeandbatch_idto begin with, and even batch ID to be placed earlier (batch_id, change_type).Given the characteristic of change feed, the output is expected to be ordered by batch ID (among partition IDs, which may be uneasy), or even the data source does not do so, users should be able to do so easily because they will high likely do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.