diff --git a/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.cu b/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.cu old mode 100755 new mode 100644 index 560dc561fd5..78f9d227b4c --- a/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.cu +++ b/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.cu @@ -17,8 +17,8 @@ #include "adaptive_pool2d_kernel.h" namespace fastdeploy { - -__global__ void CudaCastKernel(const float* in, float* out, int edge, +template +__global__ void CudaCastKernel(const T1* in, T2* out, int edge, int out_bc_offset, int in_bc_offset, int ih, int iw, int oh, int ow, bool is_avg) { int position = blockDim.x * blockIdx.x + threadIdx.x; @@ -32,29 +32,37 @@ __global__ void CudaCastKernel(const float* in, float* out, int edge, int hend = ceilf(static_cast((h + 1) * ih) / oh); int wstart = floorf(static_cast(w * iw) / ow); int wend = ceilf(static_cast((w + 1) * iw) / ow); + float ele_val = 0.0; if (is_avg) { - out[position] = 0.0; + ele_val = 0.0; } else { - out[position] = in[offset * in_bc_offset + hstart * iw + wstart]; + ele_val = + static_cast(in[offset * in_bc_offset + hstart * iw + wstart]); } for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int input_idx = h * iw + w; if (is_avg) { - out[position] = out[position] + in[offset * in_bc_offset + input_idx]; + ele_val = + ele_val + static_cast(in[offset * in_bc_offset + input_idx]); } else { - out[position] = - max(out[position], in[offset * in_bc_offset + input_idx]); + ele_val = + (ele_val > + static_cast(in[offset * in_bc_offset + input_idx])) + ? ele_val + : static_cast(in[offset * in_bc_offset + input_idx]); } } } - out[position] = out[position] / ((hend - hstart) * (wend - wstart)); + out[position] = static_cast( + ele_val / static_cast(((hend - hstart) * (wend - wstart)))); } void CudaAdaptivePool(const std::vector& input_dims, - const std::vector& output_dims, float* output, - const float* input, void* compute_stream, - const std::string& pooling_type) { + const std::vector& output_dims, void* output, + const void* input, void* compute_stream, + const std::string& pooling_type, const std::string& dtype, + const std::string& out_dtype) { auto casted_compute_stream = reinterpret_cast(compute_stream); int out_bc_offset = output_dims[2] * output_dims[3]; int in_bc_offset = input_dims[2] * input_dims[3]; @@ -65,9 +73,27 @@ void CudaAdaptivePool(const std::vector& input_dims, bool is_avg = pooling_type == "avg"; int threads = 256; int blocks = ceil(jobs / static_cast(threads)); - CudaCastKernel<<>>( - input, output, jobs, out_bc_offset, in_bc_offset, int(input_dims[2]), - int(input_dims[3]), int(output_dims[2]), int(output_dims[3]), is_avg); + if (dtype == "float") { + CudaCastKernel<<>>( + static_cast(input), static_cast(output), jobs, + out_bc_offset, in_bc_offset, int(input_dims[2]), int(input_dims[3]), + int(output_dims[2]), int(output_dims[3]), is_avg); + } else if (dtype == "half") { + if (out_dtype == "half") { + CudaCastKernel<<>>( + static_cast(input), static_cast(output), jobs, + out_bc_offset, in_bc_offset, int(input_dims[2]), int(input_dims[3]), + int(output_dims[2]), int(output_dims[3]), is_avg); + } + if (out_dtype == "float") { + CudaCastKernel + <<>>( + static_cast(input), static_cast(output), + jobs, out_bc_offset, in_bc_offset, int(input_dims[2]), + int(input_dims[3]), int(output_dims[2]), int(output_dims[3]), + is_avg); + } + } } } // namespace fastdeploy #endif diff --git a/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.h b/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.h index dc29c07dc0f..ddb7cb81555 100755 --- a/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.h +++ b/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.h @@ -15,6 +15,7 @@ #pragma once +#include #include #include #include @@ -25,8 +26,10 @@ namespace fastdeploy { void CudaAdaptivePool(const std::vector& input_dims, - const std::vector& output_dims, float* output, - const float* input, void* compute_stream, - const std::string& pooling_type); + const std::vector& output_dims, void* output, + const void* input, void* compute_stream, + const std::string& pooling_type, + const std::string& dtype = "float", + const std::string& out_dtype = "float"); } // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/tensorrt/ops/adaptive_pool2d.cc b/fastdeploy/runtime/backends/tensorrt/ops/adaptive_pool2d.cc index ae7cef7f41a..a9779447505 100644 --- a/fastdeploy/runtime/backends/tensorrt/ops/adaptive_pool2d.cc +++ b/fastdeploy/runtime/backends/tensorrt/ops/adaptive_pool2d.cc @@ -63,11 +63,6 @@ int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { - if (inputDesc[0].type != nvinfer1::DataType::kFLOAT) { - return -1; - } - auto const* data = static_cast(inputs[0]); - auto* result = static_cast(outputs[0]); int nums = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] * outputDesc[0].dims.d[2] * outputDesc[0].dims.d[3]; std::vector input_size, output_size; @@ -75,8 +70,18 @@ int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, input_size.push_back(inputDesc[0].dims.d[i]); output_size.push_back(outputDesc[0].dims.d[i]); } - CudaAdaptivePool(input_size, output_size, result, data, stream, - pooling_type_); + if (inputDesc[0].type == nvinfer1::DataType::kHALF) { + if (outputDesc[0].type == nvinfer1::DataType::kHALF) { + CudaAdaptivePool(input_size, output_size, outputs[0], inputs[0], stream, + pooling_type_, "half", "half"); + } else if (outputDesc[0].type == nvinfer1::DataType::kFLOAT) { + CudaAdaptivePool(input_size, output_size, outputs[0], inputs[0], stream, + pooling_type_, "half", "float"); + } + } else if (inputDesc[0].type == nvinfer1::DataType::kFLOAT) { + CudaAdaptivePool(input_size, output_size, outputs[0], inputs[0], stream, + pooling_type_, "float", "float"); + } return cudaPeekAtLastError(); } @@ -106,7 +111,12 @@ nvinfer1::DataType AdaptivePool2d::getOutputDataType( bool AdaptivePool2d::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept { - return (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR); + if ((inOut[pos].format == nvinfer1::PluginFormat::kLINEAR) && + (inOut[pos].type == nvinfer1::DataType::kFLOAT || + inOut[pos].type == nvinfer1::DataType::kHALF)) { + return true; + } + return false; } int AdaptivePool2d::initialize() noexcept { return 0; }