Skip to content

Commit bfced39

Browse files
authored
[Paddle-TRT] nearest_interp op (#31626)
* nearest_interp op converter w/ dynamic/static * fix data_layout include * add trt nearest unit_test * add nearest_interp NHWC test * update trt nearest interp nhwc testcase * remove asterisk for python2 compatibility * add empty line to prevent conflict * nearest_interp op converter w/ dynamic/static * fix data_layout include * add trt nearest unit_test * add nearest_interp NHWC test * update trt nearest interp nhwc testcase * remove asterisk for python2 compatibility * add empty line to prevent conflict * change the priority of out_h, out_w
1 parent 7ccf6b6 commit bfced39

File tree

5 files changed

+332
-0
lines changed

5 files changed

+332
-0
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,8 @@ USE_TRT_CONVERTER(scale);
11921192
USE_TRT_CONVERTER(stack);
11931193
USE_TRT_CONVERTER(clip);
11941194
USE_TRT_CONVERTER(gather);
1195+
1196+
USE_TRT_CONVERTER(nearest_interp);
11951197
#endif
11961198

11971199
namespace paddle_infer {

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ nv_library(tensorrt_converter
66
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
77
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
88
gather_op.cc
9+
10+
nearest_interp_op.cc
911
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
1012

1113
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/framework/data_layout.h"
13+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
14+
15+
namespace paddle {
16+
namespace framework {
17+
class Scope;
18+
namespace proto {
19+
class OpDesc;
20+
} // namespace proto
21+
} // namespace framework
22+
} // namespace paddle
23+
24+
namespace paddle {
25+
namespace inference {
26+
namespace tensorrt {
27+
28+
class NearestInterpolateOpConverter : public OpConverter {
29+
public:
30+
void operator()(const framework::proto::OpDesc& op,
31+
const framework::Scope& scope, bool test_mode) override {
32+
VLOG(3) << "convert a fluid nearest_interp op";
33+
34+
framework::OpDesc op_desc(op, nullptr);
35+
36+
std::string input_name = op_desc.Input("X").front();
37+
std::string output_name = op_desc.Output("Out").front();
38+
39+
auto input = engine_->GetITensor(input_name);
40+
41+
auto data_layout = framework::StringToDataLayout(
42+
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout")));
43+
auto interp_method =
44+
BOOST_GET_CONST(std::string, op_desc.GetAttr("interp_method"));
45+
bool align_corners =
46+
BOOST_GET_CONST(bool, op_desc.GetAttr("align_corners"));
47+
48+
auto input_names = op_desc.Input("X");
49+
auto scale = BOOST_GET_CONST(float, op_desc.GetAttr("scale"));
50+
auto out_h = BOOST_GET_CONST(int, op_desc.GetAttr("out_h"));
51+
auto out_w = BOOST_GET_CONST(int, op_desc.GetAttr("out_w"));
52+
53+
auto layer = TRT_ENGINE_ADD_LAYER(engine_, Resize, *input);
54+
layer->setAlignCorners(align_corners);
55+
56+
auto in_dim = input->getDimensions();
57+
58+
float scale_h = 1.f;
59+
float scale_w = 1.f;
60+
61+
std::vector<float> scales;
62+
63+
if (scale > 0.f && (out_h <= 0 && out_w <= 0)) {
64+
scale_h = scale;
65+
scale_w = scale;
66+
} else {
67+
// axis are different in static/dynamic mode
68+
PADDLE_ENFORCE_GT(
69+
out_h, 0, platform::errors::InvalidArgument(
70+
"out_h must be greater than 0 if scale is not set."));
71+
PADDLE_ENFORCE_GT(
72+
out_w, 0, platform::errors::InvalidArgument(
73+
"out_w must be greater than 0 if scale is not set."));
74+
75+
bool with_dynamic = engine_->with_dynamic_shape();
76+
77+
int h_axis = (data_layout == framework::DataLayout::kNCHW) + with_dynamic;
78+
int w_axis =
79+
(data_layout == framework::DataLayout::kNCHW) + 1 + with_dynamic;
80+
81+
scale_h =
82+
static_cast<float>(out_h) / static_cast<float>(in_dim.d[h_axis]);
83+
scale_w =
84+
static_cast<float>(out_w) / static_cast<float>(in_dim.d[w_axis]);
85+
}
86+
87+
if (engine_->with_dynamic_shape()) {
88+
scales.push_back(1.f);
89+
}
90+
91+
if (data_layout == framework::DataLayout::kNCHW) {
92+
scales.push_back(1.f);
93+
scales.push_back(scale_h);
94+
scales.push_back(scale_w);
95+
} else if (data_layout == framework::DataLayout::kNHWC) {
96+
// NHWC
97+
scales.push_back(scale_h);
98+
scales.push_back(scale_w);
99+
scales.push_back(1.f);
100+
} else {
101+
PADDLE_THROW(platform::errors::InvalidArgument(
102+
"Data layout must be NCHW or NHWC."));
103+
}
104+
layer->setScales(scales.data(), scales.size());
105+
106+
RreplenishLayerAndOutput(layer, "nearest_interp", {output_name}, test_mode);
107+
}
108+
};
109+
110+
} // namespace tensorrt
111+
} // namespace inference
112+
} // namespace paddle
113+
114+
REGISTER_TRT_OP_CONVERTER(nearest_interp, NearestInterpolateOpConverter);

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/fluid/inference/tensorrt/op_teller.h"
1616
#include "paddle/fluid/framework/block_desc.h"
17+
#include "paddle/fluid/framework/data_layout.h"
1718

1819
namespace paddle {
1920
namespace framework {
@@ -110,6 +111,8 @@ struct SimpleOpTypeSetTeller : public Teller {
110111
"flatten2",
111112
"flatten",
112113
"gather",
114+
115+
"nearest_interp",
113116
};
114117
};
115118

@@ -187,10 +190,29 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
187190
if (axis != 1) return false;
188191
}
189192
}
193+
190194
if (op_type == "gather") {
191195
// current not support axis from input, use default 0
192196
if (!with_dynamic_shape || desc.Input("Axis").size() > 0) return false;
193197
}
198+
199+
if (op_type == "nearest_interp") {
200+
std::vector<std::string> attrs{"data_layout", "interp_method",
201+
"align_corners", "scale",
202+
"out_h", "out_w"};
203+
for (auto const attr : attrs) {
204+
if (!desc.HasAttr(attr)) return false;
205+
}
206+
auto data_layout = framework::StringToDataLayout(
207+
BOOST_GET_CONST(std::string, desc.GetAttr("data_layout")));
208+
if (data_layout != framework::DataLayout::kNCHW &&
209+
data_layout != framework::DataLayout::kNHWC)
210+
return false;
211+
auto interp_method =
212+
BOOST_GET_CONST(std::string, desc.GetAttr("interp_method"));
213+
if (interp_method != "nearest") return false;
214+
}
215+
194216
if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
195217
}
196218
return false;
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from inference_pass_test import InferencePassTest
20+
import paddle.fluid as fluid
21+
import paddle.fluid.core as core
22+
from paddle.fluid.core import PassVersionChecker
23+
from paddle.fluid.core import AnalysisConfig
24+
25+
26+
class TRTNearestInterpTest(InferencePassTest):
27+
def setUp(self):
28+
self.set_params()
29+
30+
with fluid.program_guard(self.main_program, self.startup_program):
31+
if self.data_layout == 'NCHW':
32+
shape = [
33+
-1, self.channels, self.origin_shape[0],
34+
self.origin_shape[1]
35+
]
36+
else:
37+
shape = [
38+
-1, self.origin_shape[0], self.origin_shape[1],
39+
self.channels
40+
]
41+
data = fluid.data(name='data', shape=shape, dtype='float32')
42+
resize_out = self.append_nearest_interp(data)
43+
out = fluid.layers.batch_norm(resize_out, is_test=True)
44+
45+
if self.data_layout == 'NCHW':
46+
shape = [
47+
self.bs, self.channels, self.origin_shape[0],
48+
self.origin_shape[1]
49+
]
50+
else:
51+
shape = [
52+
self.bs, self.origin_shape[0], self.origin_shape[1],
53+
self.channels
54+
]
55+
56+
self.feeds = {'data': np.random.random(shape).astype('float32'), }
57+
self.enable_trt = True
58+
self.trt_parameters = TRTNearestInterpTest.TensorRTParam(
59+
1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False)
60+
self.fetch_list = [out]
61+
62+
def set_params(self):
63+
self.bs = 4
64+
self.scale = 1
65+
self.channels = 3
66+
self.origin_shape = (32, 32) # HW
67+
self.resize_shape = (64, 64) # HW
68+
self.align_corners = True
69+
self.data_layout = 'NCHW'
70+
71+
def append_nearest_interp(self, data):
72+
if self.scale > 0.:
73+
return fluid.layers.resize_nearest(
74+
data,
75+
scale=self.scale,
76+
align_corners=self.align_corners,
77+
data_format=self.data_layout)
78+
return fluid.layers.resize_nearest(
79+
data,
80+
out_shape=self.resize_shape,
81+
align_corners=self.align_corners,
82+
data_format=self.data_layout)
83+
84+
def test_check_output(self):
85+
if core.is_compiled_with_cuda():
86+
use_gpu = True
87+
self.check_output_with_option(use_gpu, flatten=True)
88+
self.assertTrue(
89+
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
90+
91+
92+
class TRTNearestInterpTest1(TRTNearestInterpTest):
93+
def set_params(self):
94+
self.bs = 4
95+
self.scale = -1
96+
self.channels = 3
97+
self.origin_shape = (32, 32) # HW
98+
self.resize_shape = (64, 64) # HW
99+
self.align_corners = True
100+
self.data_layout = 'NCHW'
101+
102+
103+
class TRTNearestInterpTest2(TRTNearestInterpTest):
104+
def set_params(self):
105+
self.bs = 4
106+
self.scale = 2.
107+
self.channels = 3
108+
self.origin_shape = (32, 32) # HW
109+
self.resize_shape = (64, 64) # HW
110+
self.align_corners = False
111+
self.data_layout = 'NCHW'
112+
113+
114+
class TRTNearestInterpTest3(TRTNearestInterpTest):
115+
def set_params(self):
116+
self.bs = 4
117+
self.scale = -1
118+
self.channels = 3
119+
self.origin_shape = (32, 32) # HW
120+
self.resize_shape = (64, 64) # HW
121+
self.align_corners = False
122+
self.data_layout = 'NCHW'
123+
124+
125+
class TRTNearestInterpTest4(TRTNearestInterpTest):
126+
def set_params(self):
127+
self.bs = 4
128+
self.scale = -1
129+
self.channels = 3
130+
self.origin_shape = (32, 32) # HW
131+
self.resize_shape = (47, 48) # HW
132+
self.align_corners = False
133+
self.data_layout = 'NCHW'
134+
135+
136+
class TRTNearestInterpTest5(TRTNearestInterpTest):
137+
def set_params(self):
138+
self.bs = 4
139+
self.scale = -1
140+
self.channels = 3
141+
self.origin_shape = (32, 32) # HW
142+
self.resize_shape = (64, 64) # HW
143+
self.align_corners = True
144+
self.data_layout = 'NHWC'
145+
146+
147+
class TRTNearestInterpTest6(TRTNearestInterpTest):
148+
def set_params(self):
149+
self.bs = 4
150+
self.scale = 2.
151+
self.channels = 3
152+
self.origin_shape = (32, 32) # HW
153+
self.resize_shape = (64, 64) # HW
154+
self.align_corners = False
155+
self.data_layout = 'NHWC'
156+
157+
158+
class TRTNearestInterpTest7(TRTNearestInterpTest):
159+
def set_params(self):
160+
self.bs = 4
161+
self.scale = -1
162+
self.channels = 3
163+
self.origin_shape = (32, 32) # HW
164+
self.resize_shape = (64, 64) # HW
165+
self.align_corners = False
166+
self.data_layout = 'NHWC'
167+
168+
169+
class TRTNearestInterpTest8(TRTNearestInterpTest):
170+
def set_params(self):
171+
self.bs = 4
172+
self.scale = -1
173+
self.channels = 3
174+
self.origin_shape = (32, 32) # HW
175+
self.resize_shape = (47, 48) # HW
176+
self.align_corners = False
177+
self.data_layout = 'NHWC'
178+
179+
180+
class TRTNearestInterpTest9(TRTNearestInterpTest):
181+
def set_params(self):
182+
self.bs = 4
183+
self.scale = -1
184+
self.channels = 3
185+
self.origin_shape = (32, 32) # HW
186+
self.resize_shape = (47, 48) # HW
187+
self.align_corners = False
188+
self.data_layout = 'NHWC'
189+
190+
191+
if __name__ == "__main__":
192+
unittest.main()

0 commit comments

Comments
 (0)