Skip to content

Commit 1d46bc3

Browse files
committed
fix trt6
1 parent 0674fc3 commit 1d46bc3

File tree

4 files changed

+112
-33
lines changed

4 files changed

+112
-33
lines changed

paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,16 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
253253
for (size_t i = 0; i < axes_.size(); i++) {
254254
int start = starts_[i];
255255
int end = ends_[i];
256+
#if IS_TRT_VERSION_GE(7000)
256257
ret.d[axes_[i]] = expr_builder.operation(
257258
nvinfer1::DimensionOperation::kSUB,
258259
*expr_builder.operation(nvinfer1::DimensionOperation::kMIN,
259260
*expr_builder.constant(ends_[i]),
260261
*in_dims.d[axes_[i]]),
261262
*expr_builder.constant(start));
263+
#else
264+
ret.d[axes_[i]] = expr_builder.constant(end - start);
265+
#endif
262266
}
263267
return ret;
264268
}

python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@ if(WITH_GPU AND TENSORRT_FOUND)
1919
endforeach()
2020

2121
foreach(target ${TEST_TRT_IR_PASSES})
22-
py_test_modules(${target} MODULES ${target})
22+
if(${target} STREQUAL "test_trt_slice_dynamic_plugin")
23+
if(${TENSORRT_MAJOR_VERSION} VERSION_GREATER "6")
24+
py_test_modules(${target} MODULES ${target})
25+
endif()
26+
else()
27+
py_test_modules(${target} MODULES ${target})
28+
endif()
2329
endforeach()
2430

2531
foreach(target ${TEST_TRT_CONVERTER})
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from inference_pass_test import InferencePassTest
20+
import paddle.fluid as fluid
21+
import paddle.fluid.core as core
22+
from paddle.fluid.core import AnalysisConfig
23+
24+
25+
#normal starts && ends
26+
class SlicePluginTRTDynamicTest(SlicePluginTRTTest):
27+
def setUpSliceParams(self):
28+
self.params_axes = [1, 3]
29+
self.params_starts = [0, 1]
30+
self.params_ends = [2, 3]
31+
32+
def setUpTensorRTParams(self):
33+
self.trt_parameters = SlicePluginTRTTest.TensorRTParam(
34+
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False)
35+
self.enable_trt = True
36+
self.dynamic_shape_params = SlicePluginTRTDynamicTest.DynamicShapeParam(
37+
{
38+
'data': [1, 1, 1, 1]
39+
}, {'data': [8, 8, 8, 8]}, {'data': [8, 8, 8, 8]}, False)
40+
41+
def setUp(self):
42+
self.setUpSliceParams()
43+
self.setUpTensorRTParams()
44+
with fluid.program_guard(self.main_program, self.startup_program):
45+
data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="float32")
46+
axes = self.params_axes
47+
starts = self.params_starts
48+
ends = self.params_ends
49+
slice_out = fluid.layers.slice(
50+
data, axes=axes, starts=starts, ends=ends)
51+
52+
self.feeds = {
53+
"data": np.random.random((3, 3, 3, 3)).astype("float32"),
54+
}
55+
self.fetch_list = [slice_out]
56+
57+
def test_check_output(self):
58+
use_gpu = [False]
59+
if core.is_compiled_with_cuda():
60+
use_gpu.append(True)
61+
for i in range(len(use_gpu)):
62+
atol = 1e-5
63+
if self.trt_parameters.precision == AnalysisConfig.Precision.Half:
64+
atol = 1e-3
65+
self.check_output_with_option(use_gpu[i], atol)
66+
67+
68+
class SlicePluginTRTDynamicBoundTest(SlicePluginTRTDynamicTest):
69+
def setUpSliceParams(self):
70+
self.params_axes = [1, 3]
71+
self.params_starts = [0, 1]
72+
self.params_ends = [2, 1000]
73+
74+
def setUpTensorRTParams(self):
75+
self.trt_parameters = SlicePluginTRTDynamicBoundTest.TensorRTParam(
76+
1 << 30, 32, 1, AnalysisConfig.Precision.Half, False, False)
77+
self.enable_trt = True
78+
self.dynamic_shape_params = SlicePluginTRTDynamicBoundTest.DynamicShapeParam(
79+
{
80+
'data': [1, 1, 1, 1]
81+
}, {'data': [8, 8, 8, 8]}, {'data': [8, 8, 8, 8]}, False)
82+
83+
84+
class SlicePluginTRTDynamicNegativeBoundTest(SlicePluginTRTDynamicTest):
85+
def setUpSliceParams(self):
86+
self.params_axes = [1, 3]
87+
self.params_starts = [-5, 1]
88+
self.params_ends = [2, 1000]
89+
90+
def setUpTensorRTParams(self):
91+
self.trt_parameters = SlicePluginTRTDynamicNegativeBoundTest.TensorRTParam(
92+
1 << 30, 32, 1, AnalysisConfig.Precision.Half, False, False)
93+
self.enable_trt = True
94+
self.dynamic_shape_params = SlicePluginTRTDynamicNegativeBoundTest.DynamicShapeParam(
95+
{
96+
'data': [1, 1, 1, 1]
97+
}, {'data': [8, 8, 8, 8]}, {'data': [8, 8, 8, 8]}, False)
98+
99+
100+
if __name__ == "__main__":
101+
unittest.main()

python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -101,37 +101,5 @@ def setUpTensorRTParams(self):
101101
}, {'data': [8, 8, 8, 8]}, {'data': [8, 8, 8, 8]}, False)
102102

103103

104-
class SlicePluginTRTDynamicBoundTest(SlicePluginTRTTest):
105-
def setUpSliceParams(self):
106-
self.params_axes = [1, 3]
107-
self.params_starts = [0, 1]
108-
self.params_ends = [2, 1000]
109-
110-
def setUpTensorRTParams(self):
111-
self.trt_parameters = SlicePluginTRTDynamicBoundTest.TensorRTParam(
112-
1 << 30, 32, 1, AnalysisConfig.Precision.Half, False, False)
113-
self.enable_trt = True
114-
self.dynamic_shape_params = SlicePluginTRTDynamicBoundTest.DynamicShapeParam(
115-
{
116-
'data': [1, 1, 1, 1]
117-
}, {'data': [8, 8, 8, 8]}, {'data': [8, 8, 8, 8]}, False)
118-
119-
120-
class SlicePluginTRTDynamicNegativeBoundTest(SlicePluginTRTTest):
121-
def setUpSliceParams(self):
122-
self.params_axes = [1, 3]
123-
self.params_starts = [-5, 1]
124-
self.params_ends = [2, 1000]
125-
126-
def setUpTensorRTParams(self):
127-
self.trt_parameters = SlicePluginTRTDynamicNegativeBoundTest.TensorRTParam(
128-
1 << 30, 32, 1, AnalysisConfig.Precision.Half, False, False)
129-
self.enable_trt = True
130-
self.dynamic_shape_params = SlicePluginTRTDynamicNegativeBoundTest.DynamicShapeParam(
131-
{
132-
'data': [1, 1, 1, 1]
133-
}, {'data': [8, 8, 8, 8]}, {'data': [8, 8, 8, 8]}, False)
134-
135-
136104
if __name__ == "__main__":
137105
unittest.main()

0 commit comments

Comments
 (0)