Skip to content

Commit 49d9e2d

Browse files
anandoleecopybara-github
authored andcommitted
Change proto_api work with custom pool for upb and pure python.
PiperOrigin-RevId: 761644107
1 parent f59b84a commit 49d9e2d

File tree

2 files changed

+138
-53
lines changed

2 files changed

+138
-53
lines changed

python/descriptor_pool.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ PyObject* PyUpb_DescriptorPool_Get(const upb_DefPool* symtab) {
8080
}
8181

8282
static void PyUpb_DescriptorPool_Dealloc(PyUpb_DescriptorPool* self) {
83+
#if PY_VERSION_HEX >= 0x030C0000
84+
PyObject_ClearWeakRefs((PyObject*)self);
85+
#endif
8386
PyObject_GC_UnTrack(self);
8487
PyUpb_DescriptorPool_Clear(self);
8588
upb_DefPool_Free(self->symtab);
@@ -721,7 +724,11 @@ static PyType_Spec PyUpb_DescriptorPool_Spec = {
721724
PYUPB_MODULE_NAME ".DescriptorPool",
722725
sizeof(PyUpb_DescriptorPool),
723726
0, // tp_itemsize
727+
#if PY_VERSION_HEX >= 0x030C0000
728+
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_MANAGED_WEAKREF,
729+
#else
724730
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
731+
#endif
725732
PyUpb_DescriptorPool_Slots,
726733
};
727734

python/google/protobuf/pyext/message_module.cc

Lines changed: 131 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -34,29 +34,9 @@ namespace {
3434

3535
class ProtoAPIDescriptorDatabase : public google::protobuf::DescriptorDatabase {
3636
public:
37-
ProtoAPIDescriptorDatabase() {
38-
PyObject* descriptor_pool =
39-
PyImport_ImportModule("google.protobuf.descriptor_pool");
40-
if (descriptor_pool == nullptr) {
41-
ABSL_LOG(ERROR)
42-
<< "Failed to import google.protobuf.descriptor_pool module.";
43-
}
37+
ProtoAPIDescriptorDatabase(PyObject* py_pool) : pool_(py_pool) {};
4438

45-
pool_ = PyObject_CallMethod(descriptor_pool, "Default", nullptr);
46-
if (pool_ == nullptr) {
47-
ABSL_LOG(ERROR) << "Failed to get python Default pool.";
48-
}
49-
Py_DECREF(descriptor_pool);
50-
};
51-
52-
~ProtoAPIDescriptorDatabase() {
53-
// Objects of this class are meant to be `static`ally initialized and
54-
// never destroyed. This is a commonly used approach, because the order
55-
// in which destructors of static objects run is unpredictable. In
56-
// particular, it is possible that the Python interpreter may have been
57-
// finalized already.
58-
ABSL_DLOG(ERROR) << "MEANT TO BE UNREACHABLE.";
59-
};
39+
~ProtoAPIDescriptorDatabase() {};
6040

6141
bool FindFileByName(StringViewArg filename,
6242
google::protobuf::FileDescriptorProto* output) override {
@@ -112,52 +92,150 @@ class ProtoAPIDescriptorDatabase : public google::protobuf::DescriptorDatabase {
11292
PyObject* pool_;
11393
};
11494

95+
struct DescriptorPoolState {
96+
// clang-format off
97+
PyObject_HEAD
98+
99+
std::unique_ptr<google::protobuf::DescriptorPool> pool;
100+
std::unique_ptr<ProtoAPIDescriptorDatabase> database;
101+
};
102+
103+
void DeallocDescriptorPoolState (DescriptorPoolState* self) {
104+
self->database.reset();
105+
self->pool.reset();
106+
Py_TYPE(self)->tp_free((PyObject *)self);
107+
}
108+
109+
PyTypeObject PyDescriptorPoolState_Type = {
110+
PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
111+
".DescriptorPoolState", // tp_name
112+
sizeof(DescriptorPoolState), // tp_basicsize
113+
0, // tp_itemsize
114+
(destructor)DeallocDescriptorPoolState, // tp_dealloc
115+
0, // tp_vectorcall_offset
116+
nullptr, // tp_getattr
117+
nullptr, // tp_setattr
118+
nullptr, // tp_compare
119+
nullptr, // tp_repr
120+
nullptr, // tp_as_number
121+
nullptr, // tp_as_sequence
122+
nullptr, // tp_as_mapping
123+
PyObject_HashNotImplemented, // tp_hash
124+
nullptr, // tp_call
125+
nullptr, // tp_str
126+
nullptr, // tp_getattro
127+
nullptr, // tp_setattro
128+
nullptr, // tp_as_buffer
129+
Py_TPFLAGS_DEFAULT, // tp_flags
130+
"DescriptorPoolState", // tp_doc
131+
};
132+
133+
PyObject* PyDescriptorPoolState_New(PyObject* pyfile_pool) {
134+
PyObject* pypool_state = PyType_GenericAlloc(&PyDescriptorPoolState_Type, 0);
135+
if (pypool_state == nullptr) {
136+
PyErr_SetString(PyExc_MemoryError,
137+
"Fail to new PyDescriptorPoolState_Type");
138+
return nullptr;
139+
}
140+
DescriptorPoolState* pool_state =
141+
reinterpret_cast<DescriptorPoolState*>(pypool_state);
142+
pool_state->database =
143+
std::make_unique<ProtoAPIDescriptorDatabase>(pyfile_pool);
144+
pool_state->pool =
145+
std::make_unique<google::protobuf::DescriptorPool>(pool_state->database.get());
146+
return pypool_state;
147+
}
148+
149+
PyObject* InitAndGetPoolMap() {
150+
#if PY_VERSION_HEX >= 0x030C0000
151+
// Returns a WeakKeyDictionary. The key will be a python pool and
152+
// the value will be PyDescriptorPoolState_Type.
153+
// PyDescriptorPoolState_Type should be ready for the usage.
154+
if (PyType_Ready(&PyDescriptorPoolState_Type) < 0) {
155+
return nullptr;
156+
}
157+
PyObject* weakref = PyImport_ImportModule("weakref");
158+
PyObject* pypool_map =
159+
PyObject_CallMethod(weakref, "WeakKeyDictionary", NULL);
160+
Py_DECREF(weakref);
161+
return pypool_map;
162+
#else
163+
return PyDict_New();
164+
#endif
165+
}
166+
115167
absl::StatusOr<const google::protobuf::Descriptor*> FindMessageDescriptor(
116168
PyObject* pyfile, const char* descriptor_full_name) {
117-
static auto* database = new ProtoAPIDescriptorDatabase();
118-
static auto* pool = new google::protobuf::DescriptorPool(database);
119-
PyObject* pyfile_name = PyObject_GetAttrString(pyfile, "name");
169+
static PyObject* pypool_map = InitAndGetPoolMap();
170+
if (pypool_map == nullptr) {
171+
return absl::InternalError("Fail to create pypool_map");
172+
}
173+
PyObject* pyfile_name = nullptr;
174+
PyObject* pyfile_pool = nullptr;
175+
PyObject* pypool_state = nullptr;
176+
google::protobuf::DescriptorPool* pool;
177+
DescriptorPoolState* pool_state;
178+
const char* pyfile_name_char_ptr;
179+
const google::protobuf::FileDescriptor* file_descriptor;
180+
absl::StatusOr<const google::protobuf::Descriptor*> ret;
181+
182+
pyfile_name = PyObject_GetAttrString(pyfile, "name");
120183
if (pyfile_name == nullptr) {
121-
return absl::InvalidArgumentError("FileDescriptor has no attribute 'name'");
184+
ret = absl::InvalidArgumentError("FileDescriptor has no attribute 'name'");
185+
goto err;
122186
}
123-
PyObject* pyfile_pool = PyObject_GetAttrString(pyfile, "pool");
187+
pyfile_pool = PyObject_GetAttrString(pyfile, "pool");
124188
if (pyfile_pool == nullptr) {
125-
Py_DECREF(pyfile_name);
126-
return absl::InvalidArgumentError("FileDescriptor has no attribute 'pool'");
189+
ret = absl::InvalidArgumentError("FileDescriptor has no attribute 'pool'");
190+
goto err;
191+
}
192+
193+
pypool_state = PyObject_GetItem(pypool_map, pyfile_pool);
194+
if (pypool_state == nullptr) {
195+
if (PyErr_ExceptionMatches(PyExc_KeyError)) {
196+
// Ignore the KeyError
197+
PyErr_Clear();
198+
}
199+
PyErr_Print();
200+
pypool_state = PyDescriptorPoolState_New(pyfile_pool);
201+
if (pypool_state == nullptr) {
202+
ret = absl::InternalError("Fail to create PyDescriptorPoolState_Type");
203+
goto err;
204+
}
205+
if (PyObject_SetItem(pypool_map, pyfile_pool, pypool_state) < 0) {
206+
ret = absl::InternalError(
207+
"Fail to insert PyDescriptorPoolState_Type into pypool_map");
208+
goto err;
209+
}
127210
}
128-
// Check the file descriptor is from generated pool.
129-
bool is_from_generated_pool = database->pool() == pyfile_pool;
130-
Py_DECREF(pyfile_pool);
131-
const char* pyfile_name_char_ptr = PyUnicode_AsUTF8(pyfile_name);
211+
pool_state = reinterpret_cast<DescriptorPoolState*>(pypool_state);
212+
pool = pool_state->pool.get();
213+
pyfile_name_char_ptr = PyUnicode_AsUTF8(pyfile_name);
132214
if (pyfile_name_char_ptr == nullptr) {
133-
Py_DECREF(pyfile_name);
134-
return absl::InvalidArgumentError(
215+
ret = absl::InvalidArgumentError(
135216
"FileDescriptor 'name' PyUnicode_AsUTF8() failure.");
217+
goto err;
136218
}
137-
if (!is_from_generated_pool) {
138-
std::string error_msg = absl::StrCat(pyfile_name_char_ptr,
139-
" is not from generated pool");
140-
Py_DECREF(pyfile_name);
141-
return absl::InvalidArgumentError(error_msg);
142-
}
143-
const google::protobuf::FileDescriptor* file_descriptor =
144-
pool->FindFileByName(pyfile_name_char_ptr);
145-
Py_DECREF(pyfile_name);
219+
file_descriptor = pool->FindFileByName(pyfile_name_char_ptr);
146220
if (file_descriptor == nullptr) {
147-
// Already checked the file is from generated pool above, this
148-
// error should never be reached.
221+
// This error should never be reached.
149222
ABSL_DLOG(ERROR) << "MEANT TO BE UNREACHABLE.";
150223
std::string error_msg = absl::StrCat("Fail to find/build file ",
151224
pyfile_name_char_ptr);
152-
return absl::InternalError(error_msg);
225+
ret = absl::InternalError(error_msg);
226+
goto err;
153227
}
154228

155-
const google::protobuf::Descriptor* descriptor =
156-
pool->FindMessageTypeByName(descriptor_full_name);
157-
if (descriptor == nullptr) {
158-
return absl::InternalError("Fail to find descriptor by name.");
229+
ret = pool->FindMessageTypeByName(descriptor_full_name);
230+
if (ret.value() == nullptr) {
231+
ret = absl::InternalError("Fail to find descriptor by name.");
159232
}
160-
return descriptor;
233+
234+
err:
235+
Py_XDECREF(pyfile_name);
236+
Py_XDECREF(pyfile_pool);
237+
Py_XDECREF(pypool_state);
238+
return ret;
161239
}
162240

163241
google::protobuf::DynamicMessageFactory* GetFactory() {
@@ -199,8 +277,8 @@ absl::StatusOr<google::protobuf::Message*> CreateNewMessage(PyObject* py_msg) {
199277
}
200278
auto d = FindMessageDescriptor(pyfile, descriptor_full_name);
201279
Py_DECREF(pyfile);
202-
RETURN_IF_ERROR(d.status());
203280
Py_DECREF(fn);
281+
RETURN_IF_ERROR(d.status());
204282
return GetFactory()->GetPrototype(*d)->New();
205283
}
206284

0 commit comments

Comments
 (0)