Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
&use_strided_slice);

// step2: Dealing with basic indexing
bool out_is_view = false;
auto out = getTensorWithBasicIndexing(tensor,
&slice_axes,
&slice_starts,
Expand All @@ -1372,7 +1373,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
&decrease_axis,
&none_axes,
&infer_flags,
&use_strided_slice);
&use_strided_slice,
&out_is_view);

if (!has_advanced_index) {
return ToPyObject(out);
Expand All @@ -1391,7 +1393,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
&trans_back_dim,
&pos_of_new_dim,
&rank_of_new_dim,
&trans_dim);
&trans_dim,
&out_is_view);

if (transed_index.size() == 1 &&
transed_index[0].dtype() == phi::DataType::BOOL) {
Expand Down Expand Up @@ -1686,6 +1689,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
// 3. assign values to the sliced result by index_put OP;
// 4. transpose back and assign the result to original tensor by set_value
// OP.
bool out_is_view = false;
paddle::Tensor sub_tensor = getTensorWithBasicIndexing(tensor,
&slice_axes,
&slice_starts,
Expand All @@ -1694,7 +1698,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&decrease_axis,
&none_axes,
&infer_flags,
&use_strided_slice);
&use_strided_slice,
&out_is_view);

std::vector<paddle::Tensor> transed_index;
std::vector<int> trans_back_dim, trans_dim;
Expand All @@ -1710,7 +1715,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&trans_back_dim,
&pos_of_new_dim,
&rank_of_new_dim,
&trans_dim);
&trans_dim,
&out_is_view);

// Release gil and do tracing
py::gil_scoped_release release;
Expand All @@ -1737,10 +1743,6 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
value_tensor = transpose_ad_func(value_tensor, trans_dim);
}

// TODO(zoooo0820) 1.Using inplace version index_put
// 2.Remove following code after backward bug fixed.
transed_sub_tensor = assign_ad_func(transed_sub_tensor);

const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(
&mesh, self->tensor, transed_sub_tensor, value_tensor)) {
Expand All @@ -1749,19 +1751,22 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
}

transed_sub_tensor =
index_put_ad_func(transed_sub_tensor, transed_index, value_tensor);

paddle::Tensor transback_sub_tensor =
transpose_ad_func(transed_sub_tensor, trans_back_dim);

self->tensor = set_value_with_tensor__ad_func(self->tensor,
transback_sub_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
index_put__ad_func(transed_sub_tensor, transed_index, value_tensor);

// TODO(zoooo0820) Remove following code after backward bug fixed.
if (out_is_view) {
paddle::Tensor transback_sub_tensor =
transpose_ad_func(transed_sub_tensor, trans_back_dim);

self->tensor = set_value_with_tensor__ad_func(self->tensor,
transback_sub_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
}
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
Expand Down
9 changes: 7 additions & 2 deletions paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,13 @@ static paddle::Tensor getTensorWithBasicIndexing(
std::vector<int64_t>* decrease_axis,
std::vector<int64_t>* none_axes,
std::vector<int64_t>* infer_flags,
bool* use_strided_slice) {
bool* use_strided_slice,
bool* out_is_view) {
paddle::Tensor out;
if (slice_axes->empty()) {
out = tensor;
} else {
*out_is_view = true;
if (!(*use_strided_slice)) {
eager_gil_scoped_release guard;
out = slice_ad_func(tensor,
Expand All @@ -373,6 +375,7 @@ static paddle::Tensor getTensorWithBasicIndexing(
}
}
if (!none_axes->empty()) {
*out_is_view = true;
eager_gil_scoped_release guard;
// Deal with cases that decrease_axes is not empty
// For example:
Expand Down Expand Up @@ -401,7 +404,8 @@ static paddle::Tensor dealWithAdvancedIndex(
std::vector<int>* trans_back_dim,
int* pos_of_new_dim,
int* rank_of_new_dim,
std::vector<int>* trans_dim) {
std::vector<int>* trans_dim,
bool* out_is_view) {
int p = 0;
for (size_t i = 0; i < advanced_index_dim->size(); ++i) {
auto index_dim = (*advanced_index_dim)[i];
Expand Down Expand Up @@ -444,6 +448,7 @@ static paddle::Tensor dealWithAdvancedIndex(
if (original_dim_order == *trans_dim) {
transed_tensor = tensor;
} else {
*out_is_view = true;
transed_tensor = transpose_ad_func(tensor, *trans_dim);
}

Expand Down
58 changes: 33 additions & 25 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import numpy as np

Expand Down Expand Up @@ -170,7 +169,9 @@ def _setitem_for_tensor_array(var, item, value):
)


def deal_advanced_index(ori_tensor, indices, is_for_setitem, values):
def deal_advanced_index(
ori_tensor, indices, is_for_setitem, values, out_is_view=True
):
"""
Transpose origin Tensor and advanced indices to the front.

Expand Down Expand Up @@ -206,18 +207,24 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem, values):
for i in range(ori_tensor.ndim):
if indices[i] is None:
transed_dim.append(i)
transed_tensor = ori_tensor.transpose(transed_dim)

trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else []

transed_value_tensor = None
if is_for_setitem:
if values.ndim > 1 and pos_of_new_dim != 0:
# If the value tensor is not a scalar / 1-D Tensor, and the src tensor was
# transposed at 1st dim, the value tensor should be transposed too.
transed_value_tensor = values.transpose(transed_dim)
else:

if transed_dim == list(range(ori_tensor.ndim)):
transed_tensor = ori_tensor
if is_for_setitem:
transed_value_tensor = values
else:
out_is_view = True
transed_tensor = ori_tensor.transpose(transed_dim)
if is_for_setitem:
if values.ndim > 1 and pos_of_new_dim != 0:
# If the value tensor is not a scalar / 1-D Tensor, and the src tensor was
# transposed at 1st dim, the value tensor should be transposed too.
transed_value_tensor = values.transpose(transed_dim)
else:
transed_value_tensor = values

return (
transed_tensor,
Expand All @@ -226,6 +233,7 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem, values):
pos_of_new_dim,
rank_of_new_dim,
transed_value_tensor,
out_is_view,
)


Expand Down Expand Up @@ -599,7 +607,7 @@ def _setitem_static(x, indices, values):
):
values = paddle.assign(values).astype(x.dtype)

sub_tensor = get_tensor_with_basic_indexing(
sub_tensor, is_view = get_tensor_with_basic_indexing(
x,
axes,
starts,
Expand All @@ -616,16 +624,21 @@ def _setitem_static(x, indices, values):
_,
_,
values,
) = deal_advanced_index(sub_tensor, advanced_index, True, values)
is_view,
) = deal_advanced_index(
sub_tensor, advanced_index, True, values, is_view
)

if values.dtype != transed_sub_tensor.dtype:
values = values.astype(transed_sub_tensor.dtype)

if in_dynamic_or_pir_mode():
if paddle.in_dynamic_mode():
# NOTE(zoooo0820): directly return result instead of another set_value, after backward bug fixed.
transed_sub_tensor = transed_sub_tensor.index_put_(
adjusted_advanced_index, values
)
if not is_view:
return transed_sub_tensor
else:
transed_sub_tensor = transed_sub_tensor.index_put(
adjusted_advanced_index, values
Expand Down Expand Up @@ -694,12 +707,14 @@ def get_tensor_with_basic_indexing(
):
from .dygraph.base import in_to_static_mode

out_is_view = False
if in_to_static_mode() and hasattr(x, "is_view_var"):
x.is_view_var = True

if len(axes) == 0:
out = x
else:
out_is_view = True
op_type = "strided_slice" if use_strided_slice else "slice"
inputs = {'Input': [x]}
attrs = {
Expand Down Expand Up @@ -748,7 +763,7 @@ def get_tensor_with_basic_indexing(
if paddle.utils._contain_var(end):
end = paddle.utils.get_int_tensor_list(end)
if x.is_dense_tensor_array_type():
return paddle._pir_ops.slice_array_dense(x, st)
return paddle._pir_ops.slice_array_dense(x, st), False
out = paddle._C_ops.slice(
x,
axes,
Expand All @@ -775,17 +790,9 @@ def get_tensor_with_basic_indexing(
attrs=attrs,
)
out = slice_out_var
# NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
# with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
# otherwise the output shape will be not correct.
set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']
if set_to_1d and len(decrease_axes) == len(x.shape):
warnings.warn(
"Warning: In Tensor '__getitem__', if the number of scalar elements in the index is equal to the rank of the Tensor, the output should be 0-D. In order to be consistent with the behavior of previous versions, it will be processed to 1-D. But it is not correct and will be removed in release 2.6. If 1-D is still wanted, please modify the index element from scalar to slice (e.g. 'x[i]' => 'x[i:i+1]')."
)
none_axes = none_axes[1:]

if len(none_axes) > 0:
out_is_view = True
# Deal with cases that decrease_axes is not empty
# For example:
# # x.shape: (2,3,4)
Expand All @@ -799,7 +806,7 @@ def get_tensor_with_basic_indexing(

if in_to_static_mode() and hasattr(out, "is_view_var"):
out.is_view_var = True
return out
return out, out_is_view


def _getitem_static(x, indices):
Expand All @@ -822,7 +829,7 @@ def _getitem_static(x, indices):
) = parse_index(x, indices)

# step2: Dealing with basic indexing
out = get_tensor_with_basic_indexing(
out, _ = get_tensor_with_basic_indexing(
x,
axes,
starts,
Expand All @@ -842,6 +849,7 @@ def _getitem_static(x, indices):
pos_of_new_dim,
rank_of_new_dim,
_,
_,
) = deal_advanced_index(out, advanced_index, False, None)

# TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently
Expand Down
29 changes: 29 additions & 0 deletions test/indexing/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ def setUp(self):
self.ndtype = np.float64
self.dtype = 'float64'

def test_advanced_index(self):
np_data = np.zeros((3, 4, 5, 6), dtype='float32').astype(self.ndtype)
if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

x = paddle.to_tensor(np_data, dtype=self.dtype)
np_data[[0, 1], [1, 2], [1]] = 10.0
x[[0, 1], [1, 2], [1]] = 10.0

if self.dtype == 'bfloat16':
x = paddle.cast(x, dtype='float32')
np.testing.assert_allclose(x.numpy(), np_data)

def test_combined_index_1(self):
np_data = np.zeros((3, 4, 5, 6), dtype='float32').astype(self.ndtype)
if self.dtype == 'bfloat16':
Expand Down Expand Up @@ -426,6 +441,20 @@ def setUp(self):
paddle.enable_static()
self.exe = paddle.static.Executor()

@test_with_pir_api
def test_advanced_index(self):
# multi-int tensor
np_data = np.zeros((3, 4, 5, 6), dtype='float32')
np_data[[0, 1], [1, 2], [1]] = 10.0
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.zeros((3, 4, 5, 6), dtype='float32')
y = _setitem_static(x, ([0, 1], [1, 2], [1]), 10.0)
res = self.exe.run(fetch_list=[y])

np.testing.assert_allclose(res[0], np_data)

@test_with_pir_api
def test_combined_index_1(self):
# int tensor + slice (without decreasing axes)
Expand Down