Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ void BindImperative(py::module *m_ptr) {
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def("__setitem__",
.def("__setitem_varbase__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
py::object &value_obj) {
auto self_tensor =
Expand Down
64 changes: 47 additions & 17 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .. import framework
from .. import core
from .. import unique_name
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_
from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss
Expand Down Expand Up @@ -543,23 +543,41 @@ def __array__(self, dtype=None):
array = array.astype(dtype)
return array

def contain_tensor(item):
if not isinstance(item, tuple):
item = [item]

for slice_item in item:
if isinstance(slice_item, slice):
if isinstance(slice_item.start, Variable) \
or isinstance(slice_item.stop, Variable) \
or isinstance(slice_item.step, Variable):
return True
else:
if isinstance(slice_item, Variable):
return True
return False

def __getitem__(self, item):
def contain_tensor(item):
if not isinstance(item, tuple):
item = [item]

for slice_item in item:
if isinstance(slice_item, slice):
if isinstance(slice_item.start, Variable) \
or isinstance(slice_item.stop, Variable) \
or isinstance(slice_item.step, Variable):
return True
else:
if isinstance(slice_item, Variable):
return True
return False
def is_list_tuple(index, contain_type):
def _is_list_tuple(item):
if not (isinstance(item, (list, tuple)) or
type(item) == contain_type):
return False
if isinstance(item, (tuple, list)):
for s in item:
if not _is_list_tuple(s):
return False
return True

if contain_tensor(item):
if not isinstance(index, (tuple, list)):
return False
for s in index:
if not _is_list_tuple(s):
return False
return True

if contain_tensor(item) or is_list_tuple(item, int):
# 1. Call _getitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return _getitem_impl_(self, item)
Expand All @@ -568,6 +586,17 @@ def contain_tensor(item):
# 2. Call c++ func getitem_index_not_tensor to speedup.
return self._getitem_index_not_tensor(item)

def __setitem__(self, item, value):

if contain_tensor(item):
# 1. Call _setitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return _setitem_impl_(self, item, value)

else:
# 2. Call c++ func __setitem_varbase__ to speedup.
return self.__setitem_varbase__(item, value)

for method_name, method in (
("__bool__", __bool__), ("__nonzero__", __nonzero__),
("_to_static_var", _to_static_var), ("set_value", set_value),
Expand All @@ -577,7 +606,8 @@ def contain_tensor(item):
("__str__", __str__), ("__repr__", __str__),
("__deepcopy__", __deepcopy__), ("__module__", "paddle"),
("__name__", "Tensor"), ("__array__", __array__),
("__getitem__", __getitem__), ("item", item)):
("__getitem__", __getitem__), ("item", item),
("__setitem__", __setitem__)):
setattr(core.VarBase, method_name, method)

# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
Expand Down
Loading