Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2991,7 +2991,7 @@ void SliceArrayDenseOp::VerifySig() {
{
auto input_size = num_operands();
IR_ENFORCE(input_size == 2u,
"The size %d of inputs must be equal to 1.",
"The size %d of inputs must be equal to 2.",
input_size);
IR_ENFORCE((*this)
->operand_source(0)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3806,6 +3806,7 @@ void SliceArrayDenseInferMeta(const MetaTensor& input,
}
// out->set_dims(input.dims());
out->set_dtype(input.dtype());
out->set_dims(input.dims());
}

void SliceRawInferMeta(const MetaTensor& input,
Expand Down
8 changes: 5 additions & 3 deletions test/dygraph_to_static/test_for_enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def for_iter_var_list(x):
# 1. prepare data, ref test_list.py
x = paddle.to_tensor(x)
iter_num = paddle.tensor.fill_constant(shape=[1], value=5, dtype="int32")
a = []
a = paddle.tensor.create_array("int32")
for i in range(iter_num):
a.append(x + i)
# 2. iter list[var]
Expand All @@ -187,7 +187,7 @@ def for_enumerate_var_list(x):
# 1. prepare data, ref test_list.py
x = paddle.to_tensor(x)
iter_num = paddle.tensor.fill_constant(shape=[1], value=5, dtype="int32")
a = []
a = paddle.tensor.create_array("int32")
for i in range(iter_num):
a.append(x + i)
# 2. iter list[var]
Expand Down Expand Up @@ -389,7 +389,7 @@ def transformed_error(self, etype):

class TestForInRangeConfig(TestTransform):
def set_input(self):
self.input = np.array([5])
self.input = np.array([5]).astype("int32")

def set_test_func(self):
self.dygraph_func = for_in_range
Expand Down Expand Up @@ -489,6 +489,7 @@ class TestForIterVarList(TestForInRangeConfig):
def set_test_func(self):
self.dygraph_func = for_iter_var_list

@test_legacy_and_pt_and_pir
def test_transformed_result_compare(self):
self.set_test_func()
self.transformed_result_compare()
Expand All @@ -498,6 +499,7 @@ class TestForEnumerateVarList(TestForInRangeConfig):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_list

@test_legacy_and_pt_and_pir
def test_transformed_result_compare(self):
self.set_test_func()
self.transformed_result_compare()
Expand Down