diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 0e31db0fbf9a94..2468ae05ee1e53 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -2991,7 +2991,7 @@ void SliceArrayDenseOp::VerifySig() { { auto input_size = num_operands(); IR_ENFORCE(input_size == 2u, - "The size %d of inputs must be equal to 1.", + "The size %d of inputs must be equal to 2.", input_size); IR_ENFORCE((*this) ->operand_source(0) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 9bee9357e1ed13..611b5239dccdf8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3806,6 +3806,7 @@ void SliceArrayDenseInferMeta(const MetaTensor& input, } // out->set_dims(input.dims()); out->set_dtype(input.dtype()); + out->set_dims(input.dims()); } void SliceRawInferMeta(const MetaTensor& input, diff --git a/test/dygraph_to_static/test_for_enumerate.py b/test/dygraph_to_static/test_for_enumerate.py index 46cfc80ebe7cd4..2873704a97abe1 100644 --- a/test/dygraph_to_static/test_for_enumerate.py +++ b/test/dygraph_to_static/test_for_enumerate.py @@ -172,7 +172,7 @@ def for_iter_var_list(x): # 1. prepare data, ref test_list.py x = paddle.to_tensor(x) iter_num = paddle.tensor.fill_constant(shape=[1], value=5, dtype="int32") - a = [] + a = paddle.tensor.create_array("int32") for i in range(iter_num): a.append(x + i) # 2. iter list[var] @@ -187,7 +187,7 @@ def for_enumerate_var_list(x): # 1. prepare data, ref test_list.py x = paddle.to_tensor(x) iter_num = paddle.tensor.fill_constant(shape=[1], value=5, dtype="int32") - a = [] + a = paddle.tensor.create_array("int32") for i in range(iter_num): a.append(x + i) # 2. iter list[var] @@ -389,7 +389,7 @@ def transformed_error(self, etype): class TestForInRangeConfig(TestTransform): def set_input(self): - self.input = np.array([5]) + self.input = np.array([5]).astype("int32") def set_test_func(self): self.dygraph_func = for_in_range @@ -489,6 +489,7 @@ class TestForIterVarList(TestForInRangeConfig): def set_test_func(self): self.dygraph_func = for_iter_var_list + @test_legacy_and_pt_and_pir def test_transformed_result_compare(self): self.set_test_func() self.transformed_result_compare() @@ -498,6 +499,7 @@ class TestForEnumerateVarList(TestForInRangeConfig): def set_test_func(self): self.dygraph_func = for_enumerate_var_list + @test_legacy_and_pt_and_pir def test_transformed_result_compare(self): self.set_test_func() self.transformed_result_compare()