Skip to content

Commit c59c8e4

Browse files
authored
[inference]add hard_swish dynamic plugin (#35214)
1 parent d43f797 commit c59c8e4

File tree

4 files changed

+313
-12
lines changed

4 files changed

+313
-12
lines changed

paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,21 @@ class HardSwishOpConverter : public OpConverter {
6464
nvinfer1::ElementWiseOperation::kPROD);
6565
layer = eltwise_layer;
6666
} else {
67-
plugin::HardSwishPlugin* plugin =
68-
new plugin::HardSwishPlugin(threshold, scale, offset);
69-
layer = engine_->AddPlugin(&input, input_num, plugin);
67+
if (engine_->with_dynamic_shape()) {
68+
#if IS_TRT_VERSION_GE(6000)
69+
plugin::HardSwishPluginDynamic* plugin =
70+
new plugin::HardSwishPluginDynamic(threshold, scale, offset);
71+
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
72+
#else
73+
PADDLE_THROW(platform::errors::Fatal(
74+
"You are running the TRT Dynamic Shape mode, need to confirm that "
75+
"your TRT version is no less than 6.0"));
76+
#endif
77+
} else {
78+
plugin::HardSwishPlugin* plugin =
79+
new plugin::HardSwishPlugin(threshold, scale, offset);
80+
layer = engine_->AddPlugin(&input, input_num, plugin);
81+
}
7082
}
7183
auto output_name = op_desc.Output("Out")[0];
7284
RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode);

paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ namespace tensorrt {
2222
namespace plugin {
2323

2424
nvinfer1::Dims HardSwishPlugin::getOutputDimensions(
25-
int index, const nvinfer1::Dims* in_dims, int nb_inputs) TRT_NOEXCEPT {
25+
int index, const nvinfer1::Dims *in_dims, int nb_inputs) TRT_NOEXCEPT {
2626
assert(nb_inputs == 1);
2727
assert(index < this->getNbOutputs());
28-
nvinfer1::Dims const& input_dims = in_dims[0];
28+
nvinfer1::Dims const &input_dims = in_dims[0];
2929
nvinfer1::Dims output_dims = input_dims;
3030
return output_dims;
3131
}
@@ -42,22 +42,22 @@ __device__ T kMin(T a, T b) {
4242

4343
template <typename T, unsigned TPB>
4444
__global__ void hard_swish_kernel(float threshold, float scale, float offset,
45-
int n, const T* input, T* output) {
45+
int n, const T *input, T *output) {
4646
const int idx = blockIdx.x * TPB + threadIdx.x;
4747
if (idx < n) {
4848
const T in = input[idx];
4949
output[idx] = in / scale * kMin<T>(kMax<T>(in + offset, 0), threshold);
5050
}
5151
}
5252

53-
int HardSwishPlugin::enqueue(int batch_size, const void* const* inputs,
53+
int HardSwishPlugin::enqueue(int batch_size, const void *const *inputs,
5454
#if IS_TRT_VERSION_LT(8000)
55-
void** outputs, void*, cudaStream_t stream) {
55+
void **outputs, void *, cudaStream_t stream) {
5656
#else
57-
void* const* outputs, void*,
57+
void *const *outputs, void *,
5858
cudaStream_t stream) TRT_NOEXCEPT {
5959
#endif
60-
const auto& input_dims = this->getInputDims(0);
60+
const auto &input_dims = this->getInputDims(0);
6161
int num = batch_size;
6262
for (int i = 0; i < input_dims.nbDims; i++) {
6363
num *= input_dims.d[i];
@@ -69,14 +69,79 @@ int HardSwishPlugin::enqueue(int batch_size, const void* const* inputs,
6969
const int block_size = 256;
7070
const int grid_size = (num + block_size - 1) / block_size;
7171

72-
const float* input = static_cast<const float*>(inputs[0]);
73-
float* output = static_cast<float*>(outputs[0]);
72+
const float *input = static_cast<const float *>(inputs[0]);
73+
float *output = static_cast<float *>(outputs[0]);
7474
hard_swish_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
7575
threshold, scale, offset, num, input, output);
7676

7777
return cudaGetLastError() != cudaSuccess;
7878
}
7979

80+
#if IS_TRT_VERSION_GE(6000)
81+
82+
nvinfer1::DimsExprs HardSwishPluginDynamic::getOutputDimensions(
83+
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
84+
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
85+
return inputs[0];
86+
}
87+
88+
int HardSwishPluginDynamic::enqueue(
89+
const nvinfer1::PluginTensorDesc *input_desc,
90+
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
91+
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT {
92+
auto input_dims = input_desc[0].dims;
93+
int num = 1;
94+
for (int i = 0; i < input_dims.nbDims; i++) {
95+
num *= input_dims.d[i];
96+
}
97+
float threshold = threshold_;
98+
float scale = scale_;
99+
float offset = offset_;
100+
const int block_size = 256;
101+
const int grid_size = (num + block_size - 1) / block_size;
102+
const float *input = static_cast<const float *>(inputs[0]);
103+
float *output = static_cast<float *>(outputs[0]);
104+
hard_swish_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
105+
threshold, scale, offset, num, input, output);
106+
107+
return cudaGetLastError() != cudaSuccess;
108+
}
109+
110+
nvinfer1::DataType HardSwishPluginDynamic::getOutputDataType(
111+
int index, const nvinfer1::DataType *input_types,
112+
int nb_inputs) const TRT_NOEXCEPT {
113+
PADDLE_ENFORCE_EQ(index, 0,
114+
platform::errors::InvalidArgument(
115+
"The Elementwise Plugin only has one input, so the "
116+
"index value should be 0, but get %d.",
117+
index));
118+
return input_types[0];
119+
}
120+
121+
bool HardSwishPluginDynamic::supportsFormatCombination(
122+
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
123+
int nb_outputs) TRT_NOEXCEPT {
124+
PADDLE_ENFORCE_NOT_NULL(
125+
in_out, platform::errors::InvalidArgument(
126+
"The input of swish plugin shoule not be nullptr."));
127+
128+
PADDLE_ENFORCE_LT(
129+
pos, nb_inputs + nb_outputs,
130+
platform::errors::InvalidArgument("The pos(%d) should be less than the "
131+
"num(%d) of the input and the output.",
132+
pos, nb_inputs + nb_outputs));
133+
(in_out && pos < (nb_inputs + nb_outputs));
134+
135+
const nvinfer1::PluginTensorDesc &in = in_out[pos];
136+
if (pos == 0) {
137+
return (in.type == nvinfer1::DataType::kFLOAT) &&
138+
(in.format == nvinfer1::TensorFormat::kLINEAR);
139+
}
140+
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
141+
// output
142+
return in.type == prev.type && in.format == prev.format;
143+
}
144+
#endif
80145
} // namespace plugin
81146
} // namespace tensorrt
82147
} // namespace inference

paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,113 @@ class HardSwishPluginCreator : public TensorRTPluginCreator {
9494
};
9595
REGISTER_TRT_PLUGIN_V2(HardSwishPluginCreator);
9696

97+
#if IS_TRT_VERSION_GE(6000)
98+
class HardSwishPluginDynamic : public DynamicPluginTensorRT {
99+
public:
100+
HardSwishPluginDynamic(const float threshold, const float scale,
101+
const float offset)
102+
: threshold_(threshold), scale_(scale), offset_(offset) {}
103+
104+
// It was used for tensorrt deserialization.
105+
// It should not be called by users.
106+
HardSwishPluginDynamic(void const* serialData, size_t serialLength) {
107+
DeserializeValue(&serialData, &serialLength, &threshold_);
108+
DeserializeValue(&serialData, &serialLength, &scale_);
109+
DeserializeValue(&serialData, &serialLength, &offset_);
110+
}
111+
~HardSwishPluginDynamic() {}
112+
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
113+
return new HardSwishPluginDynamic(threshold_, scale_, offset_);
114+
}
115+
const char* getPluginType() const TRT_NOEXCEPT override {
116+
return "hard_swish_plugin_dynamic";
117+
}
118+
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
119+
int initialize() TRT_NOEXCEPT override { return 0; }
120+
nvinfer1::DimsExprs getOutputDimensions(
121+
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
122+
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override;
123+
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
124+
const nvinfer1::PluginTensorDesc* outputDesc,
125+
const void* const* inputs, void* const* outputs, void* workspace,
126+
cudaStream_t stream) TRT_NOEXCEPT override;
127+
128+
size_t getSerializationSize() const TRT_NOEXCEPT override {
129+
return SerializedSize(threshold_) + SerializedSize(scale_) +
130+
SerializedSize(offset_);
131+
}
132+
133+
// TRT will call this func to serialize the configuration of TRT
134+
// It should not be called by users.
135+
void serialize(void* buffer) const TRT_NOEXCEPT override {
136+
SerializeValue(&buffer, threshold_);
137+
SerializeValue(&buffer, scale_);
138+
SerializeValue(&buffer, offset_);
139+
}
140+
nvinfer1::DataType getOutputDataType(
141+
int index, const nvinfer1::DataType* inputTypes,
142+
int nbInputs) const TRT_NOEXCEPT override;
143+
bool supportsFormatCombination(int pos,
144+
const nvinfer1::PluginTensorDesc* inOut,
145+
int nbInputs,
146+
int nbOutputs) TRT_NOEXCEPT override;
147+
148+
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
149+
int nbInputs,
150+
const nvinfer1::DynamicPluginTensorDesc* out,
151+
int nbOutputs) TRT_NOEXCEPT override {}
152+
void destroy() TRT_NOEXCEPT override { delete this; }
153+
154+
protected:
155+
float threshold_;
156+
float scale_;
157+
float offset_;
158+
};
159+
160+
class HardSwishPluginDynamicCreator : public nvinfer1::IPluginCreator {
161+
public:
162+
HardSwishPluginDynamicCreator() {}
163+
const char* getPluginName() const TRT_NOEXCEPT override {
164+
return "hardswish_plugin_dynamic";
165+
}
166+
167+
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
168+
169+
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
170+
return &field_collection_;
171+
}
172+
173+
nvinfer1::IPluginV2* createPlugin(const char* name,
174+
const nvinfer1::PluginFieldCollection* fc)
175+
TRT_NOEXCEPT override {
176+
return nullptr;
177+
}
178+
179+
nvinfer1::IPluginV2* deserializePlugin(
180+
const char* name, const void* serial_data,
181+
size_t serial_length) TRT_NOEXCEPT override {
182+
auto plugin = new HardSwishPluginDynamic(serial_data, serial_length);
183+
return plugin;
184+
}
185+
186+
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
187+
plugin_namespace_ = lib_namespace;
188+
}
189+
190+
const char* getPluginNamespace() const TRT_NOEXCEPT override {
191+
return plugin_namespace_.c_str();
192+
}
193+
194+
private:
195+
std::string plugin_namespace_;
196+
std::string plugin_name_;
197+
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
198+
std::vector<nvinfer1::PluginField> plugin_attributes_;
199+
};
200+
REGISTER_TRT_PLUGIN_V2(HardSwishPluginDynamicCreator);
201+
202+
#endif
203+
97204
} // namespace plugin
98205
} // namespace tensorrt
99206
} // namespace inference
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
16+
from program_config import TensorConfig, ProgramConfig
17+
import numpy as np
18+
import paddle.inference as paddle_infer
19+
from functools import partial
20+
from typing import Optional, List, Callable, Dict, Any, Set
21+
import unittest
22+
23+
24+
class TrtConvertHardSwishTest(TrtLayerAutoScanTest):
25+
def is_program_valid(self, program_config: ProgramConfig) -> bool:
26+
inputs = program_config.inputs
27+
weights = program_config.weights
28+
attrs = [
29+
program_config.ops[i].attrs
30+
for i in range(len(program_config.ops))
31+
]
32+
33+
if attrs[0]['threshold'] <= 0 or attrs[0]['scale'] <= 0:
34+
return False
35+
36+
return True
37+
38+
def sample_program_configs(self):
39+
def generate_input1(attrs: List[Dict[str, Any]]):
40+
return np.ones([1, 3, 64, 64]).astype(np.float32)
41+
42+
for threshold in [6.0, 7.0, 100.0, 0.0, -1.0]:
43+
for scale in [5.0, 6.0, 7.0, -1.0, 0.0, 100.0]:
44+
for offset in [3.0, 4.0, 5.0, -1.0, 0.0, 100.0]:
45+
dics = [{
46+
"threshold": threshold,
47+
"scale": scale,
48+
"offset": offset
49+
}]
50+
51+
ops_config = [{
52+
"op_type": "hard_swish",
53+
"op_inputs": {
54+
"X": ["input_data"]
55+
},
56+
"op_outputs": {
57+
"Out": ["hard_swish_output_data"]
58+
},
59+
"op_attrs": dics[0]
60+
}]
61+
ops = self.generate_op_config(ops_config)
62+
63+
program_config = ProgramConfig(
64+
ops=ops,
65+
weights={},
66+
inputs={
67+
"input_data": TensorConfig(data_gen=partial(
68+
generate_input1, dics))
69+
},
70+
outputs=["hard_swish_output_data"])
71+
72+
yield program_config
73+
74+
def sample_predictor_configs(
75+
self, program_config) -> (paddle_infer.Config, List[int], float):
76+
def generate_dynamic_shape(attrs):
77+
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
78+
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
79+
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
80+
81+
def clear_dynamic_shape():
82+
self.dynamic_shape.min_input_shape = {}
83+
self.dynamic_shape.max_input_shape = {}
84+
self.dynamic_shape.opt_input_shape = {}
85+
86+
def generate_trt_nodes_num(attrs, dynamic_shape):
87+
return 1, 2
88+
89+
attrs = [
90+
program_config.ops[i].attrs
91+
for i in range(len(program_config.ops))
92+
]
93+
94+
# for static_shape
95+
clear_dynamic_shape()
96+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
97+
yield self.create_inference_config(), generate_trt_nodes_num(
98+
attrs, False), 1e-5
99+
self.trt_param.precision = paddle_infer.PrecisionType.Half
100+
yield self.create_inference_config(), generate_trt_nodes_num(
101+
attrs, False), (1e-5, 1e-5)
102+
103+
# for dynamic_shape
104+
generate_dynamic_shape(attrs)
105+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
106+
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
107+
True), 1e-5
108+
self.trt_param.precision = paddle_infer.PrecisionType.Half
109+
yield self.create_inference_config(), generate_trt_nodes_num(
110+
attrs, True), (1e-5, 1e-5)
111+
112+
def test(self):
113+
self.run_test()
114+
115+
116+
if __name__ == "__main__":
117+
unittest.main()

0 commit comments

Comments
 (0)