@@ -92,6 +92,41 @@ void TestAPISizeAndShape() {
9292 CHECK (t1.shape () == tensor_shape);
9393}
9494
95+ void TestAPISlice () {
96+ std::vector<int64_t > tensor_shape_origin1 = {5 , 5 };
97+ std::vector<int64_t > tensor_shape_sub1 = {3 , 5 };
98+ std::vector<int64_t > tensor_shape_origin2 = {5 , 5 , 5 };
99+ std::vector<int64_t > tensor_shape_sub2 = {1 , 5 , 5 };
100+ #ifdef PADDLE_WITH_CUDA
101+ auto t1 = paddle::Tensor (paddle::PlaceType::kGPU , tensor_shape_origin1);
102+ t1.mutable_data <float >();
103+ CHECK (t1.slice (0 , 5 ).shape () == tensor_shape_origin1);
104+ CHECK (t1.slice (0 , 3 ).shape () == tensor_shape_sub1);
105+ auto t2 = paddle::Tensor (paddle::PlaceType::kGPU , tensor_shape_origin2);
106+ t2.mutable_data <float >();
107+ CHECK (t2.slice (4 , 5 ).shape () == tensor_shape_sub2);
108+ #endif
109+ auto t3 = paddle::Tensor (paddle::PlaceType::kCPU , tensor_shape_origin1);
110+ t3.mutable_data <float >();
111+ CHECK (t3.slice (0 , 5 ).shape () == tensor_shape_origin1);
112+ CHECK (t3.slice (0 , 3 ).shape () == tensor_shape_sub1);
113+ auto t4 = paddle::Tensor (paddle::PlaceType::kCPU , tensor_shape_origin2);
114+ t4.mutable_data <float >();
115+ CHECK (t4.slice (4 , 5 ).shape () == tensor_shape_sub2);
116+
117+ // Test writing function for sliced tensor
118+ auto t = InitCPUTensorForTest<float >();
119+ auto t_sliced = t.slice (0 , 1 );
120+ auto * t_sliced_data_ptr = t_sliced.mutable_data <float >();
121+ for (int64_t i = 0 ; i < t_sliced.size (); i++) {
122+ t_sliced_data_ptr[i] += static_cast <float >(5 );
123+ }
124+ auto * t_data_ptr = t.mutable_data <float >();
125+ for (int64_t i = 0 ; i < t_sliced.size (); i++) {
126+ CHECK_EQ (t_data_ptr[i], static_cast <float >(10 ));
127+ }
128+ }
129+
95130template <typename T>
96131paddle::DataType TestDtype () {
97132 std::vector<int64_t > tensor_shape = {5 , 5 };
@@ -261,6 +296,8 @@ TEST(CustomTensor, copyTest) {
261296 TestAPISizeAndShape ();
262297 VLOG (2 ) << " TestPlace" ;
263298 TestAPIPlace ();
299+ VLOG (2 ) << " TestSlice" ;
300+ TestAPISlice ();
264301 VLOG (2 ) << " TestCast" ;
265302 GroupTestCast ();
266303 VLOG (2 ) << " TestDtypeConvert" ;
0 commit comments