Skip to content

Commit d91352c

Browse files
author
Pei Yang
authored
[Paddle-TRT]Fix flatten converter when batch_size > 1 (#33768)
* fix trt flatten converter when batch_size > 1 * change ut to same dynamic shape
1 parent 0f59d4e commit d91352c

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

paddle/fluid/inference/tensorrt/convert/flatten_op.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,19 @@ class FlattenOpConverter : public OpConverter {
5353
layer->setReshapeDimensions(flatten_dim);
5454
} else {
5555
auto* shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
56+
nvinfer1::Dims start_dim, size_dim, stride_dim;
57+
start_dim.nbDims = 1;
58+
size_dim.nbDims = 1;
59+
stride_dim.nbDims = 1;
60+
start_dim.d[0] = 1;
61+
size_dim.d[0] = dims - 1;
62+
stride_dim.d[0] = 1;
63+
auto* slice_layer =
64+
TRT_ENGINE_ADD_LAYER(engine_, Slice, *(shape_layer->getOutput(0)),
65+
start_dim, size_dim, stride_dim);
5666
uint32_t reduce_dim = 1;
57-
5867
auto* reduce_prod_layer = TRT_ENGINE_ADD_LAYER(
59-
engine_, Reduce, *(shape_layer->getOutput(0)),
68+
engine_, Reduce, *(slice_layer->getOutput(0)),
6069
nvinfer1::ReduceOperation::kPROD, reduce_dim, true);
6170
int32_t* constant_weight_data = new int32_t[1];
6271
constant_weight_data[0] = -1;

paddle/fluid/inference/tensorrt/helper.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,19 @@ inline size_t ProductDim(const nvinfer1::Dims& dims) {
134134
return v;
135135
}
136136

137+
inline void PrintITensorShape(nvinfer1::ITensor* X) {
138+
auto dims = X->getDimensions();
139+
auto name = X->getName();
140+
std::cout << "ITensor " << name << " shape: [";
141+
for (int i = 0; i < dims.nbDims; i++) {
142+
if (i == dims.nbDims - 1)
143+
std::cout << dims.d[i];
144+
else
145+
std::cout << dims.d[i] << ", ";
146+
}
147+
std::cout << "]\n";
148+
}
149+
137150
} // namespace tensorrt
138151
} // namespace inference
139152
} // namespace paddle

python/paddle/fluid/tests/unittests/ir/inference/test_trt_flatten_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ def setUp(self):
6363
self.trt_parameters = TRTFlattenDynamicTest.TensorRTParam(
6464
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
6565
self.dynamic_shape_params = TRTFlattenDynamicTest.DynamicShapeParam({
66-
'data': [1, 6, 8, 8],
67-
'flatten_0.tmp_0': [1, 6 * 8 * 8]
68-
}, {'data': [3, 6, 128, 128],
69-
'flatten_0.tmp_0': [3, 6 * 128 * 128]}, {
66+
'data': [2, 6, 64, 64],
67+
'flatten_0.tmp_0': [2, 6 * 64 * 64]
68+
}, {'data': [2, 6, 64, 64],
69+
'flatten_0.tmp_0': [2, 6 * 64 * 64]}, {
7070
'data': [2, 6, 64, 64],
7171
'flatten_0.tmp_0': [2, 6 * 64 * 64]
7272
}, False)

0 commit comments

Comments
 (0)