Skip to content

Commit 07ef676

Browse files
Fix handling of repeated extension fields in PyProto JSON
Both serialize and parse path would not work on repeated extensions before these fixes. #22989 PiperOrigin-RevId: 803542611
1 parent eaedcdd commit 07ef676

File tree

2 files changed

+128
-47
lines changed

2 files changed

+128
-47
lines changed

python/google/protobuf/internal/json_format_test.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,46 @@ def testExtensionToDictAndBackWithScalar(self):
183183
json_format.ParseDict(message_dict, parsed_message)
184184
self.assertEqual(message, parsed_message)
185185

186+
def testScalarExtensionToDictAndBack(self):
187+
message = unittest_pb2.TestAllExtensions()
188+
message.Extensions[unittest_pb2.optional_int32_extension] = 7
189+
message.Extensions[unittest_pb2.optional_string_extension] = 'hello'
190+
message_dict = json_format.MessageToDict(message)
191+
self.assertEqual(
192+
message_dict,
193+
{
194+
'[proto2_unittest.optional_int32_extension]': 7,
195+
'[proto2_unittest.optional_string_extension]': 'hello',
196+
},
197+
)
198+
parsed_message = unittest_pb2.TestAllExtensions()
199+
json_format.ParseDict(message_dict, parsed_message)
200+
self.assertEqual(message, parsed_message)
201+
202+
def testRepeatedScalarExtensionToDictAndBack(self):
203+
message = unittest_pb2.TestAllExtensions()
204+
ext1 = unittest_pb2.repeated_int32_extension
205+
message.Extensions[ext1].extend([1, 2, 3])
206+
message_dict = json_format.MessageToDict(message)
207+
self.assertIn('[proto2_unittest.repeated_int32_extension]', message_dict)
208+
parsed_message = unittest_pb2.TestAllExtensions()
209+
json_format.ParseDict(message_dict, parsed_message)
210+
self.assertEqual(message, parsed_message)
211+
212+
def testRepeatedMessageExtensionToDictAndBack(self):
213+
message = unittest_pb2.TestAllExtensions()
214+
ext1 = unittest_pb2.repeated_nested_message_extension
215+
sub = unittest_pb2.TestAllTypes.NestedMessage()
216+
sub.bb = 1
217+
message.Extensions[ext1].append(sub)
218+
message_dict = json_format.MessageToDict(message)
219+
self.assertIn(
220+
'[proto2_unittest.repeated_nested_message_extension]', message_dict
221+
)
222+
parsed_message = unittest_pb2.TestAllExtensions()
223+
json_format.ParseDict(message_dict, parsed_message)
224+
self.assertEqual(message, parsed_message)
225+
186226
def testJsonParseDictToAnyDoesNotAlterInput(self):
187227
orig_dict = {
188228
'int32Value': 20,
@@ -1041,7 +1081,10 @@ def testParseErrorForUnknownEnumValue_ScalarWithoutIgnore_Proto2(self):
10411081
self.assertRaisesRegex(
10421082
json_format.ParseError,
10431083
'Invalid enum value',
1044-
json_format.Parse, '{"a": "UNKNOWN_STRING_VALUE"}', message)
1084+
json_format.Parse,
1085+
'{"a": "UNKNOWN_STRING_VALUE"}',
1086+
message,
1087+
)
10451088

10461089
def testParseUnknownEnumStringValue_Repeated_Proto2(self):
10471090
message = json_format_pb2.TestRepeatedEnum()
@@ -1066,8 +1109,9 @@ def testParseUnknownEnumStringValue_ExtensionField_Proto2(self):
10661109
"""
10671110
json_format.Parse(text, message, ignore_unknown_fields=True)
10681111

1069-
self.assertFalse(json_format_pb2.TestExtension.enum_ext in
1070-
message.Extensions)
1112+
self.assertFalse(
1113+
json_format_pb2.TestExtension.enum_ext in message.Extensions
1114+
)
10711115

10721116
def testParseUnknownEnumStringValue_ExtensionFieldWithoutIgnore_Proto2(self):
10731117
message = json_format_pb2.TestMessageWithExtension()
@@ -1077,7 +1121,10 @@ def testParseUnknownEnumStringValue_ExtensionFieldWithoutIgnore_Proto2(self):
10771121
self.assertRaisesRegex(
10781122
json_format.ParseError,
10791123
'Invalid enum value',
1080-
json_format.Parse, text, message)
1124+
json_format.Parse,
1125+
text,
1126+
message,
1127+
)
10811128

10821129
def testParseUnknownEnumStringValue_Scalar_Proto3(self):
10831130
message = json_format_proto3_pb2.TestMessage()
@@ -1092,8 +1139,9 @@ def testParseUnknownEnumStringValue_Repeated_Proto3(self):
10921139
json_format.Parse(text, message, ignore_unknown_fields=True)
10931140

10941141
self.assertEqual(len(message.repeated_enum_value), 1)
1095-
self.assertTrue(message.repeated_enum_value[0] ==
1096-
json_format_proto3_pb2.FOO)
1142+
self.assertTrue(
1143+
message.repeated_enum_value[0] == json_format_proto3_pb2.FOO
1144+
)
10971145

10981146
def testParseUnknownEnumStringValue_Map_Proto3(self):
10991147
message = json_format_proto3_pb2.MapOfEnums()
@@ -1155,6 +1203,28 @@ def testDuplicateField(self):
11551203
'Failed to load JSON: duplicate key int32Value.',
11561204
)
11571205

1206+
def testDuplicateFieldAlternateNames(self):
1207+
# Note: this behavior is non-spec and an oversight bug in the
1208+
# implementation, but would be a breaking change to fix. The duplicate field
1209+
# checker intends reject inputs with duplicate key names, but it only
1210+
# catches keys that are exact matches and not alternate spellings that
1211+
# correspond to the same field.
1212+
parsed_message = json_format_proto3_pb2.TestMessage()
1213+
json_format.Parse('{"int32Value": 1,"int32_value":2}', parsed_message)
1214+
self.assertEqual(parsed_message.int32_value, 2)
1215+
1216+
def testDuplicateFieldAlternateNamesMap(self):
1217+
# Note: this behavior is non-spec and an oversight bug in the
1218+
# implementation, but would be a breaking change to fix. The duplicate field
1219+
# checker intends reject inputs with duplicate key names, but it only
1220+
# catches keys that are exact matches and not alternate spellings that
1221+
# correspond to the same field.
1222+
parsed_message = json_format_proto3_pb2.TestMap()
1223+
json_format.Parse(
1224+
'{"int32Map": {"1": 2}, "int32_map": {"3": 4}}', parsed_message
1225+
)
1226+
self.assertEqual(parsed_message.int32_map, {3: 4})
1227+
11581228
def testInvalidBoolValue(self):
11591229
self.CheckError(
11601230
'{"boolValue": 1}',
@@ -1458,7 +1528,13 @@ def testFieldMaskInvalidStringValue(self):
14581528
def testInvalidAny(self):
14591529
message = any_pb2.Any()
14601530
text = '{"@type": "type.googleapis.com/google.protobuf.Int32Value"}'
1461-
self.assertRaisesRegex(json_format.ParseError, 'KeyError: \'value\'', json_format.Parse, text, message)
1531+
self.assertRaisesRegex(
1532+
json_format.ParseError,
1533+
"KeyError: 'value'",
1534+
json_format.Parse,
1535+
text,
1536+
message,
1537+
)
14621538
text = '{"value": 1234}'
14631539
self.assertRaisesRegex(
14641540
json_format.ParseError,
@@ -1739,5 +1815,6 @@ def testManyRecursionsRaisesParseError(self):
17391815
with self.assertRaises(json_format.ParseError):
17401816
json_format.Parse(text, json_format_proto3_pb2.TestMessage())
17411817

1818+
17421819
if __name__ == '__main__':
17431820
unittest.main()

python/google/protobuf/json_format.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,13 @@ def _RegularMessageToJsonObject(self, message, js):
224224

225225
try:
226226
for field, value in fields:
227-
if self.preserving_proto_field_name:
227+
if field.is_extension:
228+
name = '[%s]' % field.full_name
229+
elif self.preserving_proto_field_name:
228230
name = field.name
229231
else:
230232
name = field.json_name
233+
231234
if _IsMapEntry(field):
232235
# Convert a map field.
233236
v_field = field.message_type.fields_by_name['value']
@@ -245,9 +248,6 @@ def _RegularMessageToJsonObject(self, message, js):
245248
elif field.is_repeated:
246249
# Convert a repeated field.
247250
js[name] = [self._FieldToJsonObject(field, k) for k in value]
248-
elif field.is_extension:
249-
name = '[%s]' % field.full_name
250-
js[name] = self._FieldToJsonObject(field, value)
251251
else:
252252
js[name] = self._FieldToJsonObject(field, value)
253253

@@ -562,6 +562,25 @@ def _ConvertFieldValuePair(self, js, message, path):
562562
fields_by_json_name = dict(
563563
(f.json_name, f) for f in message_descriptor.fields
564564
)
565+
566+
def _ClearFieldOrExtension(message, field):
567+
if field.is_extension:
568+
message.ClearExtension(field)
569+
else:
570+
message.ClearField(field.name)
571+
572+
def _GetFieldOrExtension(message, field):
573+
if field.is_extension:
574+
return message.Extensions[field]
575+
else:
576+
return getattr(message, field.name)
577+
578+
def _SetFieldOrExtension(message, field, value):
579+
if field.is_extension:
580+
message.Extensions[field] = value
581+
else:
582+
setattr(message, field.name, value)
583+
565584
for name in js:
566585
try:
567586
field = fields_by_json_name.get(name, None)
@@ -625,25 +644,25 @@ def _ConvertFieldValuePair(self, js, message, path):
625644
field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE
626645
and field.message_type.full_name == 'google.protobuf.Value'
627646
):
628-
sub_message = getattr(message, field.name)
647+
sub_message = _GetFieldOrExtension(message, field)
629648
sub_message.null_value = 0
630649
elif (
631650
field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM
632651
and field.enum_type.full_name == 'google.protobuf.NullValue'
633652
):
634-
setattr(message, field.name, 0)
653+
_SetFieldOrExtension(message, field, 0)
635654
else:
636-
message.ClearField(field.name)
655+
_ClearFieldOrExtension(message, field)
637656
continue
638657

639658
# Parse field value.
640659
if _IsMapEntry(field):
641-
message.ClearField(field.name)
660+
_ClearFieldOrExtension(message, field)
642661
self._ConvertMapFieldValue(
643662
value, message, field, '{0}.{1}'.format(path, name)
644663
)
645664
elif field.is_repeated:
646-
message.ClearField(field.name)
665+
_ClearFieldOrExtension(message, field)
647666
if not isinstance(value, _LIST_LIKE):
648667
raise ParseError(
649668
'repeated field {0} must be in [] which is {1} at {2}'.format(
@@ -653,7 +672,7 @@ def _ConvertFieldValuePair(self, js, message, path):
653672
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
654673
# Repeated message field.
655674
for index, item in enumerate(value):
656-
sub_message = getattr(message, field.name).add()
675+
sub_message = _GetFieldOrExtension(message, field).add()
657676
# None is a null_value in Value.
658677
if (
659678
item is None
@@ -683,21 +702,13 @@ def _ConvertFieldValuePair(self, js, message, path):
683702
message, field, item, '{0}.{1}[{2}]'.format(path, name, index)
684703
)
685704
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
686-
if field.is_extension:
687-
sub_message = message.Extensions[field]
688-
else:
689-
sub_message = getattr(message, field.name)
705+
sub_message = _GetFieldOrExtension(message, field)
690706
sub_message.SetInParent()
691707
self.ConvertMessage(value, sub_message, '{0}.{1}'.format(path, name))
692708
else:
693-
if field.is_extension:
694-
self._ConvertAndSetScalarExtension(
695-
message, field, value, '{0}.{1}'.format(path, name)
696-
)
697-
else:
698-
self._ConvertAndSetScalar(
699-
message, field, value, '{0}.{1}'.format(path, name)
700-
)
709+
self._ConvertAndSetScalar(
710+
message, field, value, '{0}.{1}'.format(path, name)
711+
)
701712
except ParseError as e:
702713
if field and field.containing_oneof is None:
703714
raise ParseError(
@@ -855,34 +866,27 @@ def _ConvertMapFieldValue(self, value, message, field, path):
855866
path='{0}[{1}]'.format(path, key_value),
856867
)
857868

858-
def _ConvertAndSetScalarExtension(
859-
self, message, extension_field, js_value, path
860-
):
861-
"""Convert scalar from js_value and assign it to message.Extensions[extension_field]."""
862-
try:
863-
message.Extensions[extension_field] = _ConvertScalarFieldValue(
864-
js_value, extension_field, path
865-
)
866-
except EnumStringValueParseError:
867-
if not self.ignore_unknown_fields:
868-
raise
869-
870869
def _ConvertAndSetScalar(self, message, field, js_value, path):
871870
"""Convert scalar from js_value and assign it to message.field."""
872871
try:
873-
setattr(
874-
message, field.name, _ConvertScalarFieldValue(js_value, field, path)
875-
)
872+
value = _ConvertScalarFieldValue(js_value, field, path)
873+
if field.is_extension:
874+
message.Extensions[field] = value
875+
else:
876+
setattr(message, field.name, value)
876877
except EnumStringValueParseError:
877878
if not self.ignore_unknown_fields:
878879
raise
879880

880881
def _ConvertAndAppendScalar(self, message, repeated_field, js_value, path):
881882
"""Convert scalar from js_value and append it to message.repeated_field."""
882883
try:
883-
getattr(message, repeated_field.name).append(
884-
_ConvertScalarFieldValue(js_value, repeated_field, path)
885-
)
884+
if repeated_field.is_extension:
885+
repeated = message.Extensions[repeated_field]
886+
else:
887+
repeated = getattr(message, repeated_field.name)
888+
value = _ConvertScalarFieldValue(js_value, repeated_field, path)
889+
repeated.append(value)
886890
except EnumStringValueParseError:
887891
if not self.ignore_unknown_fields:
888892
raise

0 commit comments

Comments
 (0)