Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions paddle/fluid/inference/tensorrt/convert/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,33 +41,27 @@ class GatherOpConverter : public OpConverter {
std::string input_name = op_desc.Input("X").front();
std::string index_name = op_desc.Input("Index").front();
std::string output_name = op_desc.Output("Out").front();

const auto input_tensor = engine_->GetITensor(input_name);
const auto index_tensor = engine_->GetITensor(index_name);

const int axis = 0;
int axis = 0;
if (op_desc.HasAttr("axis")) {
axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
}

auto layer = TRT_ENGINE_ADD_LAYER(engine_, Gather, *input_tensor,
*index_tensor, axis);
auto reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *index_tensor);

auto odim = layer->getOutput(0)->getDimensions();
nvinfer1::Dims index_shape{};
index_shape.nbDims = 1;
index_shape.d[0] = -1;

auto reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0));
reshape_layer->setReshapeDimensions(index_shape);

nvinfer1::Dims target_shape{};
target_shape.nbDims = odim.nbDims - 1;
for (int i = 0; i < axis; ++i) {
target_shape.d[i] = odim.d[i];
}
target_shape.d[axis] = 0;
for (int i = axis + 1; i < target_shape.nbDims; ++i) {
target_shape.d[i] = odim.d[i + 1];
}

reshape_layer->setReshapeDimensions(target_shape);
auto layer = TRT_ENGINE_ADD_LAYER(engine_, Gather, *input_tensor,
*reshape_layer->getOutput(0), axis);
layer->setNbElementWiseDims(0);

RreplenishLayerAndOutput(reshape_layer, "gather", {output_name}, test_mode);
RreplenishLayerAndOutput(layer, "gather", {output_name}, test_mode);
}
};

Expand Down
32 changes: 17 additions & 15 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}

if (op_type == "gather") {
if (!with_dynamic_shape) return false;

if (with_dynamic_shape) {
auto gather_inputs = desc.Inputs();
if (gather_inputs.find("Axis") != gather_inputs.end()) {
if (desc.Input("Axis").size() >= 1) {
return false;
}
}
if (!with_dynamic_shape) {
return false;
} else {
auto* block = desc.Block();
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape();
Expand All @@ -373,13 +379,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}
}

auto inputs = desc.InputArgumentNames();
for (auto& input : inputs) {
if (input == "Axis" && desc.Input("Axis").size() > 0) return false;
}
// current not support axis from input, use default 0
if (desc.GetAttrIfExists<int>("axis")) return false;
}

if (op_type == "gather_nd") {
Expand Down Expand Up @@ -1085,13 +1084,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
#if IS_TRT_VERSION_GE(7000)
if (op_type == "tile") {
// Paddle-TRT does not support the input tensors.
auto inputs = desc.InputArgumentNames();
for (auto& input : inputs) {
if (input == "repeat_times_tensor" &&
desc.Input("repeat_times_tensor").size() > 0)
auto tile_inputs = desc.Inputs();
if (tile_inputs.find("repeat_times_tensor") != tile_inputs.end()) {
if (desc.Input("repeat_times_tensor").size() >= 1) {
return false;
if (input == "RepeatTimes" && desc.Input("RepeatTimes").size() > 0)
}
}
if (tile_inputs.find("RepeatTimes") != tile_inputs.end()) {
if (desc.Input("RepeatTimes").size() >= 1) {
return false;
}
}
if (with_dynamic_shape) return false;
if (!with_dynamic_shape && !desc.HasAttr("repeat_times")) return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,47 +23,78 @@
from paddle.fluid.core import AnalysisConfig


class TRTGatherTest(InferencePassTest):
class TRTGatherTest1(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name='data', shape=[-1, 512], dtype='float32')
index = fluid.data(name='index', shape=[-1], dtype='int32')
scale_out = self.append_gather(data, index)
out = fluid.layers.batch_norm(scale_out, is_test=True)

index = np.arange(self.num_gather, dtype='int32')
np.random.shuffle(index)
data = fluid.data(name='data', shape=[-1, 128], dtype='float32')
index = fluid.data(name='index', shape=[-1, 1], dtype='int32')
scale_out = fluid.layers.gather(data, index=index)
out = fluid.layers.softmax(input=scale_out)

self.feeds = {
"data": np.random.random([self.bs, 512]).astype("float32"),
"index": index,
"data": np.random.random([self.bs, 128]).astype("float32"),
"index": self.index
}

self.enable_trt = True
self.trt_parameters = TRTGatherTest.TensorRTParam(
self.trt_parameters = TRTGatherTest1.TensorRTParam(
1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False)
self.dynamic_shape_params = TRTGatherTest1.DynamicShapeParam({
'data': [1, 1],
'index': [1, 1]
}, {'data': [32, 128],
'index': [3, 1]}, {'data': [32, 128],
'index': [3, 1]}, False)
self.fetch_list = [out]

def set_params(self):
self.num_gather = 16
self.bs = 32

def append_gather(self, data, index):
return fluid.layers.gather(data, index=index)
self.index = np.array([[1], [2], [3]], dtype='int32')
self.bs = 4

def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu, flatten=True)
self.check_output_with_option(use_gpu, flatten=False)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))


class TRTGatherTest1(TRTGatherTest):
class TRTGatherTest2(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name='data', shape=[16, 64], dtype='float32')
index = fluid.data(name='index', shape=[2], dtype='int32')
scale_out = fluid.layers.gather(data, index=index)
out = fluid.layers.softmax(input=scale_out)

self.feeds = {
"data": np.random.random([self.bs, 64]).astype("float32"),
"index": self.index
}

self.enable_trt = True
self.trt_parameters = TRTGatherTest2.TensorRTParam(
1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False)
self.dynamic_shape_params = TRTGatherTest2.DynamicShapeParam({
'data': [2, 4],
'index': [1]
}, {'data': [256, 256],
'index': [4]}, {'data': [64, 32],
'index': [2]}, False)
self.fetch_list = [out]

def set_params(self):
self.num_gather = 32
self.bs = 32
self.index = np.array([1, 4], dtype='int32')
self.bs = 16

def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu, flatten=False)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, methodName='runTest'):
max_batch_size=4,
min_subgraph_size=0,
precision=paddle_infer.PrecisionType.Float32,
use_static=True,
use_static=False,
use_calib_mode=False)
self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False)
self.num_percent_cases = float(
Expand Down Expand Up @@ -109,7 +109,9 @@ def assert_tensors_near(self,
for key, arr in tensor.items():
self.assertTrue(
baseline[key].shape == arr.shape,
"The output shape of GPU and TensorRT are not equal.")
"The output shape of GPU and TensorRT are not equal, the baseline shape is "
+ str(baseline[key].shape) + ', but the trt shape is ' +
str(arr.shape))
self.assertTrue(
np.allclose(
baseline[key], arr, atol=atol, rtol=rtol),
Expand Down Expand Up @@ -259,9 +261,9 @@ def run_test(self, quant=False):
if not skip_flag:
self.assert_op_size(nodes_num[0], nodes_num[1])
# deserialize test
if nodes_num[0] > 0:
self.run_test_config(model, params, prog_config,
pred_config_deserialize, feed_data)
#if nodes_num[0] > 0:
# self.run_test_config(model, params, prog_config,
# pred_config_deserialize, feed_data)
except Exception as e:
self.fail_log(
str(prog_config) + ' vs ' + self.inference_config_str(
Expand Down