Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri
import org.apache.avro.io.{DecoderFactory, EncoderFactory}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
Expand Down Expand Up @@ -60,15 +60,92 @@ sealed trait RocksDBValueStateEncoder {
* by the callers. The metadata in each row does not need to be written as Avro or UnsafeRow,
* but the actual data provided by the caller does.
*/
/** Interface for encoding and decoding state store data between UnsafeRow and raw bytes.
*
* @note All encode methods expect non-null input rows. Handling of null values is left to the
* implementing classes.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this in the right place?

trait DataEncoder {
/** Encodes a complete key row into bytes. Used as the primary key for state lookups.
*
* @param row An UnsafeRow containing all key columns as defined in the keySchema
* @return Serialized byte array representation of the key
*/
def encodeKey(row: UnsafeRow): Array[Byte]

/** Encodes the non-prefix portion of a key row. Used with prefix scan and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Text should start at the second line, not the first in scaladoc

* range scan state lookups where the key is split into prefix and remaining portions.
*
* For prefix scans: Encodes columns after the prefix columns
* For range scans: Encodes columns not included in the ordering columns
*
* @param row An UnsafeRow containing only the remaining key columns
* @return Serialized byte array of the remaining key portion
* @throws UnsupportedOperationException if called on an encoder that doesn't support split keys
*/
def encodeRemainingKey(row: UnsafeRow): Array[Byte]

/** Encodes key columns used for range scanning, ensuring proper sort order in RocksDB.
*
* This method handles special encoding for numeric types to maintain correct sort order:
* - Adds sign byte markers for numeric types
* - Flips bits for negative floating point values
* - Preserves null ordering
*
* @param row An UnsafeRow containing the columns needed for range scan
* (specified by orderingOrdinals)
* @return Serialized bytes that will maintain correct sort order in RocksDB
* @throws UnsupportedOperationException if called on an encoder that doesn't support range scans
*/
def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte]

/** Encodes a value row into bytes.
*
* @param row An UnsafeRow containing the value columns as defined in the valueSchema
* @return Serialized byte array representation of the value
*/
def encodeValue(row: UnsafeRow): Array[Byte]

/** Decodes a complete key from its serialized byte form.
*
* For NoPrefixKeyStateEncoder: Decodes the entire key
* For PrefixKeyScanStateEncoder: Decodes only the prefix portion
*
* @param bytes Serialized byte array containing the encoded key
* @return UnsafeRow containing the decoded key columns
* @throws UnsupportedOperationException for unsupported encoder types
*/
def decodeKey(bytes: Array[Byte]): UnsafeRow

/** Decodes the remaining portion of a split key from its serialized form.
*
* For PrefixKeyScanStateEncoder: Decodes columns after the prefix
* For RangeKeyScanStateEncoder: Decodes non-ordering columns
*
* @param bytes Serialized byte array containing the encoded remaining key portion
* @return UnsafeRow containing the decoded remaining key columns
* @throws UnsupportedOperationException if called on an encoder that doesn't support split keys
*/
def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow

/** Decodes range scan key bytes back into an UnsafeRow, preserving proper ordering.
*
* This method reverses the special encoding done by encodePrefixKeyForRangeScan:
* - Interprets sign byte markers
* - Reverses bit flipping for negative floating point values
* - Handles null values
*
* @param bytes Serialized byte array containing the encoded range scan key
* @return UnsafeRow containing the decoded range scan columns
* @throws UnsupportedOperationException if called on an encoder that doesn't support range scans
*/
def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow

/** Decodes a value from its serialized byte form.
*
* @param bytes Serialized byte array containing the encoded value
* @return UnsafeRow containing the decoded value columns
*/
def decodeValue(bytes: Array[Byte]): UnsafeRow
}

Expand Down Expand Up @@ -347,10 +424,10 @@ class UnsafeRowDataEncoder(

class AvroStateEncoder(
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType,
avroEncoder: AvroEncoder) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema)
valueSchema: StructType) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema)
with Logging {

private val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema)
// Avro schema used by the avro encoders
private lazy val keyAvroType: Schema = SchemaConverters.toAvroType(keySchema)
private lazy val keyProj = UnsafeProjection.create(keySchema)
Expand Down Expand Up @@ -402,6 +479,80 @@ class AvroStateEncoder(

private lazy val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uber nit: delete extra empty lines


private def getAvroSerializer(schema: StructType): AvroSerializer = {
val avroType = SchemaConverters.toAvroType(schema)
new AvroSerializer(schema, avroType, nullable = false)
}

private def getAvroDeserializer(schema: StructType): AvroDeserializer = {
val avroType = SchemaConverters.toAvroType(schema)
val avroOptions = AvroOptions(Map.empty)
new AvroDeserializer(avroType, schema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
}

/**
* Creates an AvroEncoder that handles both key and value serialization/deserialization.
* This method sets up the complete encoding infrastructure needed for state store operations.
*
* The encoder handles different key encoding specifications:
* - NoPrefixKeyStateEncoderSpec: Simple key encoding without prefix
* - PrefixKeyScanStateEncoderSpec: Keys with prefix for efficient scanning
* - RangeKeyScanStateEncoderSpec: Keys with ordering requirements for range scans
*
* For prefix scan cases, it also creates separate encoders for the suffix portion of keys.
*
* @param keyStateEncoderSpec Specification for how to encode keys
* @param valueSchema Schema for the values to be encoded
* @return An AvroEncoder containing all necessary serializers and deserializers
*/
private def createAvroEnc(
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType
): AvroEncoder = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix the indent

val valueSerializer = getAvroSerializer(valueSchema)
val valueDeserializer = getAvroDeserializer(valueSchema)

// Get key schema based on encoder spec type
val keySchema = keyStateEncoderSpec match {
case NoPrefixKeyStateEncoderSpec(schema) =>
schema
case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
StructType(schema.take(numColsPrefixKey))
case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) =>
val remainingSchema = {
0.until(schema.length).diff(orderingOrdinals).map { ordinal =>
schema(ordinal)
}
}
StructType(remainingSchema)
}

// Handle suffix key schema for prefix scan case
val suffixKeySchema = keyStateEncoderSpec match {
case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
Some(StructType(schema.drop(numColsPrefixKey)))
case _ =>
None
}

val keySerializer = getAvroSerializer(keySchema)
val keyDeserializer = getAvroDeserializer(keySchema)

// Create the AvroEncoder with all components
AvroEncoder(
keySerializer,
keyDeserializer,
valueSerializer,
valueDeserializer,
suffixKeySchema.map(getAvroSerializer),
suffixKeySchema.map(getAvroDeserializer)
)
}

/**
* This method takes an UnsafeRow, and serializes to a byte array using Avro encoding.
*/
Expand Down Expand Up @@ -789,44 +940,58 @@ abstract class RocksDBKeyStateEncoderBase(
}
}

/** Factory object for creating state encoders used by RocksDB state store.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move text to next line

*
* The encoders created by this object handle serialization and deserialization of state data,
* supporting both key and value encoding with various access patterns
* (e.g., prefix scan, range scan).
*/
object RocksDBStateEncoder extends Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add scaladoc for all the methods below?


/** Creates a key encoder based on the specified encoding strategy and configuration.
*
* @param dataEncoder The underlying encoder that handles the actual data encoding/decoding
* @param keyStateEncoderSpec Specification defining the key encoding strategy
* (no prefix, prefix scan, or range scan)
* @param useColumnFamilies Whether to use RocksDB column families for storage
* @param virtualColFamilyId Optional column family identifier when column families are enabled
* @return A configured RocksDBKeyStateEncoder instance
*/
def getKeyEncoder(
dataEncoder: RocksDBDataEncoder,
keyStateEncoderSpec: KeyStateEncoderSpec,
useColumnFamilies: Boolean,
virtualColFamilyId: Option[Short] = None,
avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = {
// Return the key state encoder based on the requested type
keyStateEncoderSpec match {
case NoPrefixKeyStateEncoderSpec(keySchema) =>
new NoPrefixKeyStateEncoder(dataEncoder, keySchema, useColumnFamilies, virtualColFamilyId)

case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
new PrefixKeyScanStateEncoder(dataEncoder, keySchema, numColsPrefixKey,
useColumnFamilies, virtualColFamilyId)

case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
new RangeKeyScanStateEncoder(dataEncoder, keySchema, orderingOrdinals,
useColumnFamilies, virtualColFamilyId)

case _ =>
throw new IllegalArgumentException(s"Unsupported key state encoder spec: " +
s"$keyStateEncoderSpec")
}
virtualColFamilyId: Option[Short] = None): RocksDBKeyStateEncoder = {
keyStateEncoderSpec.toEncoder(dataEncoder, useColumnFamilies, virtualColFamilyId)
}

/** Creates a value encoder that supports either single or multiple values per key.
*
* @param dataEncoder The underlying encoder that handles the actual data encoding/decoding
* @param valueSchema Schema defining the structure of values to be encoded
* @param useMultipleValuesPerKey If true, creates an encoder that can handle multiple values
* per key; if false, creates an encoder for single values
* @return A configured RocksDBValueStateEncoder instance
*/
def getValueEncoder(
dataEncoder: RocksDBDataEncoder,
valueSchema: StructType,
useMultipleValuesPerKey: Boolean,
avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = {
useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = {
if (useMultipleValuesPerKey) {
new MultiValuedStateEncoder(dataEncoder, valueSchema)
} else {
new SingleValueStateEncoder(dataEncoder, valueSchema)
}
}

/** Encodes a virtual column family ID into a byte array suitable for RocksDB.
*
* This method creates a fixed-size byte array prefixed with the virtual column family ID,
* which is used to partition data within RocksDB.
*
* @param virtualColFamilyId The column family identifier to encode
* @return A byte array containing the encoded column family ID
*/
def getColumnFamilyIdBytes(virtualColFamilyId: Short): Array[Byte] = {
val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET, virtualColFamilyId)
Expand Down Expand Up @@ -871,18 +1036,6 @@ class PrefixKeyScanStateEncoder(
UnsafeProjection.create(refs)
}

// Prefix Key schema and projection definitions used by the Avro Serializers
// and Deserializers
private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey))
private lazy val prefixKeyAvroType = SchemaConverters.toAvroType(prefixKeySchema)
private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema)

// Remaining Key schema and projection definitions used by the Avro Serializers
// and Deserializers
private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey))
private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema)
private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)

// This is quite simple to do - just bind sequentially, as we don't change the order.
private val restoreKeyProjection: UnsafeProjection = UnsafeProjection.create(keySchema)

Expand Down Expand Up @@ -1056,22 +1209,6 @@ class RangeKeyScanStateEncoder(
UnsafeProjection.create(refs)
}

private val rangeScanAvroSchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan(
StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray))

private lazy val rangeScanAvroType = SchemaConverters.toAvroType(rangeScanAvroSchema)

private val rangeScanAvroProjection = UnsafeProjection.create(rangeScanAvroSchema)

// Existing remainder key schema stuff
private val remainingKeySchema = StructType(
0.to(keySchema.length - 1).diff(orderingOrdinals).map(keySchema(_))
)

private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema)

private val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema)

// Reusable objects
private val joinedRowOnKey = new JoinedRow()

Expand Down
Loading
Loading