Skip to content

Commit 6535661

Browse files
committed
resolve conflict
1 parent 88cd27a commit 6535661

File tree

4 files changed

+105
-2
lines changed

4 files changed

+105
-2
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,7 @@ USE_TRT_CONVERTER(scale);
11921192
USE_TRT_CONVERTER(stack);
11931193
USE_TRT_CONVERTER(clip);
11941194
USE_TRT_CONVERTER(gather);
1195-
1195+
USE_TRT_CONVERTER(affine_channel);
11961196
USE_TRT_CONVERTER(multiclass_nms);
11971197

11981198
USE_TRT_CONVERTER(nearest_interp);

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ nv_library(tensorrt_converter
66
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
77
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
88
gather_op.cc
9+
affine_channel_op.cc
910
multiclass_nms_op.cc
1011
nearest_interp_op.cc
1112
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/data_layout.h"
16+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
class Scope;
21+
22+
namespace proto {
23+
class OpDesc;
24+
} // namespace proto
25+
} // namespace framework
26+
} // namespace paddle
27+
28+
namespace paddle {
29+
namespace inference {
30+
namespace tensorrt {
31+
32+
/*
33+
* Affine Channel Op
34+
*/
35+
class AffineChannelOpConverter : public OpConverter {
36+
public:
37+
void operator()(const framework::proto::OpDesc& op,
38+
const framework::Scope& scope, bool test_mode) override {
39+
VLOG(3) << "convert a fluid affine_channel op to tensorrt scale nd layer";
40+
41+
framework::OpDesc op_desc(op, nullptr);
42+
std::string input_name = op_desc.Input("X").front();
43+
std::string scale_name = op_desc.Input("Scale").front();
44+
std::string bias_name = op_desc.Input("Bias").front();
45+
std::string output_name = op_desc.Output("Out").front();
46+
47+
auto input_tensor = engine_->GetITensor(input_name);
48+
auto idim = input_tensor->getDimensions();
49+
50+
auto* scale_v = scope.FindVar(scale_name);
51+
auto* scale_t = scale_v->GetMutable<framework::LoDTensor>();
52+
float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t, false);
53+
54+
auto* bias_v = scope.FindVar(bias_name);
55+
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
56+
float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false);
57+
58+
auto data_layout = framework::StringToDataLayout(
59+
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout")));
60+
61+
PADDLE_ENFORCE_EQ(
62+
data_layout, framework::DataLayout::kNCHW,
63+
platform::errors::InvalidArgument(
64+
"TensorRT affine channel converter can only convert NCHW format. "
65+
"Other format should be run in fluid mode. Report a bug on github "
66+
"issue if you see this line."));
67+
68+
// tensorrt scalend layer only support spatial dims >= 2,
69+
// so nhwc is not availabe (spatial dims == 0)
70+
const int channel_axis = engine_->with_dynamic_shape();
71+
72+
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT,
73+
static_cast<void*>(scale_ptr),
74+
(size_t)idim.d[channel_axis]};
75+
TensorRTEngine::Weight bias_weights{nvinfer1::DataType::kFLOAT,
76+
static_cast<void*>(bias_ptr),
77+
(size_t)idim.d[channel_axis]};
78+
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
79+
0};
80+
81+
auto layer = TRT_ENGINE_ADD_LAYER(engine_, ScaleNd, *input_tensor,
82+
nvinfer1::ScaleMode::kCHANNEL,
83+
bias_weights.get(), scale_weights.get(),
84+
power_weights.get(), channel_axis);
85+
86+
RreplenishLayerAndOutput(layer, "affine_channel", {output_name}, test_mode);
87+
}
88+
};
89+
90+
} // namespace tensorrt
91+
} // namespace inference
92+
} // namespace paddle
93+
94+
REGISTER_TRT_OP_CONVERTER(affine_channel, AffineChannelOpConverter);

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ struct SimpleOpTypeSetTeller : public Teller {
111111
"flatten2",
112112
"flatten",
113113
"gather",
114+
"affine_channel",
114115
"multiclass_nms",
115116
"nearest_interp",
116117
};
@@ -196,6 +197,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
196197
if (!with_dynamic_shape || desc.Input("Axis").size() > 0) return false;
197198
}
198199

200+
if (op_type == "affine_channel") {
201+
if (!desc.HasAttr("data_layout")) return false;
202+
auto data_layout = framework::StringToDataLayout(
203+
BOOST_GET_CONST(std::string, desc.GetAttr("data_layout")));
204+
if (data_layout != framework::DataLayout::kNCHW) return false;
205+
}
206+
199207
if (op_type == "multiclass_nms") {
200208
if (with_dynamic_shape) return false;
201209
auto* block = desc.Block();
@@ -238,6 +246,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
238246
return false;
239247
}
240248
}
249+
241250
if (op_type == "nearest_interp") {
242251
std::vector<std::string> attrs{"data_layout", "interp_method",
243252
"align_corners", "scale",
@@ -254,7 +263,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
254263
BOOST_GET_CONST(std::string, desc.GetAttr("interp_method"));
255264
if (interp_method != "nearest") return false;
256265
}
257-
258266
if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
259267
}
260268
return false;

0 commit comments

Comments
 (0)