Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest


class TrtConvertPool2dTest(TrtLayerAutoScanTest):
Expand All @@ -32,6 +33,10 @@ def is_paddings_valid(self, program_config: ProgramConfig) -> bool:
for index in range(len(ksize)):
if ksize[index] <= paddings[index]:
return False
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
if program_config.ops[0].attrs['pooling_type'] == 'avg':
return False
return True

def is_program_valid(self, program_config: ProgramConfig) -> bool:
Expand All @@ -46,16 +51,16 @@ def generate_input1(attrs: List[Dict[str, Any]]):
def generate_weight1(attrs: List[Dict[str, Any]]):
return np.random.random([24, 3, 3, 3]).astype(np.float32)

for strides in [[1, 1], [2, 2], [1, 2]]:
for strides in [[1, 1], [1, 2], [2, 2]]:
for paddings in [[0, 2], [0, 3], [0, 1, 2, 3]]:
for pooling_type in ['max', 'avg']:
for padding_algotithm in ['EXPLICIT', 'SAME', 'VAILD']:
for ksize in [[2, 3], [3, 3]]:
for data_format in ['NCHW']:
for global_pooling in [True, False]:
for exclusive in [True, False]:
for exclusive in [False, True]:
for adaptive in [True, False]:
for ceil_mode in [True, False]:
for ceil_mode in [False, True]:

dics = [{
"pooling_type":
Expand Down Expand Up @@ -157,6 +162,29 @@ def teller2(program_config, predictor_config):
teller2, SkipReasons.TRT_NOT_IMPLEMENTED,
"It is not support that global_pooling is true for trt now.")

def teller3(program_config, predictor_config):
if self.dynamic_shape.min_input_shape == {} and program_config.ops[
0].attrs['ceil_mode'] == True:
return True
return False

self.add_skip_case(
teller3, SkipReasons.TRT_NOT_IMPLEMENTED,
"It is not support that ceil_mode is true in static mode for trt now."
)

def teller4(program_config, predictor_config):
if self.dynamic_shape.min_input_shape != {} and (
program_config.ops[0].attrs['strides'] == [1, 2] or
program_config.ops[0].attrs['strides'] == [2, 2]):
return True
return False

self.add_skip_case(
teller4, SkipReasons.TRT_NOT_IMPLEMENTED,
"It is not support that strides is not equal [1, 1] in dynamic mode for trt now."
)

def test(self):
self.add_skip_trt_case()
self.run_test()
Expand Down