Skip to content

Commit 64ee255

Browse files
authored
[Paddle-TRT] yolobox (#31755)
* yolobox converter and plugin * yolobox unittest * add dynamic shape restriction * fix git merge log
1 parent c4b60ef commit 64ee255

File tree

8 files changed

+689
-0
lines changed

8 files changed

+689
-0
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,7 @@ USE_TRT_CONVERTER(scale);
11921192
USE_TRT_CONVERTER(stack);
11931193
USE_TRT_CONVERTER(clip);
11941194
USE_TRT_CONVERTER(gather);
1195+
USE_TRT_CONVERTER(yolo_box);
11951196
USE_TRT_CONVERTER(roi_align);
11961197
USE_TRT_CONVERTER(affine_channel);
11971198
USE_TRT_CONVERTER(multiclass_nms);

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ nv_library(tensorrt_converter
66
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
77
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
88
gather_op.cc
9+
yolo_box_op.cc
910
roi_align_op.cc
1011
affine_channel_op.cc
1112
multiclass_nms_op.cc
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include <vector>
13+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
14+
#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h"
15+
16+
namespace paddle {
17+
namespace framework {
18+
class Scope;
19+
namespace proto {
20+
class OpDesc;
21+
} // namespace proto
22+
} // namespace framework
23+
} // namespace paddle
24+
25+
namespace paddle {
26+
namespace inference {
27+
namespace tensorrt {
28+
29+
class YoloBoxOpConverter : public OpConverter {
30+
public:
31+
void operator()(const framework::proto::OpDesc& op,
32+
const framework::Scope& scope, bool test_mode) override {
33+
VLOG(3) << "convert a fluid yolo box op to tensorrt plugin";
34+
35+
framework::OpDesc op_desc(op, nullptr);
36+
std::string X = op_desc.Input("X").front();
37+
std::string img_size = op_desc.Input("ImgSize").front();
38+
39+
auto* X_tensor = engine_->GetITensor(X);
40+
auto* img_size_tensor = engine_->GetITensor(img_size);
41+
42+
int class_num = BOOST_GET_CONST(int, op_desc.GetAttr("class_num"));
43+
std::vector<int> anchors =
44+
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("anchors"));
45+
46+
int downsample_ratio =
47+
BOOST_GET_CONST(int, op_desc.GetAttr("downsample_ratio"));
48+
float conf_thresh = BOOST_GET_CONST(float, op_desc.GetAttr("conf_thresh"));
49+
bool clip_bbox = BOOST_GET_CONST(bool, op_desc.GetAttr("clip_bbox"));
50+
float scale_x_y = BOOST_GET_CONST(float, op_desc.GetAttr("scale_x_y"));
51+
52+
int type_id = static_cast<int>(engine_->WithFp16());
53+
auto input_dim = X_tensor->getDimensions();
54+
auto* yolo_box_plugin = new plugin::YoloBoxPlugin(
55+
type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
56+
anchors, class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y,
57+
input_dim.d[1], input_dim.d[2]);
58+
59+
std::vector<nvinfer1::ITensor*> yolo_box_inputs;
60+
yolo_box_inputs.push_back(X_tensor);
61+
yolo_box_inputs.push_back(img_size_tensor);
62+
63+
auto* yolo_box_layer = engine_->network()->addPluginV2(
64+
yolo_box_inputs.data(), yolo_box_inputs.size(), *yolo_box_plugin);
65+
66+
std::vector<std::string> output_names;
67+
output_names.push_back(op_desc.Output("Boxes").front());
68+
output_names.push_back(op_desc.Output("Scores").front());
69+
70+
RreplenishLayerAndOutput(yolo_box_layer, "yolo_box", output_names,
71+
test_mode);
72+
}
73+
};
74+
75+
} // namespace tensorrt
76+
} // namespace inference
77+
} // namespace paddle
78+
79+
REGISTER_TRT_OP_CONVERTER(yolo_box, YoloBoxOpConverter);

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ struct SimpleOpTypeSetTeller : public Teller {
111111
"flatten2",
112112
"flatten",
113113
"gather",
114+
"yolo_box",
114115
"roi_align",
115116
"affine_channel",
116117
"multiclass_nms",
@@ -198,6 +199,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
198199
if (!with_dynamic_shape || desc.Input("Axis").size() > 0) return false;
199200
}
200201

202+
if (op_type == "yolo_box") {
203+
if (with_dynamic_shape) return false;
204+
bool has_attrs =
205+
(desc.HasAttr("class_num") && desc.HasAttr("anchors") &&
206+
desc.HasAttr("downsample_ratio") && desc.HasAttr("conf_thresh") &&
207+
desc.HasAttr("clip_bbox") && desc.HasAttr("scale_x_y"));
208+
return has_attrs;
209+
}
210+
201211
if (op_type == "affine_channel") {
202212
if (!desc.HasAttr("data_layout")) return false;
203213
auto data_layout = framework::StringToDataLayout(

paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ nv_library(tensorrt_plugin
55
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
66
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
77
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
8+
yolo_box_op_plugin.cu
89
roi_align_op_plugin.cu
910
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
1011

0 commit comments

Comments
 (0)