From b2f3198a08d4661a3bed50af62e45301576417cd Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 25 Aug 2021 10:25:43 +0800 Subject: [PATCH] fix potential tensor leak in tensor.__setitem__ (#35013) * fix index tensor leak in __setitem__ * fix another usage of PyTuple_Pack * refine code * refine code * handle None index * add Py_DecRef * revert ut * refine code * merge develop * use RAII * follow comments --- paddle/fluid/pybind/imperative.cc | 70 +++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 6c4213979a46be..cbf585804e63ef 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -29,6 +29,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/imperative/all_reduce.h" #include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/basic_engine.h" @@ -426,7 +427,15 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, // types. // Ellipsis and None are not supported yet. // wrap to tuple + + // NOTE(zhiqiu): PyTuple_Pack increases refcount. PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index; + DEFINE_PADDLE_SCOPE_GUARD([index, _index]() { + if (!PyTuple_Check(_index)) { + Py_DECREF(index); + VLOG(4) << "Call Py_DECREF"; + } + }); PADDLE_ENFORCE_EQ( tensor->IsInitialized(), true, platform::errors::InvalidArgument("tensor has not been initialized")); @@ -505,8 +514,6 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, platform::errors::InvalidArgument( "Too many indices (%d) for tensor of dimension %d.", valid_indexs, rank)); - - if (!PyTuple_Check(_index)) Py_DecRef(index); } template @@ -766,11 +773,21 @@ void BindImperative(py::module *m_ptr) { .def("__setitem__", [](std::shared_ptr &self, py::handle _index, py::object &value_obj) { + VLOG(4) << "Call __setitem__"; + auto self_tensor = self->MutableVar()->GetMutable(); + // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New + // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251 PyObject *index_ptr = !PyTuple_Check(_index.ptr()) ? PyTuple_Pack(1, _index.ptr()) : _index.ptr(); + DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() { + if (!PyTuple_Check(_index.ptr())) { + Py_DECREF(index_ptr); + VLOG(4) << "Call Py_DECREF"; + } + }); // 1. Check argumnets // 1.1 Check whether value obj is a tensor. bool value_is_tensor = true; @@ -781,6 +798,18 @@ void BindImperative(py::module *m_ptr) { value_is_tensor = false; } + auto is_tensor = [](py::handle var) { + if (!var.ptr() || var.ptr() == Py_None) { + return false; + } + try { + py::cast>(var); + return true; + } catch (py::cast_error &) { + return false; + } + }; + // 1.2 Check whether _index can be parsed. const int size = PyTuple_GET_SIZE(index_ptr); for (int dim = 0; dim < size; ++dim) { @@ -797,12 +826,13 @@ void BindImperative(py::module *m_ptr) { // TODO(liym27): Try not to call TensorToPyArray because it always // copys data to cpu place, which reduces performance. if (parse_index && value_is_tensor) { + VLOG(4) << "index is integer/slice/ellipsis and value is tensor"; std::vector axes, starts, ends, steps, decrease_axes, none_axes, infer_flags; ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, &steps, &decrease_axes, &none_axes, - &infer_flags); - + &infer_flags, &list_select_idxs, + &list_select_flag); framework::AttributeMap attrs = { {"axes", axes}, {"starts", starts}, @@ -834,20 +864,43 @@ void BindImperative(py::module *m_ptr) { } } else { auto self_numpy = TensorToPyArray(*self_tensor); + VLOG(4) << "parse_index is false"; if (value_is_tensor) { + VLOG(4) << "value is tensor"; auto value = value_obj.cast>(); auto value_tensor = value->MutableVar()->GetMutable(); auto value_numpy = TensorToPyArray(*value_tensor); - - self_numpy[_index] = value_numpy; + if (is_tensor(_index)) { + VLOG(4) << "index is tensor"; + auto index_var = + py::cast>(_index); + auto index_tensor = index_var->MutableVar() + ->GetMutable(); + auto index_numpy = TensorToPyArray(*index_tensor); + self_numpy[index_numpy] = value_numpy; + } else { + VLOG(4) << "index is not tensor"; + self_numpy[_index] = value_numpy; + } SetTensorFromPyArray(self_tensor, self_numpy, self_tensor->place(), true); } else { - auto value_numpy = value_obj; - self_numpy[_index] = value_numpy; + VLOG(4) << "value is not tensor"; + if (is_tensor(_index)) { + VLOG(4) << "index is tensor"; + auto index_var = + py::cast>(_index); + auto index_tensor = index_var->MutableVar() + ->GetMutable(); + auto index_numpy = TensorToPyArray(*index_tensor); + self_numpy[index_numpy] = value_obj; + } else { + VLOG(4) << "index is not tensor"; + self_numpy[_index] = value_obj; + } SetTensorFromPyArray(self_tensor, self_numpy, self_tensor->place(), true); } @@ -859,6 +912,7 @@ void BindImperative(py::module *m_ptr) { }) .def("_getitem_index_not_tensor", [](std::shared_ptr &self, py::handle _index) { + VLOG(4) << "Call _getitem_index_not_tensor"; std::vector slice_axes, slice_starts, slice_ends, slice_strides, decrease_axis, none_axes, infer_flags; auto tensor =