Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 61 additions & 50 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,67 +2034,78 @@ def stack(x, axis=0, name=None):
"""
axis = 0 if axis is None else axis

if in_dynamic_or_pir_mode():
if in_dynamic_mode():
return _C_ops.stack(x, axis)
else:
if not isinstance(x, list) and not isinstance(x, tuple):
# NOTE:(zhiqiu) Only support Variable as input if the Variable is a LOD_TENSOR_ARRAY create by create_array, array_write, array_read, etc.
# In that case, Variable is array of tensors indeed.
if (
isinstance(x, Variable)
and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
x = [x]
else:
raise TypeError(
"The type of '{}' in {} must be {}, but received {}".format(
'x',
'stack',
'list[Tensor], tuple[Tensor] or TensorArray',
type(x),
)
)

helper = LayerHelper('stack', **locals())
if not isinstance(x, list) and not isinstance(x, tuple):
# NOTE:(zhiqiu) Only support Variable as input if the Variable is a LOD_TENSOR_ARRAY create by create_array, array_write, array_read, etc.
# In that case, Variable is array of tensors indeed.
if (
isinstance(x, Variable)
and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
) or (isinstance(x, paddle.pir.Value) and x.is_tensorarray()):
x = [x]
else:
raise TypeError(
"The type of '{}' in {} must be {}, but received {}".format(
'x',
'stack',
'list[Tensor], tuple[Tensor] or TensorArray',
type(x),
)
)

out = helper.create_variable_for_type_inference(x[0].dtype)
if x[0].desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
if in_pir_mode():
if x[0].is_tensorarray():
assert len(x) == 1, (
"If the elements of 'x' in stack are Variable(LoDTensorArray), "
"number of the elements must be 1, but received %s." % len(x)
)
out_index = helper.create_variable_for_type_inference(dtype="int32")
out, _ = _C_ops.array_to_tensor(x, axis, True)
return out
else:
return _C_ops.stack(x, axis)

for i in x:
check_variable_and_dtype(
i,
'x',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'stack',
)
helper = LayerHelper('stack', **locals())

helper.append_op(
type='tensor_array_to_tensor',
inputs={'X': x[0]},
outputs={'Out': [out], 'OutIndex': [out_index]},
attrs={'axis': axis, 'use_stack': True},
)
else:
helper.append_op(
type='stack',
inputs={'X': x},
outputs={'Y': out},
attrs={'axis': axis},
out = helper.create_variable_for_type_inference(x[0].dtype)
if x[0].desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
assert len(x) == 1, (
"If the elements of 'x' in stack are Variable(LoDTensorArray), "
"number of the elements must be 1, but received %s." % len(x)
)
out_index = helper.create_variable_for_type_inference(dtype="int32")

for i in x:
check_variable_and_dtype(
i,
'x',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'stack',
)

return out
helper.append_op(
type='tensor_array_to_tensor',
inputs={'X': x[0]},
outputs={'Out': [out], 'OutIndex': [out_index]},
attrs={'axis': axis, 'use_stack': True},
)
else:
helper.append_op(
type='stack',
inputs={'X': x},
outputs={'Y': out},
attrs={'axis': axis},
)

return out


def hstack(x, name=None):
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_stack_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def setUp(self):
else base.CPUPlace()
)

@test_with_pir_api
def test_case(self):
self.program = paddle.static.Program()
with paddle.static.program_guard(self.program):
Expand Down Expand Up @@ -258,6 +259,7 @@ def setUp(self):
else base.CPUPlace()
)

@test_with_pir_api
def test_case(self):
self.program = paddle.static.Program()
with paddle.static.program_guard(self.program):
Expand Down