Skip to content

Commit 6385f5e

Browse files
authored
[Paddle-TRT] Add gather_nd and reduce_sum trt op. (#33324) (#33365)
1 parent 28a18af commit 6385f5e

File tree

12 files changed

+933
-26
lines changed

12 files changed

+933
-26
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,8 @@ USE_TRT_CONVERTER(roi_align);
12341234
USE_TRT_CONVERTER(affine_channel);
12351235
USE_TRT_CONVERTER(multiclass_nms);
12361236
USE_TRT_CONVERTER(nearest_interp);
1237+
USE_TRT_CONVERTER(reduce_sum);
1238+
USE_TRT_CONVERTER(gather_nd);
12371239
USE_TRT_CONVERTER(reshape);
12381240
#endif
12391241

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ nv_library(tensorrt_converter
1212
affine_channel_op.cc
1313
multiclass_nms_op.cc
1414
nearest_interp_op.cc
15+
reduce_op.cc
16+
gather_nd_op.cc
1517
reshape_op.cc
1618
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
1719

paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,19 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
4040
auto word_emb_name = op_desc.Input("WordEmbedding").front();
4141
auto pos_emb_name = op_desc.Input("PosEmbedding").front();
4242
auto sent_emb_name = op_desc.Input("SentEmbedding").front();
43-
std::vector<std::string> id_names = {word_id_name, pos_id_name,
44-
sent_id_name};
45-
std::vector<std::string> emb_names = {word_emb_name, pos_emb_name,
46-
sent_emb_name};
43+
44+
std::vector<std::string> id_names;
45+
std::vector<std::string> emb_names;
46+
47+
if (engine_->use_oss()) {
48+
id_names =
49+
std::vector<std::string>{word_id_name, pos_id_name, sent_id_name};
50+
emb_names =
51+
std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name};
52+
} else {
53+
id_names = op_desc.Input("Ids");
54+
emb_names = op_desc.Input("Embs");
55+
}
4756

4857
int input_num = id_names.size();
4958

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
16+
#include "paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h"
17+
18+
namespace paddle {
19+
namespace inference {
20+
namespace tensorrt {
21+
22+
class GatherNdOpConverter : public OpConverter {
23+
public:
24+
void operator()(const framework::proto::OpDesc& op,
25+
const framework::Scope& scope, bool test_mode) override {
26+
VLOG(4) << "convert a paddle gather_nd op to tensorrt gather_nd plugin";
27+
framework::OpDesc op_desc(op, nullptr);
28+
29+
// Declare inputs
30+
std::vector<nvinfer1::ITensor*> inputs;
31+
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
32+
auto* index = engine_->GetITensor(op_desc.Input("Index")[0]);
33+
inputs.emplace_back(input);
34+
inputs.emplace_back(index);
35+
36+
nvinfer1::ILayer* layer = nullptr;
37+
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
38+
plugin::GatherNdPluginDynamic* plugin =
39+
new plugin::GatherNdPluginDynamic(with_fp16);
40+
layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin);
41+
42+
std::string layer_name = "gather_nd (Output: ";
43+
auto output_name = op_desc.Output("Out")[0];
44+
layer->getOutput(0)->setName(output_name.c_str());
45+
engine_->SetITensor(output_name, layer->getOutput(0));
46+
layer_name += output_name;
47+
if (test_mode) {
48+
engine_->DeclareOutput(output_name);
49+
}
50+
layer->setName((layer_name + ")").c_str());
51+
}
52+
};
53+
54+
} // namespace tensorrt
55+
} // namespace inference
56+
} // namespace paddle
57+
58+
REGISTER_TRT_OP_CONVERTER(gather_nd, GatherNdOpConverter);
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <NvInfer.h>
16+
#include <sys/types.h>
17+
18+
#include <cstddef>
19+
#include <cstdint>
20+
#include <vector>
21+
22+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
23+
24+
namespace paddle {
25+
namespace framework {
26+
class Scope;
27+
28+
namespace proto {
29+
class OpDesc;
30+
} // namespace proto
31+
} // namespace framework
32+
} // namespace paddle
33+
34+
namespace paddle {
35+
namespace inference {
36+
namespace tensorrt {
37+
38+
class ReduceSumOpConverter : public OpConverter {
39+
public:
40+
void operator()(const framework::proto::OpDesc& op,
41+
const framework::Scope& scope, bool test_mode) override {
42+
VLOG(4) << "convert a paddle reduce_sum op to tensorrt reduce layer";
43+
framework::OpDesc op_desc(op, nullptr);
44+
45+
auto* x = engine_->GetITensor(op_desc.Input("X").front());
46+
nvinfer1::Dims input_shape = x->getDimensions();
47+
int input_dims = input_shape.nbDims;
48+
49+
bool keep_dim = BOOST_GET_CONST(bool, op_desc.GetAttr("keep_dim"));
50+
std::vector<int32_t> dim =
51+
BOOST_GET_CONST(std::vector<int32_t>, op_desc.GetAttr("dim"));
52+
bool reduce_all = BOOST_GET_CONST(bool, op_desc.GetAttr("reduce_all"));
53+
54+
// Now we only support dynamic_shape mode.
55+
nvinfer1::IReduceLayer* layer = nullptr;
56+
if (reduce_all) {
57+
uint32_t reduce_dim = 0;
58+
for (int i = 0; i < input_dims; ++i) {
59+
reduce_dim |= 1 << i;
60+
}
61+
layer = TRT_ENGINE_ADD_LAYER(engine_, Reduce, *x,
62+
nvinfer1::ReduceOperation::kSUM, reduce_dim,
63+
keep_dim);
64+
} else {
65+
auto CvtToBitMask = [&](const std::vector<int32_t>& dims) -> uint32_t {
66+
uint32_t res = 0;
67+
for (auto x : dims) {
68+
if (x < 0) {
69+
res |= 1 << (x + input_dims);
70+
} else {
71+
res |= 1 << x;
72+
}
73+
}
74+
return res;
75+
};
76+
layer = TRT_ENGINE_ADD_LAYER(engine_, Reduce, *x,
77+
nvinfer1::ReduceOperation::kSUM,
78+
CvtToBitMask(dim), keep_dim);
79+
}
80+
81+
auto output_name = op_desc.Output("Out")[0];
82+
RreplenishLayerAndOutput(layer, "reduce_sum", {output_name}, test_mode);
83+
}
84+
};
85+
86+
} // namespace tensorrt
87+
} // namespace inference
88+
} // namespace paddle
89+
90+
REGISTER_TRT_OP_CONVERTER(reduce_sum, ReduceSumOpConverter);

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/inference/tensorrt/op_teller.h"
16+
1617
#include "paddle/fluid/framework/block_desc.h"
1718
#include "paddle/fluid/framework/data_layout.h"
1819

@@ -122,11 +123,13 @@ struct SimpleOpTypeSetTeller : public Teller {
122123
"flatten2",
123124
"flatten",
124125
"gather",
126+
"gather_nd",
125127
"yolo_box",
126128
"roi_align",
127129
"affine_channel",
128130
"nearest_interp",
129131
"anchor_generator",
132+
"reduce_sum",
130133
};
131134
};
132135

@@ -324,6 +327,30 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
324327
if (!with_dynamic_shape || desc.Input("Axis").size() > 0) return false;
325328
}
326329

330+
if (op_type == "gather_nd") {
331+
auto* block = desc.Block();
332+
auto x_var_name = desc.Input("X")[0];
333+
auto index_var_name = desc.Input("Index")[0];
334+
auto* x_var_desc = block->FindVar(x_var_name);
335+
auto* index_var_desc = block->FindVar(index_var_name);
336+
337+
// The index input must be int32 datatype.
338+
if (index_var_desc->GetDataType() !=
339+
paddle::framework::proto::VarType_Type::VarType_Type_INT32) {
340+
VLOG(3) << "gather_nd op Index input data type must be int32";
341+
return false;
342+
}
343+
344+
const auto index_shape = index_var_desc->GetShape();
345+
const auto x_shape = x_var_desc->GetShape();
346+
if (x_shape.size() != index_shape.size()) {
347+
VLOG(3) << "gather_nd op Index input dims size [" << index_shape.size()
348+
<< " ] not equal to x dims size [" << x_shape.size() << "]";
349+
return false;
350+
}
351+
if (!with_dynamic_shape) return false;
352+
}
353+
327354
if (op_type == "yolo_box") {
328355
if (with_dynamic_shape) return false;
329356
bool has_attrs =
@@ -658,6 +685,20 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
658685
}
659686
}
660687

688+
if (op_type == "reduce_sum") {
689+
if (!with_dynamic_shape) {
690+
VLOG(3) << "the reduce_sum does not support static shape yet";
691+
return false;
692+
}
693+
694+
if (!(desc.HasAttr("keep_dim") && desc.HasAttr("dim") &&
695+
desc.HasAttr("reduce_all"))) {
696+
VLOG(3) << "the reduce_sum does not have attr (keep_dim or dim or "
697+
"reduce_all)";
698+
return false;
699+
}
700+
}
701+
661702
if (op_type == "reshape" || op_type == "reshape2") {
662703
if (!desc.HasAttr("shape") || with_dynamic_shape) {
663704
return false;

paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ nv_library(tensorrt_plugin
88
anchor_generator_op_plugin.cu
99
yolo_box_op_plugin.cu
1010
roi_align_op_plugin.cu
11+
gather_nd_op_plugin.cu
1112
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
1213

1314
nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS

0 commit comments

Comments
 (0)