|
| 1 | +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +you may not use this file except in compliance with the License. |
| 4 | +You may obtain a copy of the License at |
| 5 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +Unless required by applicable law or agreed to in writing, software |
| 7 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +See the License for the specific language governing permissions and |
| 10 | +limitations under the License. */ |
| 11 | + |
| 12 | +#include <vector> |
| 13 | +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" |
| 14 | +#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h" |
| 15 | + |
| 16 | +namespace paddle { |
| 17 | +namespace framework { |
| 18 | +class Scope; |
| 19 | +namespace proto { |
| 20 | +class OpDesc; |
| 21 | +} // namespace proto |
| 22 | +} // namespace framework |
| 23 | +} // namespace paddle |
| 24 | + |
| 25 | +namespace paddle { |
| 26 | +namespace inference { |
| 27 | +namespace tensorrt { |
| 28 | + |
| 29 | +class YoloBoxOpConverter : public OpConverter { |
| 30 | + public: |
| 31 | + void operator()(const framework::proto::OpDesc& op, |
| 32 | + const framework::Scope& scope, bool test_mode) override { |
| 33 | + VLOG(3) << "convert a fluid yolo box op to tensorrt plugin"; |
| 34 | + |
| 35 | + framework::OpDesc op_desc(op, nullptr); |
| 36 | + std::string X = op_desc.Input("X").front(); |
| 37 | + std::string img_size = op_desc.Input("ImgSize").front(); |
| 38 | + |
| 39 | + auto* X_tensor = engine_->GetITensor(X); |
| 40 | + auto* img_size_tensor = engine_->GetITensor(img_size); |
| 41 | + |
| 42 | + int class_num = BOOST_GET_CONST(int, op_desc.GetAttr("class_num")); |
| 43 | + std::vector<int> anchors = |
| 44 | + BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("anchors")); |
| 45 | + |
| 46 | + int downsample_ratio = |
| 47 | + BOOST_GET_CONST(int, op_desc.GetAttr("downsample_ratio")); |
| 48 | + float conf_thresh = BOOST_GET_CONST(float, op_desc.GetAttr("conf_thresh")); |
| 49 | + bool clip_bbox = BOOST_GET_CONST(bool, op_desc.GetAttr("clip_bbox")); |
| 50 | + float scale_x_y = BOOST_GET_CONST(float, op_desc.GetAttr("scale_x_y")); |
| 51 | + |
| 52 | + int type_id = static_cast<int>(engine_->WithFp16()); |
| 53 | + auto input_dim = X_tensor->getDimensions(); |
| 54 | + auto* yolo_box_plugin = new plugin::YoloBoxPlugin( |
| 55 | + type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, |
| 56 | + anchors, class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y, |
| 57 | + input_dim.d[1], input_dim.d[2]); |
| 58 | + |
| 59 | + std::vector<nvinfer1::ITensor*> yolo_box_inputs; |
| 60 | + yolo_box_inputs.push_back(X_tensor); |
| 61 | + yolo_box_inputs.push_back(img_size_tensor); |
| 62 | + |
| 63 | + auto* yolo_box_layer = engine_->network()->addPluginV2( |
| 64 | + yolo_box_inputs.data(), yolo_box_inputs.size(), *yolo_box_plugin); |
| 65 | + |
| 66 | + std::vector<std::string> output_names; |
| 67 | + output_names.push_back(op_desc.Output("Boxes").front()); |
| 68 | + output_names.push_back(op_desc.Output("Scores").front()); |
| 69 | + |
| 70 | + RreplenishLayerAndOutput(yolo_box_layer, "yolo_box", output_names, |
| 71 | + test_mode); |
| 72 | + } |
| 73 | +}; |
| 74 | + |
| 75 | +} // namespace tensorrt |
| 76 | +} // namespace inference |
| 77 | +} // namespace paddle |
| 78 | + |
| 79 | +REGISTER_TRT_OP_CONVERTER(yolo_box, YoloBoxOpConverter); |
0 commit comments