diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index feaf7ccd1a2f68..57a1a86048b4db 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1364,6 +1364,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self, &use_strided_slice); // step2: Dealing with basic indexing + bool out_is_view = false; auto out = getTensorWithBasicIndexing(tensor, &slice_axes, &slice_starts, @@ -1372,7 +1373,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self, &decrease_axis, &none_axes, &infer_flags, - &use_strided_slice); + &use_strided_slice, + &out_is_view); if (!has_advanced_index) { return ToPyObject(out); @@ -1391,7 +1393,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self, &trans_back_dim, &pos_of_new_dim, &rank_of_new_dim, - &trans_dim); + &trans_dim, + &out_is_view); if (transed_index.size() == 1 && transed_index[0].dtype() == phi::DataType::BOOL) { @@ -1686,6 +1689,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, // 3. assign values to the sliced result by index_put OP; // 4. transpose back and assign the result to original tensor by set_value // OP. + bool out_is_view = false; paddle::Tensor sub_tensor = getTensorWithBasicIndexing(tensor, &slice_axes, &slice_starts, @@ -1694,7 +1698,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, &decrease_axis, &none_axes, &infer_flags, - &use_strided_slice); + &use_strided_slice, + &out_is_view); std::vector transed_index; std::vector trans_back_dim, trans_dim; @@ -1710,7 +1715,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, &trans_back_dim, &pos_of_new_dim, &rank_of_new_dim, - &trans_dim); + &trans_dim, + &out_is_view); // Release gil and do tracing py::gil_scoped_release release; @@ -1737,10 +1743,6 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, value_tensor = transpose_ad_func(value_tensor, trans_dim); } - // TODO(zoooo0820) 1.Using inplace version index_put - // 2.Remove following code after backward bug fixed. - transed_sub_tensor = assign_ad_func(transed_sub_tensor); - const phi::distributed::ProcessMesh* mesh = nullptr; if (InputsContainDistTensor( &mesh, self->tensor, transed_sub_tensor, value_tensor)) { @@ -1749,19 +1751,22 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, } transed_sub_tensor = - index_put_ad_func(transed_sub_tensor, transed_index, value_tensor); - - paddle::Tensor transback_sub_tensor = - transpose_ad_func(transed_sub_tensor, trans_back_dim); - - self->tensor = set_value_with_tensor__ad_func(self->tensor, - transback_sub_tensor, - slice_starts, - slice_ends, - slice_strides, - slice_axes, - decrease_axis, - none_axes); + index_put__ad_func(transed_sub_tensor, transed_index, value_tensor); + + // TODO(zoooo0820) Remove following code after backward bug fixed. + if (out_is_view) { + paddle::Tensor transback_sub_tensor = + transpose_ad_func(transed_sub_tensor, trans_back_dim); + + self->tensor = set_value_with_tensor__ad_func(self->tensor, + transback_sub_tensor, + slice_starts, + slice_ends, + slice_strides, + slice_axes, + decrease_axis, + none_axes); + } if (PyCheckTensor(value_obj)) { // pass the stop_gradient from value to tensor. // pass stop gradient should be done after CheckInplace in diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index 82bdcc80562c45..7811704f6753dd 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -348,11 +348,13 @@ static paddle::Tensor getTensorWithBasicIndexing( std::vector* decrease_axis, std::vector* none_axes, std::vector* infer_flags, - bool* use_strided_slice) { + bool* use_strided_slice, + bool* out_is_view) { paddle::Tensor out; if (slice_axes->empty()) { out = tensor; } else { + *out_is_view = true; if (!(*use_strided_slice)) { eager_gil_scoped_release guard; out = slice_ad_func(tensor, @@ -373,6 +375,7 @@ static paddle::Tensor getTensorWithBasicIndexing( } } if (!none_axes->empty()) { + *out_is_view = true; eager_gil_scoped_release guard; // Deal with cases that decrease_axes is not empty // For example: @@ -401,7 +404,8 @@ static paddle::Tensor dealWithAdvancedIndex( std::vector* trans_back_dim, int* pos_of_new_dim, int* rank_of_new_dim, - std::vector* trans_dim) { + std::vector* trans_dim, + bool* out_is_view) { int p = 0; for (size_t i = 0; i < advanced_index_dim->size(); ++i) { auto index_dim = (*advanced_index_dim)[i]; @@ -444,6 +448,7 @@ static paddle::Tensor dealWithAdvancedIndex( if (original_dim_order == *trans_dim) { transed_tensor = tensor; } else { + *out_is_view = true; transed_tensor = transpose_ad_func(tensor, *trans_dim); } diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index c4b20843864dfa..0f233d62bdc72d 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings import numpy as np @@ -170,7 +169,9 @@ def _setitem_for_tensor_array(var, item, value): ) -def deal_advanced_index(ori_tensor, indices, is_for_setitem, values): +def deal_advanced_index( + ori_tensor, indices, is_for_setitem, values, out_is_view=True +): """ Transpose origin Tensor and advanced indices to the front. @@ -206,18 +207,24 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem, values): for i in range(ori_tensor.ndim): if indices[i] is None: transed_dim.append(i) - transed_tensor = ori_tensor.transpose(transed_dim) trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else [] - transed_value_tensor = None - if is_for_setitem: - if values.ndim > 1 and pos_of_new_dim != 0: - # If the value tensor is not a scalar / 1-D Tensor, and the src tensor was - # transposed at 1st dim, the value tensor should be transposed too. - transed_value_tensor = values.transpose(transed_dim) - else: + + if transed_dim == list(range(ori_tensor.ndim)): + transed_tensor = ori_tensor + if is_for_setitem: transed_value_tensor = values + else: + out_is_view = True + transed_tensor = ori_tensor.transpose(transed_dim) + if is_for_setitem: + if values.ndim > 1 and pos_of_new_dim != 0: + # If the value tensor is not a scalar / 1-D Tensor, and the src tensor was + # transposed at 1st dim, the value tensor should be transposed too. + transed_value_tensor = values.transpose(transed_dim) + else: + transed_value_tensor = values return ( transed_tensor, @@ -226,6 +233,7 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem, values): pos_of_new_dim, rank_of_new_dim, transed_value_tensor, + out_is_view, ) @@ -599,7 +607,7 @@ def _setitem_static(x, indices, values): ): values = paddle.assign(values).astype(x.dtype) - sub_tensor = get_tensor_with_basic_indexing( + sub_tensor, is_view = get_tensor_with_basic_indexing( x, axes, starts, @@ -616,16 +624,21 @@ def _setitem_static(x, indices, values): _, _, values, - ) = deal_advanced_index(sub_tensor, advanced_index, True, values) + is_view, + ) = deal_advanced_index( + sub_tensor, advanced_index, True, values, is_view + ) if values.dtype != transed_sub_tensor.dtype: values = values.astype(transed_sub_tensor.dtype) - if in_dynamic_or_pir_mode(): + if paddle.in_dynamic_mode(): # NOTE(zoooo0820): directly return result instead of another set_value, after backward bug fixed. transed_sub_tensor = transed_sub_tensor.index_put_( adjusted_advanced_index, values ) + if not is_view: + return transed_sub_tensor else: transed_sub_tensor = transed_sub_tensor.index_put( adjusted_advanced_index, values @@ -694,12 +707,14 @@ def get_tensor_with_basic_indexing( ): from .dygraph.base import in_to_static_mode + out_is_view = False if in_to_static_mode() and hasattr(x, "is_view_var"): x.is_view_var = True if len(axes) == 0: out = x else: + out_is_view = True op_type = "strided_slice" if use_strided_slice else "slice" inputs = {'Input': [x]} attrs = { @@ -748,7 +763,7 @@ def get_tensor_with_basic_indexing( if paddle.utils._contain_var(end): end = paddle.utils.get_int_tensor_list(end) if x.is_dense_tensor_array_type(): - return paddle._pir_ops.slice_array_dense(x, st) + return paddle._pir_ops.slice_array_dense(x, st), False out = paddle._C_ops.slice( x, axes, @@ -775,17 +790,9 @@ def get_tensor_with_basic_indexing( attrs=attrs, ) out = slice_out_var - # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D - # with FLAGS_set_to_1d=True. In this case, one `None` should be pop out, - # otherwise the output shape will be not correct. - set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d'] - if set_to_1d and len(decrease_axes) == len(x.shape): - warnings.warn( - "Warning: In Tensor '__getitem__', if the number of scalar elements in the index is equal to the rank of the Tensor, the output should be 0-D. In order to be consistent with the behavior of previous versions, it will be processed to 1-D. But it is not correct and will be removed in release 2.6. If 1-D is still wanted, please modify the index element from scalar to slice (e.g. 'x[i]' => 'x[i:i+1]')." - ) - none_axes = none_axes[1:] if len(none_axes) > 0: + out_is_view = True # Deal with cases that decrease_axes is not empty # For example: # # x.shape: (2,3,4) @@ -799,7 +806,7 @@ def get_tensor_with_basic_indexing( if in_to_static_mode() and hasattr(out, "is_view_var"): out.is_view_var = True - return out + return out, out_is_view def _getitem_static(x, indices): @@ -822,7 +829,7 @@ def _getitem_static(x, indices): ) = parse_index(x, indices) # step2: Dealing with basic indexing - out = get_tensor_with_basic_indexing( + out, _ = get_tensor_with_basic_indexing( x, axes, starts, @@ -842,6 +849,7 @@ def _getitem_static(x, indices): pos_of_new_dim, rank_of_new_dim, _, + _, ) = deal_advanced_index(out, advanced_index, False, None) # TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py index 350176d1acb03a..34f3ca24eac15b 100644 --- a/test/indexing/test_setitem.py +++ b/test/indexing/test_setitem.py @@ -29,6 +29,21 @@ def setUp(self): self.ndtype = np.float64 self.dtype = 'float64' + def test_advanced_index(self): + np_data = np.zeros((3, 4, 5, 6), dtype='float32').astype(self.ndtype) + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex128': + np_data = np_data + 1j * np_data + + x = paddle.to_tensor(np_data, dtype=self.dtype) + np_data[[0, 1], [1, 2], [1]] = 10.0 + x[[0, 1], [1, 2], [1]] = 10.0 + + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') + np.testing.assert_allclose(x.numpy(), np_data) + def test_combined_index_1(self): np_data = np.zeros((3, 4, 5, 6), dtype='float32').astype(self.ndtype) if self.dtype == 'bfloat16': @@ -426,6 +441,20 @@ def setUp(self): paddle.enable_static() self.exe = paddle.static.Executor() + @test_with_pir_api + def test_advanced_index(self): + # multi-int tensor + np_data = np.zeros((3, 4, 5, 6), dtype='float32') + np_data[[0, 1], [1, 2], [1]] = 10.0 + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.zeros((3, 4, 5, 6), dtype='float32') + y = _setitem_static(x, ([0, 1], [1, 2], [1]), 10.0) + res = self.exe.run(fetch_list=[y]) + + np.testing.assert_allclose(res[0], np_data) + @test_with_pir_api def test_combined_index_1(self): # int tensor + slice (without decreasing axes)