Skip to content

Commit 5e153bf

Browse files
authored
add test (#35568)
1 parent ccf5b80 commit 5e153bf

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
16+
from program_config import TensorConfig, ProgramConfig
17+
import numpy as np
18+
import paddle.inference as paddle_infer
19+
from functools import partial
20+
from typing import Optional, List, Callable, Dict, Any, Set
21+
import unittest
22+
23+
24+
class TrtConvertTileTest(TrtLayerAutoScanTest):
25+
def is_program_valid(self, program_config: ProgramConfig) -> bool:
26+
inputs = program_config.inputs
27+
attrs = [
28+
program_config.ops[i].attrs
29+
for i in range(len(program_config.ops))
30+
]
31+
for x in attrs[0]['repeat_times']:
32+
if x <= 0:
33+
return False
34+
35+
return True
36+
37+
def sample_program_configs(self):
38+
def generate_input1(attrs: List[Dict[str, Any]]):
39+
return np.ones([1, 2, 3, 4]).astype(np.float32)
40+
41+
for repeat_times in [[100], [1, 2], [0, 3], [1, 2, 100]]:
42+
dics = [{"repeat_times": repeat_times}]
43+
44+
ops_config = [{
45+
"op_type": "tile",
46+
"op_inputs": {
47+
"X": ["input_data"]
48+
},
49+
"op_outputs": {
50+
"Out": ["tile_output_data"]
51+
},
52+
"op_attrs": dics[0]
53+
}]
54+
ops = self.generate_op_config(ops_config)
55+
56+
program_config = ProgramConfig(
57+
ops=ops,
58+
weights={},
59+
inputs={
60+
"input_data": TensorConfig(data_gen=partial(generate_input1,
61+
dics))
62+
},
63+
outputs=["tile_output_data"])
64+
65+
yield program_config
66+
67+
def sample_predictor_configs(
68+
self, program_config) -> (paddle_infer.Config, List[int], float):
69+
def generate_dynamic_shape(attrs):
70+
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
71+
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
72+
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
73+
74+
def clear_dynamic_shape():
75+
self.dynamic_shape.min_input_shape = {}
76+
self.dynamic_shape.max_input_shape = {}
77+
self.dynamic_shape.opt_input_shape = {}
78+
79+
def generate_trt_nodes_num(attrs, dynamic_shape):
80+
if dynamic_shape == True:
81+
return 0, 3
82+
else:
83+
return 1, 2
84+
85+
attrs = [
86+
program_config.ops[i].attrs
87+
for i in range(len(program_config.ops))
88+
]
89+
90+
# for static_shape
91+
clear_dynamic_shape()
92+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
93+
yield self.create_inference_config(), generate_trt_nodes_num(
94+
attrs, False), 1e-5
95+
self.trt_param.precision = paddle_infer.PrecisionType.Half
96+
yield self.create_inference_config(), generate_trt_nodes_num(
97+
attrs, False), 1e-4
98+
99+
# for dynamic_shape
100+
generate_dynamic_shape(attrs)
101+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
102+
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
103+
True), 1e-5
104+
self.trt_param.precision = paddle_infer.PrecisionType.Half
105+
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
106+
True), 1e-4
107+
108+
def test(self):
109+
self.run_test()
110+
111+
112+
if __name__ == "__main__":
113+
unittest.main()

0 commit comments

Comments
 (0)