Skip to content

Commit a6312b2

Browse files
xingjing1AnnaTrainingG
authored andcommitted
1 parent d215d2d commit a6312b2

File tree

3 files changed

+158
-10
lines changed

3 files changed

+158
-10
lines changed

paddle/fluid/inference/tensorrt/convert/nearest_interp_op.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,24 @@ class NearestInterpolateOpConverter : public OpConverter {
5959
float scale_w = 1.f;
6060

6161
std::vector<float> scales;
62-
63-
if (scale > 0.f && (out_h <= 0 && out_w <= 0)) {
62+
if (scale > 0.f) {
6463
scale_h = scale;
6564
scale_w = scale;
6665
} else {
6766
// axis are different in static/dynamic mode
6867
bool with_dynamic = engine_->with_dynamic_shape();
6968

70-
int h_axis = (data_layout == framework::DataLayout::kNCHW) + with_dynamic;
71-
int w_axis =
72-
(data_layout == framework::DataLayout::kNCHW) + 1 + with_dynamic;
73-
74-
scale_h =
75-
static_cast<float>(out_h) / static_cast<float>(in_dim.d[h_axis]);
76-
scale_w =
77-
static_cast<float>(out_w) / static_cast<float>(in_dim.d[w_axis]);
69+
if (!with_dynamic) {
70+
int h_axis =
71+
(data_layout == framework::DataLayout::kNCHW) + with_dynamic;
72+
int w_axis =
73+
(data_layout == framework::DataLayout::kNCHW) + 1 + with_dynamic;
74+
75+
scale_h =
76+
static_cast<float>(out_h) / static_cast<float>(in_dim.d[h_axis]);
77+
scale_w =
78+
static_cast<float>(out_w) / static_cast<float>(in_dim.d[w_axis]);
79+
}
7880
}
7981

8082
if (engine_->with_dynamic_shape()) {

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
476476
return false;
477477
}
478478
}
479+
if ((scale <= 0.f) && with_dynamic_shape) {
480+
VLOG(3) << "dynamic shape not support scale not set.";
481+
return false;
482+
}
479483
}
480484
}
481485

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
16+
from program_config import TensorConfig, ProgramConfig
17+
import numpy as np
18+
import paddle.inference as paddle_infer
19+
from functools import partial
20+
from typing import Optional, List, Callable, Dict, Any, Set
21+
import unittest
22+
23+
24+
class TrtConvertNearestInterpTest(TrtLayerAutoScanTest):
25+
def is_program_valid(self, program_config: ProgramConfig) -> bool:
26+
inputs = program_config.inputs
27+
weights = program_config.weights
28+
attrs = [
29+
program_config.ops[i].attrs
30+
for i in range(len(program_config.ops))
31+
]
32+
33+
if attrs[0]['scale'] <= 0 and (attrs[0]['out_h'] <= 0 or
34+
attrs[0]['out_w'] <= 0):
35+
return False
36+
if (attrs[0]['out_h'] <= 0) ^ (attrs[0]['out_w'] <= 0):
37+
return False
38+
39+
return True
40+
41+
def sample_program_configs(self):
42+
def generate_input1(attrs: List[Dict[str, Any]]):
43+
return np.ones([1, 3, 64, 64]).astype(np.float32)
44+
45+
for data_layout in ["NCHW", "NHWC"]:
46+
for interp_method in ["nearest"]:
47+
for align_corners in [True, False]:
48+
for scale in [2.0, -1.0, 0.0]:
49+
for out_h in [32, 64, 128 - 32]:
50+
for out_w in [32, -32]:
51+
dics = [{
52+
"data_layout": data_layout,
53+
"interp_method": interp_method,
54+
"align_corners": align_corners,
55+
"scale": scale,
56+
"out_h": out_h,
57+
"out_w": out_w
58+
}]
59+
60+
ops_config = [{
61+
"op_type": "nearest_interp",
62+
"op_inputs": {
63+
"X": ["input_data"]
64+
},
65+
"op_outputs": {
66+
"Out": ["nearest_interp_output_data"]
67+
},
68+
"op_attrs": dics[0]
69+
}]
70+
ops = self.generate_op_config(ops_config)
71+
72+
program_config = ProgramConfig(
73+
ops=ops,
74+
weights={},
75+
inputs={
76+
"input_data": TensorConfig(
77+
data_gen=partial(generate_input1,
78+
dics))
79+
},
80+
outputs=["nearest_interp_output_data"])
81+
82+
yield program_config
83+
84+
def sample_predictor_configs(
85+
self, program_config) -> (paddle_infer.Config, List[int], float):
86+
def generate_dynamic_shape(attrs):
87+
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
88+
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
89+
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
90+
91+
def clear_dynamic_shape():
92+
self.dynamic_shape.min_input_shape = {}
93+
self.dynamic_shape.max_input_shape = {}
94+
self.dynamic_shape.opt_input_shape = {}
95+
96+
def generate_trt_nodes_num(attrs, dynamic_shape):
97+
return 1, 2
98+
99+
attrs = [
100+
program_config.ops[i].attrs
101+
for i in range(len(program_config.ops))
102+
]
103+
104+
# for static_shape
105+
clear_dynamic_shape()
106+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
107+
yield self.create_inference_config(), generate_trt_nodes_num(
108+
attrs, False), 1e-5
109+
self.trt_param.precision = paddle_infer.PrecisionType.Half
110+
yield self.create_inference_config(), generate_trt_nodes_num(
111+
attrs, False), 1e-2
112+
113+
# for dynamic_shape
114+
generate_dynamic_shape(attrs)
115+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
116+
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
117+
True), 1e-5
118+
self.trt_param.precision = paddle_infer.PrecisionType.Half
119+
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
120+
True), 1e-2
121+
122+
def add_skip_trt_case(self):
123+
def teller1(program_config, predictor_config):
124+
if program_config.ops[0].attrs[
125+
'scale'] <= 0 and self.dynamic_shape.min_input_shape:
126+
return True
127+
return False
128+
129+
self.add_skip_case(
130+
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
131+
"NOT Implemented: we need to add support scale <= 0 in dynamic shape in the future"
132+
)
133+
134+
pass
135+
136+
def test(self):
137+
self.add_skip_trt_case()
138+
self.run_test()
139+
140+
141+
if __name__ == "__main__":
142+
unittest.main()

0 commit comments

Comments
 (0)