Skip to content

trt convert ut add dynamic_shape and int8, etc.#35061

Merged
jiweibo merged 5 commits intoPaddlePaddle:developfrom
jiweibo:ut_convert
Aug 23, 2021
Merged

trt convert ut add dynamic_shape and int8, etc.#35061
jiweibo merged 5 commits intoPaddlePaddle:developfrom
jiweibo:ut_convert

Conversation

@jiweibo
Copy link
Contributor

@jiweibo jiweibo commented Aug 22, 2021

PR types

Others

PR changes

Others

Describe

Inferecne trt convert ut support int8, dynamic_shape, and engine num check.


1、组网

class TrtConvertConv2dTest(TrtLayerAutoScanTest):
    def setUp(self):
        self.ops_config = [{
            "op_type": "conv2d",
            "op_inputs": {
                "Input": ["input_data"],
                "Filter": ["conv2d_weight"]
            },
            "op_outputs": {
                "Output": ["conv_output_data"]
            },
            "op_attrs": {
                "data_format": ["NCHW"],
                "dilations": [[1, 1]],
                "padding_algorithm": ["EXPLICIT"],
                "groups": [1],
                "paddings": [[0, 3], [3, 1]],
                "strides": [[1, 1], [2, 2]],
            }
        }, {
            "op_type": "relu",
            "op_inputs": {
                "X": ["conv_output_data"]
            },
            "op_outputs": {
                "Out": ["relu_output_data"]
            },
            "op_attrs": {}
        }]
        self.batch_size_set = [1, 2, 4]

2、设置输入和权重

    def update_program_input_and_weight_with_attr(self, op_attr_list):
        weight = np.random.randn(24, 3, 3, 3).astype("float32")
        filter = TensorConfig(shape=[24, 3, 3, 3], data=weight)
        if op_attr_list[0]["data_format"] == "NCHW":
            input_data = TensorConfig(shape=[-1, 3, 64, 64])
        else:
            input_data = TensorConfig(shape=[-1, 64, 64, 3])
        self.program_weights = {"conv2d_weight": filter}
        self.program_inputs = {"input_data": input_data}
        self.program_outputs = ["relu_output_data"]

3、trt_fp32测试

    def test_check_fp32_output(self):
        self.trt_param.precision == paddle_infer.PrecisionType.Float32
        # the fused tensorrt engine num is 1, and paddle op num is 2(feed and fetch).
        self.run_test(trt_engine_num=1, paddle_op_num=2, threshold=1e-5)

4、trt_fp16测试

    def test_check_fp16_output(self):
        self.trt_param.precision == paddle_infer.PrecisionType.Half
        self.run_test(trt_engine_num=1, paddle_op_num=2, threshold=1e-2)

5、trt dynamic_shape fp32测试

    def test_dynamic_shape_fp32_check_output(self):
        self.trt_param.precision = paddle_infer.PrecisionType.Float32
        self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
        self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
        self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
        self.run_test(trt_engine_num=1, paddle_op_num=2, threshold=1e-5)

6、trt dynamic_shape fp16测试

    def test_dynamic_shape_fp16_check_output(self):
        self.trt_param.precision = paddle_infer.PrecisionType.Half
        self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
        self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
        self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
        self.run_test(trt_engine_num=1, paddle_op_num=2, threshold=1e-2)

7、如果算子支持量化,需要添加int8测试

    def test_trt_int8_check_output(self):
        self.trt_param.precision = paddle_infer.PrecisionType.Int8
        self.run_test(
            trt_engine_num=1, paddle_op_num=2, quant=True, threshold=1e-1)

注意

  • fp16的精度可能较低,int8的精度也非常低。
  • 模型存储的时候最后会加一层scale,该算子最后,可能会被fuse到trt engine中,也可能不会,需具体分析

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@jiweibo jiweibo requested a review from winter-wang August 22, 2021 10:23
@Shixiaowei02 Shixiaowei02 requested a review from T8T9 August 23, 2021 08:05
@jiweibo jiweibo merged commit 17188e8 into PaddlePaddle:develop Aug 23, 2021
@jiweibo jiweibo deleted the ut_convert branch August 23, 2021 11:04
@jiweibo jiweibo mentioned this pull request Aug 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants