Skip to content

Commit 17838be

Browse files
Add recursion depth limits to pure python
PiperOrigin-RevId: 758382549
1 parent c9d4385 commit 17838be

File tree

4 files changed

+102
-19
lines changed

4 files changed

+102
-19
lines changed

python/google/protobuf/internal/decoder.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,13 @@ def DecodeRepeatedField(
703703
if value is None:
704704
value = field_dict.setdefault(key, new_default(message))
705705
# Read sub-message.
706+
current_depth += 1
707+
if current_depth > _recursion_limit:
708+
raise _DecodeError(
709+
'Error parsing message: too many levels of nesting.'
710+
)
706711
pos = value.add()._InternalParse(buffer, pos, end, current_depth)
712+
current_depth -= 1
707713
# Read end tag.
708714
new_pos = pos+end_tag_len
709715
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -722,7 +728,11 @@ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
722728
if value is None:
723729
value = field_dict.setdefault(key, new_default(message))
724730
# Read sub-message.
731+
current_depth += 1
732+
if current_depth > _recursion_limit:
733+
raise _DecodeError('Error parsing message: too many levels of nesting.')
725734
pos = value._InternalParse(buffer, pos, end, current_depth)
735+
current_depth -= 1
726736
# Read end tag.
727737
new_pos = pos+end_tag_len
728738
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -755,13 +765,19 @@ def DecodeRepeatedField(
755765
if new_pos > end:
756766
raise _DecodeError('Truncated message.')
757767
# Read sub-message.
768+
current_depth += 1
769+
if current_depth > _recursion_limit:
770+
raise _DecodeError(
771+
'Error parsing message: too many levels of nesting.'
772+
)
758773
if (
759774
value.add()._InternalParse(buffer, pos, new_pos, current_depth)
760775
!= new_pos
761776
):
762777
# The only reason _InternalParse would return early is if it
763778
# encountered an end-group tag.
764779
raise _DecodeError('Unexpected end-group tag.')
780+
current_depth -= 1
765781
# Predict that the next tag is another copy of the same repeated field.
766782
pos = new_pos + tag_len
767783
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -781,10 +797,14 @@ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
781797
if new_pos > end:
782798
raise _DecodeError('Truncated message.')
783799
# Read sub-message.
800+
current_depth += 1
801+
if current_depth > _recursion_limit:
802+
raise _DecodeError('Error parsing message: too many levels of nesting.')
784803
if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
785804
# The only reason _InternalParse would return early is if it encountered
786805
# an end-group tag.
787806
raise _DecodeError('Unexpected end-group tag.')
807+
current_depth -= 1
788808
return new_pos
789809

790810
return DecodeField
@@ -980,6 +1000,13 @@ def _DecodeFixed32(buffer, pos):
9801000

9811001
new_pos = pos + 4
9821002
return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
1003+
DEFAULT_RECURSION_LIMIT = 100
1004+
_recursion_limit = DEFAULT_RECURSION_LIMIT
1005+
1006+
1007+
def SetRecursionLimit(new_limit):
1008+
global _recursion_limit
1009+
_recursion_limit = new_limit
9831010

9841011

9851012
def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
@@ -1020,7 +1047,11 @@ def _DecodeUnknownField(
10201047
end_tag_bytes = encoder.TagBytes(
10211048
field_number, wire_format.WIRETYPE_END_GROUP
10221049
)
1050+
current_depth += 1
1051+
if current_depth >= _recursion_limit:
1052+
raise _DecodeError('Error parsing message: too many levels of nesting.')
10231053
data, pos = _DecodeUnknownFieldSet(buffer, pos, end_pos, current_depth)
1054+
current_depth -= 1
10241055
# Check end tag.
10251056
if buffer[pos - len(end_tag_bytes) : pos] != end_tag_bytes:
10261057
raise _DecodeError('Missing group end tag.')

python/google/protobuf/internal/decoder_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ def test_decode_unknown_group_field_nested(self):
8484
self.assertEqual(parsed[0].data[0].data[0].field_number, 3)
8585
self.assertEqual(parsed[0].data[0].data[0].data, 4)
8686

87+
def test_decode_unknown_group_field_too_many_levels(self):
88+
data = memoryview(b'\023' * 5_000_000)
89+
self.assertRaisesRegex(
90+
message.DecodeError,
91+
'Error parsing message',
92+
decoder._DecodeUnknownField,
93+
data,
94+
1,
95+
len(data),
96+
1,
97+
wire_format.WIRETYPE_START_GROUP,
98+
)
99+
87100
def test_decode_unknown_mismatched_end_group(self):
88101
self.assertRaisesRegex(
89102
message.DecodeError,

python/google/protobuf/internal/message_test.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from google.protobuf.internal import more_extensions_pb2
3737
from google.protobuf.internal import more_messages_pb2
3838
from google.protobuf.internal import packed_field_test_pb2
39+
from google.protobuf.internal import self_recursive_pb2
3940
from google.protobuf.internal import test_proto3_optional_pb2
4041
from google.protobuf.internal import test_util
4142
from google.protobuf.internal import testing_refleaks
@@ -1431,6 +1432,52 @@ def testMessageClassName(self, message_module):
14311432
)
14321433

14331434

1435+
@testing_refleaks.TestCase
1436+
class TestRecursiveGroup(unittest.TestCase):
1437+
1438+
def _MakeRecursiveGroupMessage(self, n):
1439+
msg = self_recursive_pb2.SelfRecursive()
1440+
sub = msg
1441+
for _ in range(n):
1442+
sub = sub.sub_group
1443+
sub.i = 1
1444+
return msg.SerializeToString()
1445+
1446+
def testRecursiveGroups(self):
1447+
recurse_msg = self_recursive_pb2.SelfRecursive()
1448+
data = self._MakeRecursiveGroupMessage(100)
1449+
recurse_msg.ParseFromString(data)
1450+
self.assertTrue(recurse_msg.HasField('sub_group'))
1451+
1452+
def testRecursiveGroupsException(self):
1453+
if api_implementation.Type() != 'python':
1454+
api_implementation._c_module.SetAllowOversizeProtos(False)
1455+
recurse_msg = self_recursive_pb2.SelfRecursive()
1456+
data = self._MakeRecursiveGroupMessage(300)
1457+
with self.assertRaises(message.DecodeError) as context:
1458+
recurse_msg.ParseFromString(data)
1459+
self.assertIn('Error parsing message', str(context.exception))
1460+
if api_implementation.Type() == 'python':
1461+
self.assertIn('too many levels of nesting', str(context.exception))
1462+
1463+
def testRecursiveGroupsUnknownFields(self):
1464+
if api_implementation.Type() != 'python':
1465+
api_implementation._c_module.SetAllowOversizeProtos(False)
1466+
test_msg = unittest_pb2.TestAllTypes()
1467+
data = self._MakeRecursiveGroupMessage(300) # unknown to test_msg
1468+
with self.assertRaises(message.DecodeError) as context:
1469+
test_msg.ParseFromString(data)
1470+
self.assertIn(
1471+
'Error parsing message',
1472+
str(context.exception),
1473+
)
1474+
if api_implementation.Type() == 'python':
1475+
self.assertIn('too many levels of nesting', str(context.exception))
1476+
decoder.SetRecursionLimit(310)
1477+
test_msg.ParseFromString(data)
1478+
decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
1479+
1480+
14341481
# Class to test proto2-only features (required, extensions, etc.)
14351482
@testing_refleaks.TestCase
14361483
class Proto2Test(unittest.TestCase):
@@ -1905,7 +1952,7 @@ def testProto3Optional(self):
19051952
if field.name.startswith('optional_'):
19061953
self.assertTrue(field.has_presence)
19071954
for field in unittest_pb2.TestAllTypes.DESCRIPTOR.fields:
1908-
if field.is_repeated:
1955+
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
19091956
self.assertFalse(field.has_presence)
19101957
else:
19111958
self.assertTrue(field.has_presence)
@@ -2699,20 +2746,6 @@ def testMapFindInitializationErrorsSmokeTest(self):
26992746
msg.map_string_foreign_message['foo'].c = 5
27002747
self.assertEqual(0, len(msg.FindInitializationErrors()))
27012748

2702-
def testMapStubReferenceSubMessageDestructor(self):
2703-
msg = map_unittest_pb2.TestMapSubmessage()
2704-
# A reference on map stub in sub message
2705-
map_ref = msg.test_map.map_int32_int32
2706-
# Make sure destructor after Clear the original message not crash
2707-
msg.Clear()
2708-
2709-
def testRepeatedStubReferenceSubMessageDestructor(self):
2710-
msg = unittest_pb2.NestedTestAllTypes()
2711-
# A reference on repeated stub in sub message
2712-
repeated_ref = msg.payload.repeated_int32
2713-
# Make sure destructor after Clear the original message not crash
2714-
msg.Clear()
2715-
27162749
@unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2')
27172750
def testStrictUtf8Check(self):
27182751
# Test u'\ud801' is rejected at parser in both python2 and python3.
@@ -2873,8 +2906,6 @@ def testUnpackedFields(self):
28732906
self.assertEqual(golden_data, message.SerializeToString())
28742907

28752908

2876-
@unittest.skipIf(api_implementation.Type() == 'python',
2877-
'explicit tests of the C++ implementation')
28782909
@testing_refleaks.TestCase
28792910
class OversizeProtosTest(unittest.TestCase):
28802911

@@ -2891,16 +2922,23 @@ def testSucceedOkSizedProto(self):
28912922
msg.ParseFromString(self.GenerateNestedProto(100))
28922923

28932924
def testAssertOversizeProto(self):
2894-
api_implementation._c_module.SetAllowOversizeProtos(False)
2925+
if api_implementation.Type() != 'python':
2926+
api_implementation._c_module.SetAllowOversizeProtos(False)
28952927
msg = unittest_pb2.TestRecursiveMessage()
28962928
with self.assertRaises(message.DecodeError) as context:
28972929
msg.ParseFromString(self.GenerateNestedProto(101))
28982930
self.assertIn('Error parsing message', str(context.exception))
28992931

29002932
def testSucceedOversizeProto(self):
2901-
api_implementation._c_module.SetAllowOversizeProtos(True)
2933+
2934+
if api_implementation.Type() == 'python':
2935+
decoder.SetRecursionLimit(310)
2936+
else:
2937+
api_implementation._c_module.SetAllowOversizeProtos(True)
2938+
29022939
msg = unittest_pb2.TestRecursiveMessage()
29032940
msg.ParseFromString(self.GenerateNestedProto(101))
2941+
decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
29042942

29052943

29062944
if __name__ == '__main__':

python/google/protobuf/internal/self_recursive.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ package google.protobuf.python.internal;
1212
message SelfRecursive {
1313
SelfRecursive sub = 1;
1414
int32 i = 2;
15+
SelfRecursive sub_group = 3 [features.message_encoding = DELIMITED];
1516
}
1617

1718
message IndirectRecursive {

0 commit comments

Comments
 (0)