@@ -39,6 +39,43 @@ class TransposeComputeFloatImage
3939 public:
4040 using param_t = operators::TransposeParam;
4141
42+ std::vector<int > CalStrides (const DDim& dims) {
43+ int dsize = dims.size ();
44+ std::vector<int > strides (dsize, 1 );
45+ for (int i = dsize - 2 ; i >= 0 ; i--) {
46+ strides[i] = strides[i + 1 ] * dims[i + 1 ];
47+ }
48+ return strides;
49+ }
50+
51+ std::vector<int > CalIndex (const std::vector<int >& strides, int offset) {
52+ int dsize = strides.size ();
53+ std::vector<int > index (dsize, 0 );
54+ for (int i = 0 ; i < dsize; i++) {
55+ index[i] = offset / strides[i];
56+ offset %= strides[i];
57+ }
58+ return index;
59+ }
60+
61+ std::vector<int > TransIndex (const std::vector<int >& in_index,
62+ const std::vector<int >& axis) {
63+ std::vector<int > out_index (in_index.size (), 0 );
64+ for (int i = 0 ; i < axis.size (); i++) {
65+ out_index[i] = in_index[axis[i]];
66+ }
67+ return out_index;
68+ }
69+
70+ int CalOffset (const std::vector<int >& strides,
71+ const std::vector<int >& index) {
72+ int offset = 0 ;
73+ for (int i = 0 ; i < index.size (); i++) {
74+ offset += strides[i] * index[i];
75+ }
76+ return offset;
77+ }
78+
4279 void PrepareForRun () override {
4380 transpose_param_ = param_.get_mutable <param_t >();
4481 axis_ = transpose_param_->axis ;
@@ -84,11 +121,57 @@ class TransposeComputeFloatImage
84121 LOG (FATAL) << " Unsupported axis permutation for current lite OpenCL "
85122 " kernel! " ;
86123 }
124+ } else if (axis_.size () > 4 ) {
125+ kernel_func_name_ = " transpose_general_buffer" ;
126+ kernel_path_ = " buffer/transpose_kernel.cl" ;
87127 } else {
88128 LOG (FATAL) << " Unsupported axis permutation for current lite OpenCL "
89129 " kernel! " ;
90130 }
91131
132+ if (kernel_func_name_ == " transpose_general_buffer" ) {
133+ build_options_ = " -DCL_DTYPE_float" ;
134+ // create kernels of im2buf and buf2im
135+ auto im2buf_kernels = KernelRegistry::Global ().Create (
136+ " layout" , TARGET (kOpenCL ), PRECISION (kAny ), DATALAYOUT (kNCHW ));
137+ auto buf2im_kernels =
138+ KernelRegistry::Global ().Create (" layout" ,
139+ TARGET (kOpenCL ),
140+ PRECISION (kAny ),
141+ DATALAYOUT (kImageDefault ));
142+
143+ im2buf_kernel_ = std::move (im2buf_kernels.front ());
144+ buf2im_kernel_ = std::move (buf2im_kernels.front ());
145+
146+ // calc output shape
147+ std::vector<int64_t > new_output_tensor_shape (x_tensor_dims_.size (), 0 );
148+ for (size_t i = 0 ; i < x_tensor_dims_.size (); i++) {
149+ new_output_tensor_shape[i] = x_tensor_dims_[axis_[i]];
150+ }
151+ output->Resize (new_output_tensor_shape);
152+ output_tensor_dims_ = output->dims ();
153+ // calc in/out index of transpose
154+ std::vector<int > x_tensor_strides = CalStrides (x_tensor_dims_);
155+ std::vector<int > output_tensor_strides = CalStrides (output_tensor_dims_);
156+ std::vector<int > output_tensor_idxs_vec (output->dims ().production ());
157+ for (size_t i = 0 ; i < x_tensor_dims_.production (); i++) {
158+ std::vector<int > x_tensor_index = CalIndex (x_tensor_strides, i);
159+ std::vector<int > out_tensor_index = TransIndex (x_tensor_index, axis_);
160+ output_tensor_idxs_vec[i] =
161+ CalOffset (output_tensor_strides, out_tensor_index);
162+ }
163+
164+ // copy output_tensor_idxs_vec data to gpu
165+ output_tensor_idxs_t_ = std::unique_ptr<Tensor>(new Tensor);
166+ output_tensor_idxs_t_->Resize (output_tensor_dims_);
167+ output_tensor_idxs_data_ =
168+ output_tensor_idxs_t_->mutable_data <int , cl::Buffer>(TARGET (kOpenCL ));
169+ TargetWrapperCL::MemcpySync (output_tensor_idxs_data_,
170+ output_tensor_idxs_vec.data (),
171+ output_tensor_idxs_t_->memory_size (),
172+ IoDirection::HtoD);
173+ }
174+
92175 if (output_tensor_dims_.size () == 4 ) {
93176 output_tensor_n_ = output_tensor_dims_[0 ];
94177 output_tensor_c_ = output_tensor_dims_[1 ];
@@ -126,14 +209,21 @@ class TransposeComputeFloatImage
126209#endif
127210
128211 void GetGlobalWorkSize () {
129- const std::vector<size_t >& ws =
130- DefaultGlobalWorkSize (output_tensor_dims_,
131- DDim (std::vector<DDim::value_type>{
132- static_cast <int64_t >(output_image_w_),
133- static_cast <int64_t >(output_image_h_)}));
134- global_work_size_ = cl::NDRange{static_cast <cl::size_type>(ws[0 ]),
135- static_cast <cl::size_type>(ws[1 ]),
136- static_cast <cl::size_type>(ws[2 ])};
212+ if (kernel_func_name_ == " transpose_general_buffer" ) {
213+ global_work_size_ =
214+ cl::NDRange{static_cast <cl::size_type>(output_tensor_h_),
215+ static_cast <cl::size_type>(output_tensor_w_),
216+ static_cast <cl::size_type>(output_tensor_c_)};
217+ } else {
218+ const std::vector<size_t >& ws =
219+ DefaultGlobalWorkSize (output_tensor_dims_,
220+ DDim (std::vector<DDim::value_type>{
221+ static_cast <int64_t >(output_image_w_),
222+ static_cast <int64_t >(output_image_h_)}));
223+ global_work_size_ = cl::NDRange{static_cast <cl::size_type>(ws[0 ]),
224+ static_cast <cl::size_type>(ws[1 ]),
225+ static_cast <cl::size_type>(ws[2 ])};
226+ }
137227 }
138228
139229 void Run () override {
@@ -144,30 +234,92 @@ class TransposeComputeFloatImage
144234 auto & context = ctx_->As <OpenCLContext>();
145235 auto kernel = kernel_;
146236 cl_int status;
147- status = kernel.setArg (0 , *x_image);
148- CL_CHECK_FATAL (status);
149- status = kernel.setArg (1 , *output_image);
150- CL_CHECK_FATAL (status);
151- status = kernel.setArg (2 , output_tensor_c_);
152- CL_CHECK_FATAL (status);
153- status = kernel.setArg (3 , output_tensor_h_);
154- CL_CHECK_FATAL (status);
155- status = kernel.setArg (4 , output_tensor_w_);
156- CL_CHECK_FATAL (status);
157- status = kernel.setArg (5 , x_tensor_w_);
158- CL_CHECK_FATAL (status);
159- status = kernel.setArg (6 , x_tensor_h_);
160- CL_CHECK_FATAL (status);
161-
162- GetGlobalWorkSize ();
163- status = EnqueueNDRangeKernel (context,
164- kernel,
165- cl::NullRange,
166- global_work_size_,
167- cl::NullRange,
168- nullptr ,
169- event_);
170- CL_CHECK_FATAL (status);
237+ if (kernel_func_name_ == " transpose_general_buffer" ) {
238+ // do image layout transform: image to buffer
239+ // create and set param, context to kernel im2buf
240+ operators::LayoutParam im2buf_param;
241+ std::shared_ptr<lite::Tensor> im2buf_out_t (new lite::Tensor);
242+ im2buf_out_t ->Resize (x_tensor_dims_);
243+ auto im2buf_out_t_buffer_p =
244+ im2buf_out_t ->mutable_data <float , cl::Buffer>(TARGET (kOpenCL ));
245+ im2buf_param.x = transpose_param_->x ;
246+ im2buf_param.y = im2buf_out_t .get ();
247+ auto s = im2buf_kernel_->op_type ();
248+ im2buf_kernel_->SetParam (im2buf_param);
249+
250+ std::unique_ptr<KernelContext> im2buf_ctx (new KernelContext);
251+ context.CopySharedTo (&(im2buf_ctx->As <OpenCLContext>()));
252+ im2buf_kernel_->SetContext (std::move (im2buf_ctx));
253+ im2buf_kernel_->Launch ();
254+
255+ // create and set param, context to kernel buf2im
256+ std::shared_ptr<lite::Tensor> buf2im_in_t (new lite::Tensor);
257+ buf2im_in_t ->Resize (transpose_param_->output ->dims ());
258+ auto buf2im_in_t_buffer_p =
259+ buf2im_in_t ->mutable_data <float , cl::Buffer>(TARGET (kOpenCL ));
260+ operators::LayoutParam buf2im_param;
261+ buf2im_param.x = buf2im_in_t .get ();
262+ buf2im_param.y = transpose_param_->output ;
263+ buf2im_kernel_->SetParam (buf2im_param);
264+
265+ std::unique_ptr<KernelContext> buf2im_ctx (new KernelContext);
266+ context.CopySharedTo (&(buf2im_ctx->As <OpenCLContext>()));
267+ buf2im_kernel_->SetContext (std::move (buf2im_ctx));
268+
269+ // set kernel args
270+ status = kernel.setArg (0 , *im2buf_out_t_buffer_p);
271+ CL_CHECK_FATAL (status);
272+ status = kernel.setArg (1 , *buf2im_in_t_buffer_p);
273+ CL_CHECK_FATAL (status);
274+ status = kernel.setArg (2 , *output_tensor_idxs_data_);
275+ CL_CHECK_FATAL (status);
276+ status = kernel.setArg (3 , output_tensor_c_);
277+ CL_CHECK_FATAL (status);
278+ status = kernel.setArg (4 , output_tensor_h_);
279+ CL_CHECK_FATAL (status);
280+ status = kernel.setArg (5 , output_tensor_w_);
281+ CL_CHECK_FATAL (status);
282+ status = kernel.setArg (6 , output_tensor_h_ * output_tensor_w_);
283+ CL_CHECK_FATAL (status);
284+
285+ GetGlobalWorkSize ();
286+ auto & context = ctx_->As <OpenCLContext>();
287+ status = EnqueueNDRangeKernel (context,
288+ kernel,
289+ cl::NullRange,
290+ global_work_size_,
291+ cl::NullRange,
292+ nullptr ,
293+ event_);
294+ CL_CHECK_FATAL (status);
295+ // run kernel: buffer->image
296+ buf2im_kernel_->Launch ();
297+ } else {
298+ status = kernel.setArg (0 , *x_image);
299+ CL_CHECK_FATAL (status);
300+ status = kernel.setArg (1 , *output_image);
301+ CL_CHECK_FATAL (status);
302+ status = kernel.setArg (2 , output_tensor_c_);
303+ CL_CHECK_FATAL (status);
304+ status = kernel.setArg (3 , output_tensor_h_);
305+ CL_CHECK_FATAL (status);
306+ status = kernel.setArg (4 , output_tensor_w_);
307+ CL_CHECK_FATAL (status);
308+ status = kernel.setArg (5 , x_tensor_w_);
309+ CL_CHECK_FATAL (status);
310+ status = kernel.setArg (6 , x_tensor_h_);
311+ CL_CHECK_FATAL (status);
312+
313+ GetGlobalWorkSize ();
314+ status = EnqueueNDRangeKernel (context,
315+ kernel,
316+ cl::NullRange,
317+ global_work_size_,
318+ cl::NullRange,
319+ nullptr ,
320+ event_);
321+ CL_CHECK_FATAL (status);
322+ }
171323 }
172324
173325 private:
@@ -194,6 +346,10 @@ class TransposeComputeFloatImage
194346
195347 cl::NDRange global_work_size_;
196348 cl::Kernel kernel_;
349+
350+ // transpose_general_buffer
351+ std::unique_ptr<KernelBase> im2buf_kernel_;
352+ std::unique_ptr<KernelBase> buf2im_kernel_;
197353};
198354
199355} // namespace opencl
0 commit comments