Skip to content

Commit fa879ec

Browse files
[Metal] fix Conv2d_transpose MPS (#8439)
* fix * fix_conv2d_transpose
1 parent b3c19d1 commit fa879ec

2 files changed

Lines changed: 142 additions & 0 deletions

File tree

lite/kernels/metal/image_op/conv2d_transpose_image_compute.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
#include "lite/backends/metal/metal_context.h"
2929
#include "lite/backends/metal/metal_debug.h"
30+
#include "lite/backends/metal/mps_conv_datasource.h"
31+
#include "lite/kernels/metal/image_op/metal_params.h"
3032

3133
namespace paddle {
3234
namespace lite {
@@ -39,6 +41,7 @@ class Conv2dTransposeImageCompute
3941

4042
public:
4143
void PrepareForRun() override;
44+
void ReInitWhenNeeded() override;
4245
void Run() override;
4346
void SaveOutput() override {
4447
MetalDebug::SaveOutput(
@@ -48,6 +51,7 @@ class Conv2dTransposeImageCompute
4851

4952
private:
5053
bool use_mps_{false};
54+
void* mps_conv_trans_op_{nullptr};
5155
void* mps_conv_op_{nullptr};
5256
void* mps_input_image_{nullptr};
5357
void* mps_output_image_{nullptr};

lite/kernels/metal/image_op/conv2d_transpose_image_compute.mm

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,38 @@
3333
init_for_run();
3434
}
3535

36+
void Conv2dTransposeImageCompute::ReInitWhenNeeded() {
37+
const auto& param = this->Param<param_t>();
38+
auto input_dims = param.x->dims();
39+
40+
if (last_input_dims_ != input_dims) {
41+
release_memory();
42+
init_memory();
43+
44+
if (use_mps_) {
45+
if (@available(iOS 10.0, macOS 10.13, macCatalyst 13.0, *)) {
46+
if (mps_input_image_) {
47+
CFRelease(mps_input_image_);
48+
mps_input_image_ = nullptr;
49+
}
50+
if (mps_output_image_) {
51+
CFRelease(mps_output_image_);
52+
mps_output_image_ = nullptr;
53+
}
54+
auto input_c = static_cast<int>(input_buffer_->tensor_dim_[1]);
55+
auto output_c = static_cast<int>(output_buffer_->tensor_dim_[1]);
56+
// MPS input and output
57+
mps_input_image_ = (__bridge_retained void*)[[MPSImage alloc]
58+
initWithTexture:input_buffer_->image()
59+
featureChannels:input_c];
60+
mps_output_image_ = (__bridge_retained void*)[[MPSImage alloc]
61+
initWithTexture:output_buffer_->image()
62+
featureChannels:output_c];
63+
}
64+
}
65+
}
66+
}
67+
3668
// attention!!! filter: CNHW2NCHW
3769
void Conv2dTransposeImageCompute::init_attention() {
3870
const auto& param = this->Param<param_t>();
@@ -73,6 +105,13 @@
73105
const auto& param = this->Param<param_t>();
74106

75107
function_name_ = KernelFunctionName(param);
108+
bool should_use_mps = false;
109+
if (@available(iOS 10.0, macOS 10.13, macCatalyst 13.0, *)) {
110+
if (metal_context_->use_mps()) {
111+
should_use_mps = true;
112+
}
113+
}
114+
use_mps_ = should_use_mps;
76115
if (use_mps_) {
77116
setup_with_mps();
78117
} else {
@@ -427,9 +466,108 @@
427466
#pragma mark - MPS
428467

429468
void Conv2dTransposeImageCompute::run_with_mps() {
469+
auto backend = (__bridge MetalContextImp*)metal_context_->backend();
470+
auto cmdbuf = [backend commandBuffer];
471+
if (mps_conv_trans_op_) {
472+
if (@available(iOS 10.0, macOS 10.13, macCatalyst 13.0, *)) {
473+
[((__bridge MPSCNNConvolutionTranspose*)mps_conv_trans_op_)
474+
encodeToCommandBuffer:cmdbuf
475+
sourceImage:(__bridge MPSImage*)mps_input_image_
476+
destinationImage:(__bridge MPSImage*)mps_output_image_];
477+
}
478+
}
479+
[backend commit:cmdbuf];
430480
}
431481

432482
void Conv2dTransposeImageCompute::setup_with_mps() {
483+
const auto& param = this->Param<param_t>();
484+
auto backend = (__bridge MetalContextImp*)metal_context_->backend();
485+
auto padding_top = (*param.paddings)[0];
486+
auto padding_left = (*param.paddings)[2];
487+
488+
int offsetX =
489+
static_cast<int>(param.filter->dims()[3] / 2 - param.filter->dims()[3] + 1 + padding_left);
490+
int offsetY =
491+
static_cast<int>(param.filter->dims()[2] / 2 - param.filter->dims()[2] + 1 + padding_top);
492+
493+
auto rawdata = param.filter->data<float>();
494+
auto dims = filter_metal_dims_; //
495+
auto tensorDim = DDimLite({dims[0], dims[1], dims[2], dims[3]}); //
496+
auto count = tensorDim.production();
497+
498+
void* convertedPointer = TargetWrapperMetal::Malloc(count * sizeof(float));
499+
TargetWrapperMetal::MemsetSync(convertedPointer, 0, count * sizeof(float));
500+
auto weightsPointer = (float*)rawdata;
501+
auto transposed = (float*)convertedPointer;
502+
503+
int length_nhw = dims[0] * dims[2] * dims[3];
504+
int length_chw = dims[1] * dims[2] * dims[3];
505+
int length_hw = dims[2] * dims[3];
506+
507+
for (int n = 0; n < dims[0]; n++) {
508+
for (int c = 0; c < dims[1]; c++) {
509+
for (int h = 0; h < dims[2]; h++) {
510+
for (int w = 0; w < dims[3]; w++) {
511+
int tIndex = h * dims[3] + w + length_nhw * c + length_hw * n;
512+
int index = length_chw * n + (dims[2] - 1 - h) * dims[3] * dims[1] +
513+
(dims[3] - 1 - w) * dims[1] + c;
514+
transposed[index] = weightsPointer[tIndex];
515+
}
516+
}
517+
}
518+
}
519+
// mps-Convolution
520+
if (@available(iOS 10.0, macOS 10.13, macCatalyst 13.0, *)) {
521+
output_buffer_->use_mps_ = true;
522+
const_cast<MetalImage*>(input_buffer_)->use_mps_ = true;
523+
auto filter_h = static_cast<int>(param.filter->dims()[2]);
524+
auto filter_w = static_cast<int>(param.filter->dims()[3]);
525+
auto input_c = MAX(4, static_cast<int>(input_buffer_->tensor_dim_[1]));
526+
auto output_c = MAX(4, static_cast<int>(output_buffer_->tensor_dim_[1]));
527+
528+
MPSCNNConvolutionDescriptor* description =
529+
[MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:filter_w
530+
kernelHeight:filter_h
531+
inputFeatureChannels:input_c
532+
outputFeatureChannels:output_c];
533+
534+
description.strideInPixelsX = param.strides[0];
535+
description.strideInPixelsY = param.strides[1];
536+
description.dilationRateX = (*param.dilations)[0];
537+
description.dilationRateY = (*param.dilations)[1];
538+
description.groups = 1;
539+
540+
MPSConvDataSource* scoure = [[MPSConvDataSource alloc] init];
541+
scoure.descriptor = description;
542+
filter_buffer_ = std::make_shared<MetalBuffer>(
543+
metal_context_, filter_metal_dims_, METAL_PRECISION_TYPE::HALF);
544+
filter_buffer_->convert_to_nhwc_ = false;
545+
filter_buffer_->CopyFromNCHW<float>(transposed);
546+
scoure.weights = filter_buffer_->rawdata();
547+
if (param.bias && canMPSAddByChannel()) {
548+
if (bias_buffer_->src_tensor_) {
549+
lite::Tensor* y = (lite::Tensor*)(bias_buffer_->src_tensor_);
550+
auto bias = y->data<float>();
551+
scoure.biasTerms = const_cast<float*>(bias);
552+
}
553+
}
554+
mps_conv_trans_op_ = (__bridge_retained void*)[[MPSCNNConvolutionTranspose alloc]
555+
initWithDevice:backend.device
556+
weights:scoure];
557+
((__bridge MPSCNNConvolutionTranspose*)mps_conv_trans_op_).offset =
558+
MPSOffset{.x = 0, .y = 0, .z = 0};
559+
((__bridge MPSCNNConvolutionTranspose*)mps_conv_trans_op_).edgeMode = MPSImageEdgeModeZero;
560+
((__bridge MPSCNNConvolutionTranspose*)mps_conv_trans_op_).kernelOffsetX = offsetX;
561+
((__bridge MPSCNNConvolutionTranspose*)mps_conv_trans_op_).kernelOffsetY = offsetY;
562+
563+
// MPS input and output
564+
mps_input_image_ =
565+
(__bridge_retained void*)[[MPSImage alloc] initWithTexture:input_buffer_->image()
566+
featureChannels:input_c];
567+
mps_output_image_ =
568+
(__bridge_retained void*)[[MPSImage alloc] initWithTexture:output_buffer_->image()
569+
featureChannels:output_c];
570+
}
433571
}
434572

435573
#pragma mark - internal

0 commit comments

Comments
 (0)