Skip to content

Commit 69c12b1

Browse files
authored
[OpenCL]Conv2d support group>1 (#8499)
* [OpenCL]conv2d supprt group>1 test=develop * conv mul_group add tunning local_work_size test=develop
1 parent ccb74e5 commit 69c12b1

File tree

7 files changed

+839
-27
lines changed

7 files changed

+839
-27
lines changed

lite/backends/opencl/cl_image_converter.cc

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ void CLImageConverterWinoTransWeight::ImageToNCHW(void *image,
590590
float *tensor,
591591
const DDim &image_dim,
592592
const DDim &tensor_dim) {}
593+
593594
DDim CLImageConverterNBlock::InitImageDimInfoWith(const DDim &tensor_dim) {
594595
CHECK(tensor_dim.size() == 4) << " Tensor dim is not 4.";
595596
size_t N, C, H, W;
@@ -604,6 +605,20 @@ DDim CLImageConverterNBlock::InitImageDimInfoWith(const DDim &tensor_dim) {
604605
static_cast<DDim::value_type>(height)}));
605606
}
606607

608+
DDim CLImageConverterNBlockGroup::InitImageDimInfoWith(const DDim &tensor_dim) {
609+
CHECK(tensor_dim.size() == 4) << " Tensor dim is not 4.";
610+
size_t N, C, H, W;
611+
N = tensor_dim[0];
612+
C = tensor_dim[1];
613+
H = tensor_dim[2];
614+
W = tensor_dim[3];
615+
size_t width = ((C + 3) / 4) * 4;
616+
size_t height = ((N / groups + 3) / 4 * groups) * H * W;
617+
return DDim(
618+
std::vector<DDim::value_type>({static_cast<DDim::value_type>(width),
619+
static_cast<DDim::value_type>(height)}));
620+
}
621+
607622
void CLImageConverterNBlock::NCHWToImage(float *nchw,
608623
void *image,
609624
const DDim &tensor_dim) {
@@ -648,6 +663,57 @@ void CLImageConverterNBlock::NCHWToImage(float *nchw,
648663
}
649664
}
650665

666+
void CLImageConverterNBlockGroup::ImageToNCHW(void *image,
667+
float *tensor,
668+
const DDim &image_dim,
669+
const DDim &tensor_dim) {}
670+
671+
void CLImageConverterNBlockGroup::NCHWToImage(float *nchw,
672+
void *image,
673+
const DDim &tensor_dim) {
674+
CHECK(tensor_dim.size() == 4) << " Tensor dim is not 4.";
675+
size_t N, C, H, W;
676+
N = tensor_dim[0];
677+
C = tensor_dim[1];
678+
H = tensor_dim[2];
679+
W = tensor_dim[3];
680+
681+
DDim in_image_dim = InitImageDimInfoWith(tensor_dim);
682+
683+
VLOG(3) << " tensor dim: " << tensor_dim;
684+
VLOG(3) << " image dim: " << in_image_dim;
685+
686+
size_t height = in_image_dim[1];
687+
size_t n_block = height / (W * H);
688+
size_t c_block4 = ((in_image_dim[0] + 3) / 4) * 4;
689+
690+
float *image_fp32 = static_cast<float *>(image);
691+
half_t *image_fp16 = static_cast<half_t *>(image);
692+
693+
float *p = nchw;
694+
size_t i0 = 0;
695+
int i = 0;
696+
for (size_t n = 0; n < n_block * 4; n++) {
697+
for (size_t c = 0; c < c_block4; c++) {
698+
for (size_t h = 0; h < H; h++) {
699+
for (size_t w = 0; w < W; w++) {
700+
size_t img_idx =
701+
(((n / 4) * W * H + h * W + w) * c_block4 + c) * 4 + n % 4;
702+
size_t remain = n % ((N / groups + 3) / 4 * 4);
703+
if (remain < (N / groups) && c < C) {
704+
fp16_support_ ? image_fp16[img_idx] = Float2Half(*p)
705+
: image_fp32[img_idx] = *p;
706+
p++;
707+
} else {
708+
fp16_support_ ? image_fp16[img_idx] = Float2Half(0.f)
709+
: image_fp32[img_idx] = 0.f;
710+
}
711+
}
712+
}
713+
}
714+
}
715+
}
716+
651717
void CLImageConverterNBlock::ImageToNCHW(void *image,
652718
float *tensor,
653719
const DDim &image_dim,

lite/backends/opencl/cl_image_converter.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,17 @@ class CLImageConverterNBlock : public CLImageConverterBase {
143143
const DDim &tensor_dim) override;
144144
};
145145

146+
class CLImageConverterNBlockGroup : public CLImageConverterBase {
147+
public:
148+
DDim InitImageDimInfoWith(const DDim &tensor_dim) override;
149+
void NCHWToImage(float *tensor, void *image, const DDim &tensor_dim) override;
150+
void ImageToNCHW(void *image,
151+
float *tensor,
152+
const DDim &image_dim,
153+
const DDim &tensor_dim) override;
154+
int groups{-1};
155+
};
156+
146157
class CLImageConverterN2Block : public CLImageConverterBase {
147158
public:
148159
DDim InitImageDimInfoWith(const DDim &tensor_dim) override;

0 commit comments

Comments
 (0)