Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,21 @@ class HardSwishOpConverter : public OpConverter {
nvinfer1::ElementWiseOperation::kPROD);
layer = eltwise_layer;
} else {
plugin::HardSwishPlugin* plugin =
new plugin::HardSwishPlugin(threshold, scale, offset);
layer = engine_->AddPlugin(&input, input_num, plugin);
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
plugin::HardSwishPluginDynamic* plugin =
new plugin::HardSwishPluginDynamic(threshold, scale, offset);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
plugin::HardSwishPlugin* plugin =
new plugin::HardSwishPlugin(threshold, scale, offset);
layer = engine_->AddPlugin(&input, input_num, plugin);
}
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode);
Expand Down
83 changes: 74 additions & 9 deletions paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ namespace tensorrt {
namespace plugin {

nvinfer1::Dims HardSwishPlugin::getOutputDimensions(
int index, const nvinfer1::Dims* in_dims, int nb_inputs) TRT_NOEXCEPT {
int index, const nvinfer1::Dims *in_dims, int nb_inputs) TRT_NOEXCEPT {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang format 自动改的,下同

assert(nb_inputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const& input_dims = in_dims[0];
nvinfer1::Dims const &input_dims = in_dims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
Expand All @@ -42,22 +42,22 @@ __device__ T kMin(T a, T b) {

template <typename T, unsigned TPB>
__global__ void hard_swish_kernel(float threshold, float scale, float offset,
int n, const T* input, T* output) {
int n, const T *input, T *output) {
const int idx = blockIdx.x * TPB + threadIdx.x;
if (idx < n) {
const T in = input[idx];
output[idx] = in / scale * kMin<T>(kMax<T>(in + offset, 0), threshold);
}
}

int HardSwishPlugin::enqueue(int batch_size, const void* const* inputs,
int HardSwishPlugin::enqueue(int batch_size, const void *const *inputs,
#if IS_TRT_VERSION_LT(8000)
void** outputs, void*, cudaStream_t stream) {
void **outputs, void *, cudaStream_t stream) {
#else
void* const* outputs, void*,
void *const *outputs, void *,
cudaStream_t stream) TRT_NOEXCEPT {
#endif
const auto& input_dims = this->getInputDims(0);
const auto &input_dims = this->getInputDims(0);
int num = batch_size;
for (int i = 0; i < input_dims.nbDims; i++) {
num *= input_dims.d[i];
Expand All @@ -69,14 +69,79 @@ int HardSwishPlugin::enqueue(int batch_size, const void* const* inputs,
const int block_size = 256;
const int grid_size = (num + block_size - 1) / block_size;

const float* input = static_cast<const float*>(inputs[0]);
float* output = static_cast<float*>(outputs[0]);
const float *input = static_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]);
hard_swish_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
threshold, scale, offset, num, input, output);

return cudaGetLastError() != cudaSuccess;
}

#if IS_TRT_VERSION_GE(6000)

nvinfer1::DimsExprs HardSwishPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
return inputs[0];
}

int HardSwishPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims;
int num = 1;
for (int i = 0; i < input_dims.nbDims; i++) {
num *= input_dims.d[i];
}
float threshold = threshold_;
float scale = scale_;
float offset = offset_;
const int block_size = 256;
const int grid_size = (num + block_size - 1) / block_size;
const float *input = static_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]);
hard_swish_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
threshold, scale, offset, num, input, output);

return cudaGetLastError() != cudaSuccess;
}

nvinfer1::DataType HardSwishPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(index, 0,
platform::errors::InvalidArgument(
"The Elementwise Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}

bool HardSwishPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));

PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
(in_out && pos < (nb_inputs + nb_outputs));

const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
Expand Down
107 changes: 107 additions & 0 deletions paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,113 @@ class HardSwishPluginCreator : public TensorRTPluginCreator {
};
REGISTER_TRT_PLUGIN_V2(HardSwishPluginCreator);

#if IS_TRT_VERSION_GE(6000)
class HardSwishPluginDynamic : public DynamicPluginTensorRT {
public:
HardSwishPluginDynamic(const float threshold, const float scale,
const float offset)
: threshold_(threshold), scale_(scale), offset_(offset) {}

// It was used for tensorrt deserialization.
// It should not be called by users.
HardSwishPluginDynamic(void const* serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &threshold_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &offset_);
}
~HardSwishPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new HardSwishPluginDynamic(threshold_, scale_, offset_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "hard_swish_plugin_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override { return 0; }
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;

size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(threshold_) + SerializedSize(scale_) +
SerializedSize(offset_);
}

// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, threshold_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, offset_);
}
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;

void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override {}
void destroy() TRT_NOEXCEPT override { delete this; }

protected:
float threshold_;
float scale_;
float offset_;
};

class HardSwishPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
HardSwishPluginDynamicCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "hardswish_plugin_dynamic";
}

const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }

const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}

nvinfer1::IPluginV2* createPlugin(const char* name,
const nvinfer1::PluginFieldCollection* fc)
TRT_NOEXCEPT override {
return nullptr;
}

nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT override {
auto plugin = new HardSwishPluginDynamic(serial_data, serial_length);
return plugin;
}

void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}

const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}

private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(HardSwishPluginDynamicCreator);

#endif

} // namespace plugin
} // namespace tensorrt
} // namespace inference
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest


class TrtConvertHardSwishTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
weights = program_config.weights
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]

if attrs[0]['threshold'] <= 0 or attrs[0]['scale'] <= 0:
return False

return True

def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]):
return np.ones([1, 3, 64, 64]).astype(np.float32)

for threshold in [6.0, 7.0, 100.0, 0.0, -1.0]:
for scale in [5.0, 6.0, 7.0, -1.0, 0.0, 100.0]:
for offset in [3.0, 4.0, 5.0, -1.0, 0.0, 100.0]:
dics = [{
"threshold": threshold,
"scale": scale,
"offset": offset
}]

ops_config = [{
"op_type": "hard_swish",
"op_inputs": {
"X": ["input_data"]
},
"op_outputs": {
"Out": ["hard_swish_output_data"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)

program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(data_gen=partial(
generate_input1, dics))
},
outputs=["hard_swish_output_data"])

yield program_config

def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}

def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}

def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2

attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]

# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), (1e-5, 1e-5)

# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), (1e-5, 1e-5)

def test(self):
self.run_test()


if __name__ == "__main__":
unittest.main()