Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ void BindImperative(py::module *m_ptr) {
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def("__setitem__",
.def("__setitem_varbase__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
py::object &value_obj) {
auto self_tensor =
Expand Down
50 changes: 47 additions & 3 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .. import framework
from .. import core
from .. import unique_name
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_
from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss
Expand Down Expand Up @@ -559,7 +559,25 @@ def contain_tensor(item):
return True
return False

if contain_tensor(item):
def is_list_tuple(index, contain_type):
def _is_list_tuple(item):
if not (isinstance(item, (list, tuple)) or
type(item) == contain_type):
return False
if isinstance(item, (tuple, list)):
for s in item:
if not _is_list_tuple(s):
return False
return True

if not isinstance(index, (tuple, list)):
return False
for s in index:
if not _is_list_tuple(s):
return False
return True

if contain_tensor(item) or is_list_tuple(item, int):
# 1. Call _getitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return _getitem_impl_(self, item)
Expand All @@ -568,6 +586,31 @@ def contain_tensor(item):
# 2. Call c++ func getitem_index_not_tensor to speedup.
return self._getitem_index_not_tensor(item)

def __setitem__(self, item, value):
def contain_tensor(item):
if not isinstance(item, tuple):
item = [item]

for slice_item in item:
if isinstance(slice_item, slice):
if isinstance(slice_item.start, Variable) \
or isinstance(slice_item.stop, Variable) \
or isinstance(slice_item.step, Variable):
return True
else:
if isinstance(slice_item, Variable):
return True
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可以和__getitem__共用一份contain_tensor代码 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.


if contain_tensor(item):
# 1. Call _getitem_impl_ when item contains tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get -> set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return _setitem_impl_(self, item, value)

else:
# 2. Call c++ func getitem_index_not_tensor to speedup.
return self.__setitem_varbase__(item, value)

for method_name, method in (
("__bool__", __bool__), ("__nonzero__", __nonzero__),
("_to_static_var", _to_static_var), ("set_value", set_value),
Expand All @@ -577,7 +620,8 @@ def contain_tensor(item):
("__str__", __str__), ("__repr__", __str__),
("__deepcopy__", __deepcopy__), ("__module__", "paddle"),
("__name__", "Tensor"), ("__array__", __array__),
("__getitem__", __getitem__), ("item", item)):
("__getitem__", __getitem__), ("item", item),
("__setitem__", __setitem__)):
setattr(core.VarBase, method_name, method)

# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
Expand Down
Loading