Skip to content

Commit 5ce64b3

Browse files
Wire type in enum
1 parent ffa4c3b commit 5ce64b3

File tree

7 files changed

+406
-157
lines changed

7 files changed

+406
-157
lines changed

formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Helpers.kt

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,29 @@ import kotlinx.serialization.modules.*
1212
import kotlinx.serialization.protobuf.*
1313

1414
internal typealias ProtoDesc = Long
15-
internal const val VARINT = 0
16-
internal const val i64 = 1
17-
internal const val SIZE_DELIMITED = 2
18-
internal const val i32 = 5
15+
16+
internal enum class ProtoWireType(val typeId: Int) {
17+
INVALID(-1),
18+
VARINT(0),
19+
i64(1),
20+
SIZE_DELIMITED(2),
21+
i32(5),
22+
;
23+
24+
companion object {
25+
fun fromTypeId(typeId: Int): ProtoWireType {
26+
return ProtoWireType.entries.find { it.typeId == typeId } ?: INVALID
27+
}
28+
}
29+
30+
fun wireIntWithTag(tag: Int): Int {
31+
return ((tag shl 3) or typeId)
32+
}
33+
34+
override fun toString(): String {
35+
return "${this.name}($typeId)"
36+
}
37+
}
1938

2039
internal const val ID_HOLDER_ONE_OF = -2
2140

@@ -104,7 +123,7 @@ internal fun extractProtoId(descriptor: SerialDescriptor, index: Int, zeroBasedD
104123
return result
105124
}
106125

107-
internal class ProtobufDecodingException(message: String) : SerializationException(message)
126+
internal class ProtobufDecodingException(message: String, e: Throwable? = null) : SerializationException(message, e)
108127

109128
internal expect fun Int.reverseBytes(): Int
110129
internal expect fun Long.reverseBytes(): Long

formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt

Lines changed: 141 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -122,41 +122,53 @@ internal open class ProtobufDecoder(
122122
}
123123

124124
override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
125-
return when (descriptor.kind) {
126-
StructureKind.LIST -> {
127-
val tag = currentTagOrDefault
128-
return if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) {
129-
val reader = makeDelimited(reader, tag)
130-
// repeated decoder expects the first tag to be read already
131-
reader.readTag()
132-
// all elements always have id = 1
133-
RepeatedDecoder(proto, reader, ProtoDesc(1, ProtoIntegerType.DEFAULT), descriptor)
134-
135-
} else if (reader.currentType == SIZE_DELIMITED && descriptor.getElementDescriptor(0).isPackable) {
136-
val sliceReader = ProtobufReader(reader.objectInput())
137-
PackedArrayDecoder(proto, sliceReader, descriptor)
138-
139-
} else {
140-
RepeatedDecoder(proto, reader, tag, descriptor)
125+
return try {
126+
when (descriptor.kind) {
127+
StructureKind.LIST -> {
128+
val tag = currentTagOrDefault
129+
return if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) {
130+
val reader = makeDelimited(reader, tag)
131+
// repeated decoder expects the first tag to be read already
132+
reader.readTag()
133+
// all elements always have id = 1
134+
RepeatedDecoder(proto, reader, ProtoDesc(1, ProtoIntegerType.DEFAULT), descriptor)
135+
136+
} else if (reader.currentType == ProtoWireType.SIZE_DELIMITED && descriptor.getElementDescriptor(0).isPackable) {
137+
val sliceReader = ProtobufReader(reader.objectInput())
138+
PackedArrayDecoder(proto, sliceReader, descriptor)
139+
140+
} else {
141+
RepeatedDecoder(proto, reader, tag, descriptor)
142+
}
141143
}
142-
}
143-
StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> {
144-
val tag = currentTagOrDefault
145-
// Do not create redundant copy
146-
if (tag == MISSING_TAG && this.descriptor == descriptor) return this
147-
if (tag.isOneOf) {
148-
// If a tag is annotated as oneof
149-
// [tag.protoId] here is overwritten with index-based default id in
150-
// [kotlinx.serialization.protobuf.internal.HelpersKt.extractParameters]
151-
// and restored the real id from index2IdMap, set by [decodeElementIndex]
152-
val rawIndex = tag.protoId - 1
153-
val restoredTag = index2IdMap?.get(rawIndex)?.let { tag.overrideId(it) } ?: tag
154-
return OneOfPolymorphicReader(proto, reader, restoredTag, descriptor)
144+
145+
StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> {
146+
val tag = currentTagOrDefault
147+
// Do not create redundant copy
148+
if (tag == MISSING_TAG && this.descriptor == descriptor) return this
149+
if (tag.isOneOf) {
150+
// If a tag is annotated as oneof
151+
// [tag.protoId] here is overwritten with index-based default id in
152+
// [kotlinx.serialization.protobuf.internal.HelpersKt.extractParameters]
153+
// and restored the real id from index2IdMap, set by [decodeElementIndex]
154+
val rawIndex = tag.protoId - 1
155+
val restoredTag = index2IdMap?.get(rawIndex)?.let { tag.overrideId(it) } ?: tag
156+
return OneOfPolymorphicReader(proto, reader, restoredTag, descriptor)
157+
}
158+
return ProtobufDecoder(proto, makeDelimited(reader, tag), descriptor)
155159
}
156-
return ProtobufDecoder(proto, makeDelimited(reader, tag), descriptor)
160+
161+
StructureKind.MAP -> MapEntryReader(
162+
proto,
163+
makeDelimitedForced(reader, currentTagOrDefault),
164+
currentTagOrDefault,
165+
descriptor
166+
)
167+
168+
else -> throw SerializationException("Primitives are not supported at top-level")
157169
}
158-
StructureKind.MAP -> MapEntryReader(proto, makeDelimitedForced(reader, currentTagOrDefault), currentTagOrDefault, descriptor)
159-
else -> throw SerializationException("Primitives are not supported at top-level")
170+
} catch (e: ProtobufDecodingException) {
171+
throw ProtobufDecodingException("Fail to begin structure for ${descriptor.serialName} in ${this.descriptor.serialName} at proto number ${currentTagOrDefault.protoId}", e)
160172
}
161173
}
162174

@@ -173,41 +185,53 @@ internal open class ProtobufDecoder(
173185
override fun decodeTaggedByte(tag: ProtoDesc): Byte = decodeTaggedInt(tag).toByte()
174186
override fun decodeTaggedShort(tag: ProtoDesc): Short = decodeTaggedInt(tag).toShort()
175187
override fun decodeTaggedInt(tag: ProtoDesc): Int {
176-
return if (tag == MISSING_TAG) {
177-
reader.readInt32NoTag()
178-
} else {
179-
reader.readInt(tag.integerType)
188+
try {
189+
return if (tag == MISSING_TAG) {
190+
reader.readInt32NoTag()
191+
} else {
192+
reader.readInt(tag.integerType)
193+
}
194+
} catch (e: ProtobufDecodingException) {
195+
rethrowException(tag, e)
180196
}
181197
}
182198
override fun decodeTaggedLong(tag: ProtoDesc): Long {
183-
return if (tag == MISSING_TAG) {
184-
reader.readLongNoTag()
185-
} else {
186-
reader.readLong(tag.integerType)
199+
return decodeOrThrow(tag) {
200+
if (tag == MISSING_TAG) {
201+
reader.readLongNoTag()
202+
} else {
203+
reader.readLong(tag.integerType)
204+
}
187205
}
188206
}
189207

190208
override fun decodeTaggedFloat(tag: ProtoDesc): Float {
191-
return if (tag == MISSING_TAG) {
192-
reader.readFloatNoTag()
193-
} else {
194-
reader.readFloat()
209+
return decodeOrThrow(tag) {
210+
if (tag == MISSING_TAG) {
211+
reader.readFloatNoTag()
212+
} else {
213+
reader.readFloat()
214+
}
195215
}
196216
}
197217
override fun decodeTaggedDouble(tag: ProtoDesc): Double {
198-
return if (tag == MISSING_TAG) {
199-
reader.readDoubleNoTag()
200-
} else {
201-
reader.readDouble()
218+
return decodeOrThrow(tag) {
219+
if (tag == MISSING_TAG) {
220+
reader.readDoubleNoTag()
221+
} else {
222+
reader.readDouble()
223+
}
202224
}
203225
}
204226
override fun decodeTaggedChar(tag: ProtoDesc): Char = decodeTaggedInt(tag).toChar()
205227

206228
override fun decodeTaggedString(tag: ProtoDesc): String {
207-
return if (tag == MISSING_TAG) {
208-
reader.readStringNoTag()
209-
} else {
210-
reader.readString()
229+
return decodeOrThrow(tag) {
230+
if (tag == MISSING_TAG) {
231+
reader.readStringNoTag()
232+
} else {
233+
reader.readString()
234+
}
211235
}
212236
}
213237

@@ -218,22 +242,38 @@ internal open class ProtobufDecoder(
218242
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T = decodeSerializableValue(deserializer, null)
219243

220244
@Suppress("UNCHECKED_CAST")
221-
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>, previousValue: T?): T = when {
222-
deserializer is MapLikeSerializer<*, *, *, *> -> {
223-
deserializeMap(deserializer as DeserializationStrategy<T>, previousValue)
245+
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>, previousValue: T?): T = try {
246+
when {
247+
deserializer is MapLikeSerializer<*, *, *, *> -> {
248+
deserializeMap(deserializer as DeserializationStrategy<T>, previousValue)
249+
}
250+
251+
deserializer.descriptor == ByteArraySerializer().descriptor -> deserializeByteArray(previousValue as ByteArray?) as T
252+
deserializer is AbstractCollectionSerializer<*, *, *> ->
253+
(deserializer as AbstractCollectionSerializer<*, T, *>).merge(this, previousValue)
254+
255+
else -> deserializer.deserialize(this)
224256
}
225-
deserializer.descriptor == ByteArraySerializer().descriptor -> deserializeByteArray(previousValue as ByteArray?) as T
226-
deserializer is AbstractCollectionSerializer<*, *, *> ->
227-
(deserializer as AbstractCollectionSerializer<*, T, *>).merge(this, previousValue)
228-
else -> deserializer.deserialize(this)
257+
} catch (e: ProtobufDecodingException) {
258+
val currentTag = currentTagOrDefault
259+
val msg = if (descriptor != deserializer.descriptor) {
260+
// Decoding child element
261+
"Error while decoding ${deserializer.descriptor.serialName} at proto number ${currentTag.protoId} of ${descriptor.serialName}"
262+
} else {
263+
// Decoding self
264+
"Error while decoding ${descriptor.serialName}"
265+
}
266+
throw ProtobufDecodingException(msg, e)
229267
}
230268

231269
private fun deserializeByteArray(previousValue: ByteArray?): ByteArray {
232270
val tag = currentTagOrDefault
233-
val array = if (tag == MISSING_TAG) {
234-
reader.readByteArrayNoTag()
235-
} else {
236-
reader.readByteArray()
271+
val array = decodeOrThrow(tag) {
272+
if (tag == MISSING_TAG) {
273+
reader.readByteArrayNoTag()
274+
} else {
275+
reader.readByteArray()
276+
}
237277
}
238278
return if (previousValue == null) array else previousValue + array
239279
}
@@ -252,29 +292,33 @@ internal open class ProtobufDecoder(
252292
override fun SerialDescriptor.getTag(index: Int) = extractParameters(index)
253293

254294
override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
255-
while (true) {
256-
val protoId = reader.readTag()
257-
if (protoId == -1) { // EOF
258-
return elementMarker.nextUnmarkedIndex()
259-
}
260-
val index = getIndexByNum(protoId)
261-
if (index == -1) { // not found
262-
reader.skipElement()
263-
} else {
264-
if (descriptor.extractParameters(index).isOneOf) {
265-
/**
266-
* While decoding message with one-of field,
267-
* the proto id read from wire data cannot be easily found
268-
* in the properties of this type,
269-
* So the index of this one-of property and the id read from the wire
270-
* are saved in this map, then restored in [beginStructure]
271-
* and passed to [OneOfPolymorphicReader] to get the actual deserializer.
272-
*/
273-
index2IdMap?.put(index, protoId)
295+
try {
296+
while (true) {
297+
val protoId = reader.readTag()
298+
if (protoId == -1) { // EOF
299+
return elementMarker.nextUnmarkedIndex()
300+
}
301+
val index = getIndexByNum(protoId)
302+
if (index == -1) { // not found
303+
reader.skipElement()
304+
} else {
305+
if (descriptor.extractParameters(index).isOneOf) {
306+
/**
307+
* While decoding message with one-of field,
308+
* the proto id read from wire data cannot be easily found
309+
* in the properties of this type,
310+
* So the index of this one-of property and the id read from the wire
311+
* are saved in this map, then restored in [beginStructure]
312+
* and passed to [OneOfPolymorphicReader] to get the actual deserializer.
313+
*/
314+
index2IdMap?.put(index, protoId)
315+
}
316+
elementMarker.mark(index)
317+
return index
274318
}
275-
elementMarker.mark(index)
276-
return index
277319
}
320+
} catch (e: ProtobufDecodingException) {
321+
throw ProtobufDecodingException("Fail to get element index for ${descriptor.serialName} in ${this.descriptor.serialName}", e)
278322
}
279323
}
280324

@@ -296,6 +340,19 @@ internal open class ProtobufDecoder(
296340
}
297341
return false
298342
}
343+
344+
private inline fun <T> decodeOrThrow(tag: ProtoDesc, crossinline action: (tag: ProtoDesc) -> T): T {
345+
try {
346+
return action(tag)
347+
} catch (e: ProtobufDecodingException) {
348+
rethrowException(tag, e)
349+
}
350+
}
351+
352+
@Suppress("NOTHING_TO_INLINE")
353+
private inline fun rethrowException(tag: ProtoDesc, e: ProtobufDecodingException): Nothing {
354+
throw ProtobufDecodingException("Error while decoding proto number ${tag.protoId} of ${descriptor.serialName}", e)
355+
}
299356
}
300357

301358
private class RepeatedDecoder(

0 commit comments

Comments
 (0)