Skip to content

Commit 5f6c013

Browse files
anandoleecopybara-github
authored andcommitted
Add construction support for repeated Timestamp/Duration/Struct/ListValue.
-Users can now assign list datetime/temedelta to repeated Timestamp/Duration in construction. -Can assign list dictionary to repeated Struct. -Can assign list of list to repeated ListValue. #21541 PiperOrigin-RevId: 804523563
1 parent 86188d9 commit 5f6c013

File tree

7 files changed

+210
-122
lines changed

7 files changed

+210
-122
lines changed

python/google/protobuf/internal/duration_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ def test_duration_construction(self):
6767
)
6868
self.assertEqual(expected_td, message.optional_duration.ToTimedelta())
6969

70+
def test_repeated_duration_construction(self):
71+
td0 = datetime.timedelta(microseconds=123)
72+
td1 = datetime.timedelta(microseconds=456)
73+
dr = duration_pb2.Duration()
74+
message = well_known_types_test_pb2.WKTMessage(repeated_td=[td0, td1, dr])
75+
self.assertEqual(td0, duration.to_timedelta(message.repeated_td[0]))
76+
self.assertEqual(td1, duration.to_timedelta(message.repeated_td[1]))
77+
self.assertEqual(dr, message.repeated_td[2])
78+
7079
def test_duration_sub_annotation(self):
7180
dt = datetime.datetime.now()
7281
dr = duration_pb2.Duration()

python/google/protobuf/internal/python_message.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,35 @@ def _GetIntegerEnumValue(enum_type, value):
498498
return value
499499

500500
def init(self, **kwargs):
501+
502+
def init_wkt_or_merge(field, msg, value):
503+
if isinstance(value, message_mod.Message):
504+
msg.MergeFrom(value)
505+
elif (
506+
isinstance(value, dict)
507+
and field.message_type.full_name == _StructFullTypeName
508+
):
509+
msg.Clear()
510+
if len(value) == 1 and 'fields' in value:
511+
try:
512+
msg.update(value)
513+
except:
514+
msg.Clear()
515+
msg.__init__(**value)
516+
else:
517+
msg.update(value)
518+
elif hasattr(msg, '_internal_assign'):
519+
msg._internal_assign(value)
520+
else:
521+
raise TypeError(
522+
'Message field {0}.{1} must be initialized with a '
523+
'dict or instance of same class, got {2}.'.format(
524+
message_descriptor.name,
525+
field.name,
526+
type(value).__name__,
527+
)
528+
)
529+
501530
self._cached_byte_size = 0
502531
self._cached_byte_size_dirty = len(kwargs) > 0
503532
self._fields = {}
@@ -534,10 +563,13 @@ def init(self, **kwargs):
534563
field_copy.update(field_value)
535564
else:
536565
for val in field_value:
537-
if isinstance(val, dict):
566+
if isinstance(val, dict) and (
567+
field.message_type.full_name != _StructFullTypeName
568+
):
538569
field_copy.add(**val)
539570
else:
540-
field_copy.add().MergeFrom(val)
571+
new_msg = field_copy.add()
572+
init_wkt_or_merge(field, new_msg, val)
541573
else: # Scalar
542574
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
543575
field_value = [_GetIntegerEnumValue(field.enum_type, val)
@@ -546,38 +578,14 @@ def init(self, **kwargs):
546578
self._fields[field] = field_copy
547579
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
548580
field_copy = field._default_constructor(self)
549-
new_val = None
550-
if isinstance(field_value, message_mod.Message):
551-
new_val = field_value
552-
elif isinstance(field_value, dict):
553-
if field.message_type.full_name == _StructFullTypeName:
554-
field_copy.Clear()
555-
if len(field_value) == 1 and 'fields' in field_value:
556-
try:
557-
field_copy.update(field_value)
558-
except:
559-
# Fall back to init normal message field
560-
field_copy.Clear()
561-
new_val = field.message_type._concrete_class(**field_value)
562-
else:
563-
field_copy.update(field_value)
564-
else:
565-
new_val = field.message_type._concrete_class(**field_value)
566-
elif hasattr(field_copy, '_internal_assign'):
567-
field_copy._internal_assign(field_value)
581+
if isinstance(field_value, dict) and (
582+
field.message_type.full_name != _StructFullTypeName
583+
):
584+
new_val = field.message_type._concrete_class(**field_value)
585+
field_copy.MergeFrom(new_val)
568586
else:
569-
raise TypeError(
570-
'Message field {0}.{1} must be initialized with a '
571-
'dict or instance of same class, got {2}.'.format(
572-
message_descriptor.name,
573-
field_name,
574-
type(field_value).__name__,
575-
)
576-
)
577-
578-
if new_val != None:
579587
try:
580-
field_copy.MergeFrom(new_val)
588+
init_wkt_or_merge(field, field_copy, field_value)
581589
except TypeError:
582590
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
583591
self._fields[field] = field_copy

python/google/protobuf/internal/timestamp_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,25 @@ def test_timstamp_construction(self):
6666
optional_timestamp=datetime.datetime.today()
6767
)
6868

69+
def test_repeated_timestamp_construction(self):
70+
message = well_known_types_test_pb2.WKTMessage(
71+
repeated_ts=[
72+
datetime.datetime(2025, 1, 1),
73+
datetime.datetime(1970, 1, 1),
74+
timestamp_pb2.Timestamp(),
75+
]
76+
)
77+
self.assertEqual(len(message.repeated_ts), 3)
78+
self.assertEqual(
79+
datetime.datetime(2025, 1, 1),
80+
timestamp.to_datetime((message.repeated_ts[0])),
81+
)
82+
self.assertEqual(
83+
datetime.datetime(1970, 1, 1),
84+
timestamp.to_datetime((message.repeated_ts[1])),
85+
)
86+
self.assertEqual(timestamp_pb2.Timestamp(), message.repeated_ts[2])
87+
6988
def test_timestamp_sub_annotation(self):
7089
t1 = timestamp_pb2.Timestamp()
7190
t2 = timestamp_pb2.Timestamp()

python/google/protobuf/internal/well_known_types_test.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,9 @@ message WKTMessage {
2626
Duration optional_duration = 2;
2727
Struct optional_struct = 3;
2828
ListValue optional_list_value = 4;
29+
30+
repeated Timestamp repeated_ts = 5;
31+
repeated Duration repeated_td = 6;
32+
repeated Struct repeated_struct = 7;
33+
repeated ListValue repeated_list = 8;
2934
}

python/google/protobuf/internal/well_known_types_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,30 @@ def testSpecialStructConstruct(self):
945945
msg3 = well_known_types_test_pb2.WKTMessage(optional_struct=dictionary3)
946946
self.assertEqual(msg3.optional_struct, {'key1': 5.0})
947947

948+
def testRepeatedStructConstruct(self):
949+
dict0 = {'key1': 6.0}
950+
dict1 = {
951+
'key1': 'abc',
952+
'key2': {'subkey': 11.0, 'k': True},
953+
}
954+
value_msg = struct_pb2.Value(number_value=5.0)
955+
dict2 = {'fields': {'key1': value_msg}}
956+
msg = well_known_types_test_pb2.WKTMessage(
957+
repeated_struct=[dict0, dict1, dict2]
958+
)
959+
self.assertEqual(len(msg.repeated_struct), 3)
960+
self.assertEqual(msg.repeated_struct[0], dict0)
961+
self.assertEqual(msg.repeated_struct[1], dict1)
962+
self.assertEqual(msg.repeated_struct[2], {'key1': 5.0})
963+
964+
def testRepeatedListValueConstruct(self):
965+
list0 = [6, 'seven', True, False]
966+
list1 = [None, {'key': 1.2}]
967+
msg = well_known_types_test_pb2.WKTMessage(repeated_list=[list0, list1])
968+
self.assertEqual(len(msg.repeated_list), 2)
969+
self.assertEqual(msg.repeated_list[0], list0)
970+
self.assertEqual(msg.repeated_list[1], list1)
971+
948972
def testMergeFrom(self):
949973
struct = struct_pb2.Struct()
950974
struct_class = struct.__class__

python/google/protobuf/pyext/message.cc

Lines changed: 67 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,53 @@ int DeleteRepeatedField(CMessage* self, const FieldDescriptor* field_descriptor,
986986
return 0;
987987
}
988988

989+
int InitWKTOrMerge(const Descriptor* descriptor, PyObject* py_message,
990+
PyObject* value) {
991+
CMessage* cmessage = reinterpret_cast<CMessage*>(py_message);
992+
AssureWritable(cmessage);
993+
if (PyObject_TypeCheck(value, CMessage_Type)) {
994+
ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
995+
if (merged == nullptr) {
996+
return -1;
997+
}
998+
return 0;
999+
}
1000+
if (PyDict_Check(value) &&
1001+
(descriptor->well_known_type() == Descriptor::WELLKNOWNTYPE_STRUCT)) {
1002+
ScopedPyObjectPtr ok(PyObject_CallMethod(py_message, "update", "O", value));
1003+
if (ok.get() == nullptr && PyDict_Size(value) == 1) {
1004+
ScopedPyObjectPtr fields_str(PyUnicode_FromString("fields"));
1005+
if (PyDict_Contains(value, fields_str.get())) {
1006+
// Fallback to init as normal message field.
1007+
PyErr_Clear();
1008+
PyObject* tmp = Clear(cmessage);
1009+
Py_DECREF(tmp);
1010+
if (InitAttributes(cmessage, nullptr, value) < 0) {
1011+
return -1;
1012+
}
1013+
}
1014+
}
1015+
return 0;
1016+
}
1017+
1018+
if (descriptor->well_known_type() != Descriptor::WELLKNOWNTYPE_UNSPECIFIED &&
1019+
PyObject_HasAttrString(py_message, "_internal_assign")) {
1020+
ScopedPyObjectPtr ok(
1021+
PyObject_CallMethod(py_message, "_internal_assign", "O", value));
1022+
if (ok.get() == nullptr) {
1023+
return -1;
1024+
}
1025+
return 0;
1026+
}
1027+
1028+
PyErr_Format(PyExc_TypeError,
1029+
"Parameter to initialize message field must be "
1030+
"dict or instance of same class: expected %s got %s.",
1031+
std::string(descriptor->full_name()).c_str(),
1032+
Py_TYPE(value)->tp_name);
1033+
return -1;
1034+
}
1035+
9891036
// Initializes fields of a message. Used in constructors.
9901037
int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
9911038
if (args != nullptr && PyTuple_Size(args) != 0) {
@@ -1084,20 +1131,22 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
10841131
}
10851132
ScopedPyObjectPtr next;
10861133
while ((next.reset(PyIter_Next(iter.get()))) != nullptr) {
1087-
PyObject* kwargs = (PyDict_Check(next.get()) ? next.get() : nullptr);
1088-
ScopedPyObjectPtr new_msg(
1089-
repeated_composite_container::Add(rc_container, nullptr, kwargs));
1090-
if (new_msg == nullptr) {
1091-
return -1;
1092-
}
1093-
if (kwargs == nullptr) {
1094-
// next was not a dict, it's a message we need to merge
1095-
ScopedPyObjectPtr merged(MergeFrom(
1096-
reinterpret_cast<CMessage*>(new_msg.get()), next.get()));
1097-
if (merged.get() == nullptr) {
1134+
if ((PyDict_Check(next.get())) &&
1135+
(descriptor->message_type()->well_known_type() !=
1136+
Descriptor::WELLKNOWNTYPE_STRUCT)) {
1137+
ScopedPyObjectPtr new_msg(repeated_composite_container::Add(
1138+
rc_container, nullptr, next.get()));
1139+
if (new_msg == nullptr) {
10981140
return -1;
10991141
}
1142+
continue;
11001143
}
1144+
ScopedPyObjectPtr new_msg(repeated_composite_container::Add(
1145+
rc_container, nullptr, nullptr));
1146+
if (new_msg == nullptr) {
1147+
return -1;
1148+
}
1149+
InitWKTOrMerge(descriptor->message_type(), new_msg.get(), next.get());
11011150
}
11021151
if (PyErr_Occurred()) {
11031152
// Check to see how PyIter_Next() exited.
@@ -1142,53 +1191,19 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
11421191
return -1;
11431192
}
11441193
CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
1145-
if (PyDict_Check(value)) {
1194+
if (PyDict_Check(value) &&
1195+
(descriptor->message_type()->well_known_type() !=
1196+
Descriptor::WELLKNOWNTYPE_STRUCT)) {
11461197
// Make the message exist even if the dict is empty.
11471198
AssureWritable(cmessage);
1148-
if (descriptor->message_type()->well_known_type() ==
1149-
Descriptor::WELLKNOWNTYPE_STRUCT) {
1150-
ScopedPyObjectPtr ok(PyObject_CallMethod(
1151-
reinterpret_cast<PyObject*>(cmessage), "update", "O", value));
1152-
if (ok.get() == nullptr && PyDict_Size(value) == 1) {
1153-
ScopedPyObjectPtr fields_str(PyUnicode_FromString("fields"));
1154-
if (PyDict_Contains(value, fields_str.get())) {
1155-
// Fallback to init as normal message field.
1156-
PyErr_Clear();
1157-
PyObject* tmp = Clear(cmessage);
1158-
Py_DECREF(tmp);
1159-
if (InitAttributes(cmessage, nullptr, value) < 0) {
1160-
return -1;
1161-
}
1162-
}
1163-
}
1164-
} else {
1165-
if (InitAttributes(cmessage, nullptr, value) < 0) {
1166-
return -1;
1167-
}
1168-
}
1169-
} else if (PyObject_TypeCheck(value, CMessage_Type)) {
1170-
ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
1171-
if (merged == nullptr) {
1199+
if (InitAttributes(cmessage, nullptr, value) < 0) {
11721200
return -1;
11731201
}
1174-
} else if (descriptor->message_type()->well_known_type() !=
1175-
Descriptor::WELLKNOWNTYPE_UNSPECIFIED &&
1176-
PyObject_HasAttrString(reinterpret_cast<PyObject*>(cmessage),
1177-
"_internal_assign")) {
1178-
AssureWritable(cmessage);
1179-
ScopedPyObjectPtr ok(
1180-
PyObject_CallMethod(reinterpret_cast<PyObject*>(cmessage),
1181-
"_internal_assign", "O", value));
1182-
if (ok.get() == nullptr) {
1202+
} else {
1203+
if (InitWKTOrMerge(descriptor->message_type(), message.get(), value) <
1204+
0) {
11831205
return -1;
11841206
}
1185-
} else {
1186-
PyErr_Format(PyExc_TypeError,
1187-
"Parameter to initialize message field must be "
1188-
"dict or instance of same class: expected %s got %s.",
1189-
std::string(descriptor->full_name()).c_str(),
1190-
Py_TYPE(value)->tp_name);
1191-
return -1;
11921207
}
11931208
} else {
11941209
ScopedPyObjectPtr new_val;

0 commit comments

Comments
 (0)