Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/fluid/pybind/manual_static_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,17 +379,17 @@ static PyObject *static_api_slice_array(PyObject *self,
starts_tmp, phi::DataType::INT64, phi::CPUPlace());
}

PyObject *ends_obj = PyTuple_GET_ITEM(args, 1);
PyObject *ends_obj = PyTuple_GET_ITEM(args, 2);
pir::Value ends;
if (PyObject_CheckIRValue(ends_obj)) {
ends = CastPyArg2Value(ends_obj, "slice_array", 1);
ends = CastPyArg2Value(ends_obj, "slice_array", 2);
} else if (PyObject_CheckIRVectorOfValue(ends_obj)) {
std::vector<pir::Value> ends_tmp =
CastPyArg2VectorOfValue(ends_obj, "slice_array", 1);
CastPyArg2VectorOfValue(ends_obj, "slice_array", 2);
ends = paddle::dialect::stack(ends_tmp, /*axis*/ 0);
} else {
std::vector<int64_t> ends_tmp =
CastPyArg2Longs(ends_obj, "slice_array", 1);
CastPyArg2Longs(ends_obj, "slice_array", 2);
ends = paddle::dialect::full_int_array(
ends_tmp, phi::DataType::INT64, phi::CPUPlace());
}
Expand Down
20 changes: 13 additions & 7 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _setitem_for_tensor_array(var, item, value):
assert (
not paddle.in_dynamic_mode()
), "setitem for tensor_array must be called in static graph mode."
if isinstance(item, (Variable, int)):
if isinstance(item, (Variable, paddle.pir.Value, int)):
from paddle.jit.dy2static.convert_operators import to_static_variable
from paddle.tensor import array_write

Expand Down Expand Up @@ -248,17 +248,21 @@ def slice_is_same_to_original(start, end, step):
return start == 0 and end == MAX_INTEGER and step == 1


def parse_index(x, indices):
def is_tensor_array_type(value):
from .framework import in_pir_mode

if in_pir_mode():
is_tensor_array = x.is_dense_tensor_array_type()
return value.is_dense_tensor_array_type()
else:
is_tensor_array = (
hasattr(x, "desc")
and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
return (
hasattr(value, "desc")
and value.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
)


def parse_index(x, indices):
is_tensor_array = is_tensor_array_type(x)

advanced_index = (
[] if is_tensor_array else [None] * 2 * len(x.shape)
) # content is (dim, index)
Expand Down Expand Up @@ -448,7 +452,9 @@ def _setitem_static(x, indices, values):
from . import in_dynamic_or_pir_mode
from .framework import Variable, default_main_program, in_pir_mode

if x.type == paddle.base.core.VarDesc.VarType.LOD_TENSOR_ARRAY:
is_tensor_array = is_tensor_array_type(x)

if is_tensor_array:
return _setitem_for_tensor_array(x, indices, values)

# step1: parsing the index and recording them
Expand Down
9 changes: 1 addition & 8 deletions test/dygraph_to_static/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class TestSliceInIf(TestSliceBase):
def init_dygraph_func(self):
self.dygraph_func = test_slice_in_if

@test_legacy_and_pt_and_pir
def test_transformed_static_result(self):
self.init_dygraph_func()
static_res = self.run_static_mode()
Expand All @@ -179,14 +180,6 @@ def init_input(self):
def init_dygraph_func(self):
self.dygraph_func = test_set_value

# TODO(pir-control-flow): Delete this code after supporting control flow
@test_legacy_and_pt_and_pir
def test_transformed_static_result(self):
self.init_dygraph_func()
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)


class TestSetValueWithLayerAndSave(Dy2StTestBase):
def setUp(self):
Expand Down