Skip to content

Commit 8e5e6ed

Browse files
committed
stuff
1 parent 519f6c9 commit 8e5e6ed

1 file changed

Lines changed: 22 additions & 13 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -892,33 +892,43 @@ class AvroStateEncoder(
892892
valueAvroType) // Defining Avro writer for this struct type
893893
writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array
894894
encoder.flush()
895-
val bytesToEncode = out.toByteArray
896-
// prepend version byte
895+
prependVersionByte(out.toByteArray)
896+
}
897+
898+
private def prependVersionByte(bytesToEncode: Array[Byte]): Array[Byte] = {
897899
val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES)
898900
Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION)
899901
// Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform.
900902
Platform.copyMemory(
901-
bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
903+
bytesToEncode, 0,
902904
encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
903905
bytesToEncode.length)
904906
encodedBytes
905907
}
906908

909+
private def removeVersionByte(bytes: Array[Byte]): Array[Byte] = {
910+
val resultBytes = new Array[Byte](bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
911+
Platform.copyMemory(
912+
bytes, STATE_ENCODING_NUM_VERSION_BYTES + Platform.BYTE_ARRAY_OFFSET,
913+
resultBytes, 0, resultBytes.length
914+
)
915+
resultBytes
916+
}
917+
907918
/**
908919
* This method takes a byte array written using Avro encoding, and
909920
* deserializes to an UnsafeRow using the Avro deserializer
910921
*/
911922
def decodeFromAvroToUnsafeRow(
912-
valueBytes: Array[Byte],
923+
b: Array[Byte],
913924
avroDeserializer: AvroDeserializer,
914925
valueAvroType: Schema,
915926
valueProj: UnsafeProjection): UnsafeRow = {
916-
if (valueBytes != null) {
927+
if (b != null) {
928+
val valueBytes = removeVersionByte(b)
917929
val reader = new GenericDatumReader[Any](valueAvroType)
918930
val decoder = DecoderFactory.get().binaryDecoder(
919-
valueBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_VERSION,
920-
valueBytes.length - Platform.BYTE_ARRAY_OFFSET - STATE_ENCODING_VERSION,
921-
null)
931+
valueBytes, 0, valueBytes.length, null)
922932
// bytes -> Avro.GenericDataRecord
923933
val genericData = reader.read(null, decoder)
924934
// Avro.GenericDataRecord -> InternalRow
@@ -943,17 +953,16 @@ class AvroStateEncoder(
943953
* @return The deserialized UnsafeRow, or null if input bytes are null
944954
*/
945955
def decodeFromAvroToUnsafeRow(
946-
valueBytes: Array[Byte],
956+
b: Array[Byte],
947957
avroDeserializer: AvroDeserializer,
948958
writerSchema: Schema,
949959
readerSchema: Schema,
950960
valueProj: UnsafeProjection): UnsafeRow = {
951-
if (valueBytes != null) {
961+
if (b != null) {
962+
val valueBytes = removeVersionByte(b)
952963
val reader = new GenericDatumReader[Any](writerSchema, readerSchema)
953964
val decoder = DecoderFactory.get().binaryDecoder(
954-
valueBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_VERSION,
955-
valueBytes.length - Platform.BYTE_ARRAY_OFFSET - STATE_ENCODING_VERSION,
956-
null)
965+
valueBytes, 0, valueBytes.length, null)
957966
// bytes -> Avro.GenericDataRecord
958967
val genericData = reader.read(null, decoder)
959968
// Avro.GenericDataRecord -> InternalRow

0 commit comments

Comments
 (0)