Skip to content

Commit 2ee0032

Browse files
xiaoxiaohehe001AnnaTrainingG
authored andcommitted
[Paddle Inference]Add BN op TRT converter unittest (PaddlePaddle#35527)
* add_bn_ * add_bn_teller * add_bn_teller * add_bn_teller * add_bn_teller
1 parent 99fdb20 commit 2ee0032

File tree

2 files changed

+226
-1
lines changed

2 files changed

+226
-1
lines changed

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
513513
return false;
514514
}
515515
}
516-
516+
auto batch_norm_inputs = desc.Inputs();
517+
if (batch_norm_inputs.find("MomentumTensor") != batch_norm_inputs.end()) {
518+
if (desc.Input("MomentumTensor").size() >= 1) {
519+
return false;
520+
}
521+
}
517522
if (desc.Output("Y").size() != 1) {
518523
VLOG(3) << "Invalid output Y's size of batch_norm TRT "
519524
"converter. Expected 1, received "
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
16+
from program_config import TensorConfig, ProgramConfig
17+
import numpy as np
18+
import paddle.inference as paddle_infer
19+
from functools import partial
20+
from typing import Optional, List, Callable, Dict, Any, Set
21+
22+
23+
class TrtConvertBatchNormTest(TrtLayerAutoScanTest):
24+
def is_program_valid(self, program_config: ProgramConfig) -> bool:
25+
return True
26+
27+
def sample_program_configs(self):
28+
def generate_input1(attrs: List[Dict[str, Any]], batch):
29+
if self.dims == 4:
30+
if attrs[0]['data_layout'] == "NCHW":
31+
return np.ones([batch, 3, 24, 24]).astype(np.float32)
32+
elif attrs[0]['data_layout'] == "NHWC":
33+
return np.ones([batch, 24, 24, 3]).astype(np.float32)
34+
elif self.dims == 3:
35+
return np.ones([batch, 3, 24]).astype(np.float32)
36+
elif self.dims == 2:
37+
return np.ones([batch, 3]).astype(np.float32)
38+
39+
def generate_bias(attrs: List[Dict[str, Any]], batch):
40+
return np.full((3), 0.9).astype("float32")
41+
42+
def generate_mean(attrs: List[Dict[str, Any]], batch):
43+
return np.full((3), 0.9).astype("float32")
44+
45+
def generate_scale(attrs: List[Dict[str, Any]], batch):
46+
return np.full((3), 1.1).astype("float32")
47+
48+
def generate_variance(attrs: List[Dict[str, Any]], batch):
49+
return np.full((3), 1.2).astype("float32")
50+
51+
def generate_MomentumTensor(attrs: List[Dict[str, Any]], batch):
52+
return np.full((3), 0.9).astype("float32")
53+
54+
for dims in [2, 3, 4]:
55+
for num_input in [0, 1]:
56+
for batch in [1, 2, 4]:
57+
for epsilon in [1e-6, 1e-5, 1e-4]:
58+
for data_layout in ["NCHW"]:
59+
for momentum in [0.9, 0.8]:
60+
self.num_input = num_input
61+
self.dims = dims
62+
dics = [{
63+
"epsilon": epsilon,
64+
"data_layout": data_layout,
65+
"momentum": momentum,
66+
"is_test": True,
67+
"trainable_statistics": False
68+
}, {}]
69+
dics_intput = [{
70+
"X": ["batch_norm_input"],
71+
"Bias": ["Bias"],
72+
"Mean": ["Mean"],
73+
"Scale": ["Scale"],
74+
"Variance": ["Variance"],
75+
"MomentumTensor": ["MomentumTensor"]
76+
}, {
77+
"X": ["batch_norm_input"],
78+
"Bias": ["Bias"],
79+
"Mean": ["Mean"],
80+
"Scale": ["Scale"],
81+
"Variance": ["Variance"]
82+
}]
83+
dics_intputs = [{
84+
"Bias": TensorConfig(data_gen=partial(
85+
generate_bias, dics, batch)),
86+
"Mean": TensorConfig(data_gen=partial(
87+
generate_mean, dics, batch)),
88+
"Scale": TensorConfig(data_gen=partial(
89+
generate_scale, dics, batch)),
90+
"Variance": TensorConfig(data_gen=partial(
91+
generate_variance, dics, batch)),
92+
"MomentumTensor":
93+
TensorConfig(data_gen=partial(
94+
generate_MomentumTensor, dics, batch)),
95+
}, {
96+
"Bias": TensorConfig(data_gen=partial(
97+
generate_bias, dics, batch)),
98+
"Mean": TensorConfig(data_gen=partial(
99+
generate_mean, dics, batch)),
100+
"Scale": TensorConfig(data_gen=partial(
101+
generate_scale, dics, batch)),
102+
"Variance": TensorConfig(data_gen=partial(
103+
generate_variance, dics, batch))
104+
}]
105+
ops_config = [{
106+
"op_type": "batch_norm",
107+
"op_inputs": dics_intput[num_input],
108+
"op_outputs": {
109+
"Y": ["batch_norm_out"],
110+
"MeanOut": ["Mean"],
111+
"VarianceOut": ["Variance"],
112+
"SavedMean": ["SavedMean"],
113+
"SavedVariance": ["SavedVariance"]
114+
},
115+
"op_attrs": dics[0]
116+
}]
117+
ops = self.generate_op_config(ops_config)
118+
program_config = ProgramConfig(
119+
ops=ops,
120+
weights=dics_intputs[num_input],
121+
inputs={
122+
"batch_norm_input": TensorConfig(
123+
data_gen=partial(generate_input1,
124+
dics, batch))
125+
},
126+
outputs=["batch_norm_out"])
127+
128+
yield program_config
129+
130+
def sample_predictor_configs(
131+
self, program_config) -> (paddle_infer.Config, List[int], float):
132+
def generate_dynamic_shape(attrs):
133+
if self.dims == 4:
134+
if attrs[0]['data_layout'] == "NCHW":
135+
self.dynamic_shape.min_input_shape = {
136+
"batch_norm_input": [1, 3, 24, 24]
137+
}
138+
self.dynamic_shape.max_input_shape = {
139+
"batch_norm_input": [4, 3, 48, 48]
140+
}
141+
self.dynamic_shape.opt_input_shape = {
142+
"batch_norm_input": [1, 3, 24, 48]
143+
}
144+
elif attrs[0]['data_layout'] == "NHWC":
145+
self.dynamic_shape.min_input_shape = {
146+
"batch_norm_input": [1, 24, 24, 3]
147+
}
148+
self.dynamic_shape.max_input_shape = {
149+
"batch_norm_input": [4, 48, 48, 3]
150+
}
151+
self.dynamic_shape.opt_input_shape = {
152+
"batch_norm_input": [1, 24, 48, 3]
153+
}
154+
elif self.dims == 3:
155+
self.dynamic_shape.min_input_shape = {
156+
"batch_norm_input": [1, 3, 24]
157+
}
158+
self.dynamic_shape.max_input_shape = {
159+
"batch_norm_input": [4, 3, 48]
160+
}
161+
self.dynamic_shape.opt_input_shape = {
162+
"batch_norm_input": [1, 3, 48]
163+
}
164+
elif self.dims == 2:
165+
self.dynamic_shape.min_input_shape = {
166+
"batch_norm_input": [1, 3]
167+
}
168+
self.dynamic_shape.max_input_shape = {
169+
"batch_norm_input": [4, 3]
170+
}
171+
self.dynamic_shape.opt_input_shape = {
172+
"batch_norm_input": [1, 3]
173+
}
174+
175+
def clear_dynamic_shape():
176+
self.dynamic_shape.min_input_shape = {}
177+
self.dynamic_shape.max_input_shape = {}
178+
self.dynamic_shape.opt_input_shape = {}
179+
180+
def generate_trt_nodes_num(attrs, dynamic_shape):
181+
return 1, 2
182+
183+
attrs = [
184+
program_config.ops[i].attrs
185+
for i in range(len(program_config.ops))
186+
]
187+
# for static_shape
188+
clear_dynamic_shape()
189+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
190+
yield self.create_inference_config(), generate_trt_nodes_num(
191+
attrs, False), 1e-5
192+
self.trt_param.precision = paddle_infer.PrecisionType.Half
193+
yield self.create_inference_config(), generate_trt_nodes_num(
194+
attrs, False), 1e-5
195+
196+
# for dynamic_shape
197+
generate_dynamic_shape(attrs)
198+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
199+
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
200+
True), 1e-5
201+
self.trt_param.precision = paddle_infer.PrecisionType.Half
202+
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
203+
True), 1e-5
204+
205+
def add_skip_trt_case(self):
206+
def teller1(program_config, predictor_config):
207+
if len(program_config.weights) == 5:
208+
return True
209+
return False
210+
211+
self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT,
212+
"INPUT MomentumTensor NOT SUPPORT")
213+
214+
def test(self):
215+
self.add_skip_trt_case()
216+
self.run_test()
217+
218+
219+
if __name__ == "__main__":
220+
unittest.main()

0 commit comments

Comments
 (0)