File tree Expand file tree Collapse file tree 3 files changed +18
-7
lines changed
Expand file tree Collapse file tree 3 files changed +18
-7
lines changed Original file line number Diff line number Diff line change @@ -63,15 +63,16 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
6363 }
6464}
6565
66- py::object PyLayerApply (const platform::Place& place, const py::object & cls,
66+ py::object PyLayerApply (const platform::Place& place, const py::handle & cls,
6767 const py::args args, const py::kwargs kwargs) {
68+ py::gil_scoped_acquire guard;
6869 auto bk_function = cls.attr (" _backward_function" );
6970 auto context = bk_function ();
7071 auto forward = cls.attr (" forward" );
7172
7273 auto result_forward = forward (context, *args, **kwargs);
7374 std::shared_ptr<operators::PyLayerContext> py_layer_ctx =
74- std::make_shared<operators::PyLayerContext>(context.release (). ptr ());
75+ std::make_shared<operators::PyLayerContext>(context.ptr ());
7576 // make inputs to varbase
7677 std::vector<std::shared_ptr<imperative::VarBase>> input_vars;
7778 // process args,`input_vars` only collect `imperative::VarBase`
Original file line number Diff line number Diff line change @@ -157,9 +157,12 @@ class PyLayerOpKernel : public framework::OpKernel<T> {
157157 public:
158158 void Compute (const framework::ExecutionContext &ctx) const override {
159159 auto &op_ = ctx.GetOp ();
160- auto pylayer_op = dynamic_cast <const PyLayerOp *>(&op_);
161- if (pylayer_op) {
162- auto py_layer_context = pylayer_op->GetPyLayerContext ();
160+ auto const_pylayer_op = dynamic_cast <const PyLayerOp *>(&op_);
161+ if (const_pylayer_op) {
162+ auto pylayer_op = const_cast <PyLayerOp *>(const_pylayer_op);
163+
164+ // Release contex after executing the compute
165+ auto py_layer_context = pylayer_op->ReleasePyLayerContext ();
163166 py::object bk_ctx (py::handle (py_layer_context->GetMutableCtx ()), true );
164167 auto &input_vars = ctx.MultiInputVar (" X" );
165168 auto output_vars = ctx.MultiOutputVar (" Out" );
Original file line number Diff line number Diff line change @@ -34,6 +34,10 @@ class PyLayerContext {
3434 PyLayerContext () = delete ;
3535
3636 PyObject* GetMutableCtx () { return context_; }
37+ ~PyLayerContext () {
38+ py::gil_scoped_acquire guard;
39+ Py_XDECREF (context_);
40+ }
3741
3842 private:
3943 PyObject* context_;
@@ -58,8 +62,11 @@ class PyLayerOp : public framework::OperatorWithKernel {
5862 void SetPyLayerContext (const std::shared_ptr<PyLayerContext>& py_context) {
5963 py_context_ = py_context;
6064 }
61- const std::shared_ptr<PyLayerContext>& GetPyLayerContext () const {
62- return py_context_;
65+ std::shared_ptr<PyLayerContext> ReleasePyLayerContext () {
66+ auto temp = py_context_;
67+ py_context_.reset ();
68+ VLOG (3 ) << " `py_context_` in the PyLayerOp is released." ;
69+ return temp;
6370 }
6471
6572 private:
You can’t perform that action at this time.
0 commit comments