diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py index 9ec2f83fa5ba0a..ddb96c37db780c 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py @@ -18,6 +18,7 @@ import paddle.inference as paddle_infer from functools import partial from typing import Optional, List, Callable, Dict, Any, Set +import unittest class TrtConvertPool2dTest(TrtLayerAutoScanTest): @@ -32,6 +33,10 @@ def is_paddings_valid(self, program_config: ProgramConfig) -> bool: for index in range(len(ksize)): if ksize[index] <= paddings[index]: return False + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000: + if program_config.ops[0].attrs['pooling_type'] == 'avg': + return False return True def is_program_valid(self, program_config: ProgramConfig) -> bool: @@ -46,16 +51,16 @@ def generate_input1(attrs: List[Dict[str, Any]]): def generate_weight1(attrs: List[Dict[str, Any]]): return np.random.random([24, 3, 3, 3]).astype(np.float32) - for strides in [[1, 1], [2, 2], [1, 2]]: + for strides in [[1, 1], [1, 2], [2, 2]]: for paddings in [[0, 2], [0, 3], [0, 1, 2, 3]]: for pooling_type in ['max', 'avg']: for padding_algotithm in ['EXPLICIT', 'SAME', 'VAILD']: for ksize in [[2, 3], [3, 3]]: for data_format in ['NCHW']: for global_pooling in [True, False]: - for exclusive in [True, False]: + for exclusive in [False, True]: for adaptive in [True, False]: - for ceil_mode in [True, False]: + for ceil_mode in [False, True]: dics = [{ "pooling_type": @@ -157,6 +162,29 @@ def teller2(program_config, predictor_config): teller2, SkipReasons.TRT_NOT_IMPLEMENTED, "It is not support that global_pooling is true for trt now.") + def teller3(program_config, predictor_config): + if self.dynamic_shape.min_input_shape == {} and program_config.ops[ + 0].attrs['ceil_mode'] == True: + return True + return False + + self.add_skip_case( + teller3, SkipReasons.TRT_NOT_IMPLEMENTED, + "It is not support that ceil_mode is true in static mode for trt now." + ) + + def teller4(program_config, predictor_config): + if self.dynamic_shape.min_input_shape != {} and ( + program_config.ops[0].attrs['strides'] == [1, 2] or + program_config.ops[0].attrs['strides'] == [2, 2]): + return True + return False + + self.add_skip_case( + teller4, SkipReasons.TRT_NOT_IMPLEMENTED, + "It is not support that strides is not equal [1, 1] in dynamic mode for trt now." + ) + def test(self): self.add_skip_trt_case() self.run_test()