Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,7 @@ USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
USE_TRT_CONVERTER(clip);
USE_TRT_CONVERTER(gather);
USE_TRT_CONVERTER(roi_align);
USE_TRT_CONVERTER(affine_channel);
USE_TRT_CONVERTER(multiclass_nms);
USE_TRT_CONVERTER(nearest_interp);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ nv_library(tensorrt_converter
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc
roi_align_op.cc
affine_channel_op.cc
multiclass_nms_op.cc
nearest_interp_op.cc
Expand Down
86 changes: 86 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/roi_align_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h"

namespace paddle {
namespace framework {
class Scope;

namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle

namespace paddle {
namespace inference {
namespace tensorrt {

/*
* Roi Align Op
*/
class RoiAlignOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid roi align op to tensorrt plugin";

framework::OpDesc op_desc(op, nullptr);
std::string input_name = op_desc.Input("X").front();
std::string rois_name = op_desc.Input("ROIs").front();
std::string output_name = op_desc.Output("Out").front();

const auto pooled_height =
BOOST_GET_CONST(int, op_desc.GetAttr("pooled_height"));
const auto pooled_width =
BOOST_GET_CONST(int, op_desc.GetAttr("pooled_width"));
const auto spatial_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("spatial_scale"));
const auto sampling_ratio =
BOOST_GET_CONST(int, op_desc.GetAttr("sampling_ratio"));

const auto input_tensor = engine_->GetITensor(input_name);
const auto rois_tensor = engine_->GetITensor(rois_name);

const nvinfer1::DataType data_type_ = engine_->WithFp16()
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT;

std::vector<nvinfer1::ITensor*> inputs{input_tensor, rois_tensor};
nvinfer1::ILayer* layer = nullptr;

PADDLE_ENFORCE_EQ(
engine_->with_dynamic_shape(), true,
platform::errors::InvalidArgument(
"TRT roi align plugin only accept the dynamic shape, because that "
"the roi_align will change the batch size."));

auto* roi_align_plugin = new plugin::RoiAlignPluginDynamic(
data_type_, pooled_height, pooled_width, spatial_scale, sampling_ratio);
auto roi_align_layer = engine_->network()->addPluginV2(
inputs.data(), inputs.size(), *roi_align_plugin);
layer = roi_align_layer;

std::vector<std::string> output_names{output_name};
RreplenishLayerAndOutput(layer, "roi_align", output_names, test_mode);
}
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(roi_align, RoiAlignOpConverter);
24 changes: 24 additions & 0 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"flatten2",
"flatten",
"gather",
"roi_align",
"affine_channel",
"multiclass_nms",
"nearest_interp",
Expand Down Expand Up @@ -263,6 +264,29 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(std::string, desc.GetAttr("interp_method"));
if (interp_method != "nearest") return false;
}

if (op_type == "roi_align") {
if (!with_dynamic_shape) return false;

std::vector<std::string> attrs{"pooled_height", "pooled_width",
"spatial_scale", "sampling_ratio"};
for (auto const attr : attrs) {
if (!desc.HasAttr(attr)) return false;
}

const auto pooled_height =
BOOST_GET_CONST(int, desc.GetAttr("pooled_height"));
if (pooled_height <= 0) return false;

const auto pooled_width =
BOOST_GET_CONST(int, desc.GetAttr("pooled_width"));
if (pooled_width <= 0) return false;

const auto spatial_scale =
BOOST_GET_CONST(float, desc.GetAttr("spatial_scale"));
if (spatial_scale <= 0.f) return false;
}

if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
}
return false;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ nv_library(tensorrt_plugin
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
roi_align_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)

nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS
Expand Down
Loading