@@ -28,6 +28,88 @@ namespace lite {
2828namespace kernels {
2929namespace x86 {
3030
31+ static void print_tensor_array (const std::vector<lite::Tensor>* XTensorList) {
32+ std::cout << " *********tensor array********" << std::endl;
33+ for (int i = 0 ; i < XTensorList->size (); i++) {
34+ auto tensor = XTensorList->at (i);
35+ std::cout << " Tensor " << i << " len: " << tensor.data_size () << std::endl;
36+ for (int j = 0 ; j < tensor.data_size () && j < 10 ; j++) {
37+ std::cout << tensor.mutable_data <float >()[j] << " " ;
38+ }
39+ std::cout << std::endl;
40+ }
41+ }
42+
43+ static void print_tensor (lite::Tensor* Tensor) {
44+ std::cout << " *********tensor********" << std::endl;
45+ float * data = Tensor->mutable_data <float >();
46+ std::cout << " Tensor len: " << Tensor->data_size () << std::endl;
47+ for (int j = 0 ; j < Tensor->data_size () && j < 10 ; j++) {
48+ std::cout << data[j] << " " ;
49+ }
50+ std::cout << std::endl;
51+ }
52+
53+ void DealTensorArray (const std::vector<lite::Tensor>* XTensorList,
54+ std::vector<lite::Tensor>* OutTensorList,
55+ lite::Tensor* Out,
56+ const std::vector<int >& starts,
57+ const std::vector<int >& ends,
58+ bool out_is_array) {
59+ auto in_array = XTensorList;
60+ // If the input is LoDTensorArray, the rank of input is 1.
61+ int64_t in_size = in_array->size ();
62+ int64_t start = starts[0 ] < 0 ? (starts[0 ] + in_size) : starts[0 ];
63+ int64_t end = ends[0 ] < 0 ? (ends[0 ] + in_size) : ends[0 ];
64+
65+ start = std::max (start, static_cast <int64_t >(0 ));
66+ end = std::max (end, static_cast <int64_t >(0 ));
67+ end = std::min (end, in_size);
68+
69+ CHECK_GT (end, start) << " end should greater than start" ;
70+ int64_t out_size = end - start;
71+
72+ std::cout << " starts: " << std::endl;
73+ for (int i = 0 ; i < starts.size (); i++) {
74+ std::cout << starts[i] << " " ;
75+ }
76+ std::cout << std::endl;
77+ std::cout << " ends: " << std::endl;
78+ for (int i = 0 ; i < ends.size (); i++) {
79+ std::cout << ends[i] << " " ;
80+ }
81+ std::cout << std::endl;
82+
83+ if (out_is_array) {
84+ auto out_array = OutTensorList;
85+ out_array->resize (out_size);
86+ for (int i = 0 ; i < out_size; ++i) {
87+ auto * out_tensor = &out_array->at (i);
88+ auto in_tensor = in_array->at (i + start);
89+ out_tensor->set_lod (in_tensor.lod ());
90+ if (in_tensor.memory_size () > 0 ) {
91+ out_tensor->CopyDataFrom (in_tensor);
92+ } else {
93+ VLOG (4 ) << " WARNING: The input tensor 'x_tensor' holds no memory, so "
94+ " nothing has been written to output array["
95+ << i << " ]." ;
96+ }
97+ }
98+ } else {
99+ auto out_tensor = Out;
100+ auto in_tensor = in_array->at (start);
101+ out_tensor->CopyDataFrom (in_tensor);
102+ }
103+ std::cout << " input array:" << std::endl;
104+ print_tensor_array (XTensorList);
105+ if (out_is_array) {
106+ std::cout << " out array:" << std::endl;
107+ print_tensor_array (OutTensorList);
108+ } else {
109+ print_tensor (Out);
110+ }
111+ }
112+
31113inline std::vector<int > GetIntDataFromTensorList (
32114 const std::vector<lite::Tensor*>& list_tensor) {
33115 std::vector<int > vec_data;
@@ -219,6 +301,8 @@ void slice_compute(const lite::Tensor* in,
219301template <class T >
220302void slice_compute_ (const lite::Tensor* Input,
221303 lite::Tensor* Out,
304+ const std::vector<lite::Tensor>* XTensorList,
305+ std::vector<lite::Tensor>* OutTensorList,
222306 std::vector<int > axes,
223307 std::vector<int > starts,
224308 std::vector<int > ends,
@@ -228,6 +312,38 @@ void slice_compute_(const lite::Tensor* Input,
228312 std::vector<lite::Tensor*> StartsTensorList,
229313 std::vector<lite::Tensor*> EndsTensorList,
230314 std::vector<int > infer_flags) {
315+ if (Input == nullptr && XTensorList != nullptr ) {
316+ bool need_infer = false ;
317+ if (StartsTensor || EndsTensor) {
318+ need_infer = true ;
319+ }
320+ if (StartsTensorList.size () > 0 || EndsTensorList.size () > 0 ) {
321+ need_infer = true ;
322+ }
323+ if (need_infer) {
324+ if (StartsTensor) {
325+ starts = GetIntDataFromTensor (StartsTensor);
326+ } else if (StartsTensorList.size () > 0 ) {
327+ starts = GetIntDataFromTensorList (StartsTensorList);
328+ }
329+ CHECK_EQ (starts.size (), axes.size ())
330+ << " The size of starts must be equal to the size of axes." ;
331+ if (EndsTensor) {
332+ ends = GetIntDataFromTensor (EndsTensor);
333+ } else if (EndsTensorList.size () > 0 ) {
334+ ends = GetIntDataFromTensorList (EndsTensorList);
335+ }
336+ CHECK_EQ (ends.size (), axes.size ())
337+ << " The size of starts must be equal to the size of axes." ;
338+ }
339+ DealTensorArray (XTensorList,
340+ OutTensorList,
341+ Out,
342+ starts,
343+ ends,
344+ (Out == nullptr && OutTensorList != nullptr ));
345+ return ;
346+ }
231347 int rank = Input->dims ().size ();
232348 switch (rank) {
233349 case 1 :
@@ -320,6 +436,8 @@ class SliceCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
320436 auto & param = *param_.get_mutable <param_t >();
321437 slice_compute_<T>(param.X ,
322438 param.Out ,
439+ param.XTensorList ,
440+ param.OutTensorList ,
323441 param.axes ,
324442 param.starts ,
325443 param.ends ,
0 commit comments