@@ -590,6 +590,7 @@ void CLImageConverterWinoTransWeight::ImageToNCHW(void *image,
590590 float *tensor,
591591 const DDim &image_dim,
592592 const DDim &tensor_dim) {}
593+
593594DDim CLImageConverterNBlock::InitImageDimInfoWith (const DDim &tensor_dim) {
594595 CHECK (tensor_dim.size () == 4 ) << " Tensor dim is not 4." ;
595596 size_t N, C, H, W;
@@ -604,6 +605,20 @@ DDim CLImageConverterNBlock::InitImageDimInfoWith(const DDim &tensor_dim) {
604605 static_cast <DDim::value_type>(height)}));
605606}
606607
608+ DDim CLImageConverterNBlockGroup::InitImageDimInfoWith (const DDim &tensor_dim) {
609+ CHECK (tensor_dim.size () == 4 ) << " Tensor dim is not 4." ;
610+ size_t N, C, H, W;
611+ N = tensor_dim[0 ];
612+ C = tensor_dim[1 ];
613+ H = tensor_dim[2 ];
614+ W = tensor_dim[3 ];
615+ size_t width = ((C + 3 ) / 4 ) * 4 ;
616+ size_t height = ((N / groups + 3 ) / 4 * groups) * H * W;
617+ return DDim (
618+ std::vector<DDim::value_type>({static_cast <DDim::value_type>(width),
619+ static_cast <DDim::value_type>(height)}));
620+ }
621+
607622void CLImageConverterNBlock::NCHWToImage (float *nchw,
608623 void *image,
609624 const DDim &tensor_dim) {
@@ -648,6 +663,57 @@ void CLImageConverterNBlock::NCHWToImage(float *nchw,
648663 }
649664}
650665
666+ void CLImageConverterNBlockGroup::ImageToNCHW (void *image,
667+ float *tensor,
668+ const DDim &image_dim,
669+ const DDim &tensor_dim) {}
670+
671+ void CLImageConverterNBlockGroup::NCHWToImage (float *nchw,
672+ void *image,
673+ const DDim &tensor_dim) {
674+ CHECK (tensor_dim.size () == 4 ) << " Tensor dim is not 4." ;
675+ size_t N, C, H, W;
676+ N = tensor_dim[0 ];
677+ C = tensor_dim[1 ];
678+ H = tensor_dim[2 ];
679+ W = tensor_dim[3 ];
680+
681+ DDim in_image_dim = InitImageDimInfoWith (tensor_dim);
682+
683+ VLOG (3 ) << " tensor dim: " << tensor_dim;
684+ VLOG (3 ) << " image dim: " << in_image_dim;
685+
686+ size_t height = in_image_dim[1 ];
687+ size_t n_block = height / (W * H);
688+ size_t c_block4 = ((in_image_dim[0 ] + 3 ) / 4 ) * 4 ;
689+
690+ float *image_fp32 = static_cast <float *>(image);
691+ half_t *image_fp16 = static_cast <half_t *>(image);
692+
693+ float *p = nchw;
694+ size_t i0 = 0 ;
695+ int i = 0 ;
696+ for (size_t n = 0 ; n < n_block * 4 ; n++) {
697+ for (size_t c = 0 ; c < c_block4; c++) {
698+ for (size_t h = 0 ; h < H; h++) {
699+ for (size_t w = 0 ; w < W; w++) {
700+ size_t img_idx =
701+ (((n / 4 ) * W * H + h * W + w) * c_block4 + c) * 4 + n % 4 ;
702+ size_t remain = n % ((N / groups + 3 ) / 4 * 4 );
703+ if (remain < (N / groups) && c < C) {
704+ fp16_support_ ? image_fp16[img_idx] = Float2Half (*p)
705+ : image_fp32[img_idx] = *p;
706+ p++;
707+ } else {
708+ fp16_support_ ? image_fp16[img_idx] = Float2Half (0 .f )
709+ : image_fp32[img_idx] = 0 .f ;
710+ }
711+ }
712+ }
713+ }
714+ }
715+ }
716+
651717void CLImageConverterNBlock::ImageToNCHW (void *image,
652718 float *tensor,
653719 const DDim &image_dim,
0 commit comments