Skip to content

Commit fb703ac

Browse files
committed
fix
1 parent 81a84be commit fb703ac

File tree

1 file changed

+123
-13
lines changed

1 file changed

+123
-13
lines changed

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 123 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/inference/tensorrt/op_teller.h"
16-
16+
#include <bitset>
1717
#include "paddle/fluid/framework/block_desc.h"
1818
#include "paddle/fluid/framework/data_layout.h"
1919

@@ -316,11 +316,36 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
316316
if (op_type == "transpose2" || op_type == "transpose") {
317317
if (!desc.HasAttr("axis")) {
318318
return false;
319-
} else {
320-
std::vector<int> axis =
321-
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("axis"));
322-
if (!with_dynamic_shape && axis[0] != 0) return false;
323-
if (axis.size() >= nvinfer1::Dims::MAX_DIMS) return false;
319+
}
320+
std::vector<int> axis =
321+
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("axis"));
322+
if (!with_dynamic_shape && axis[0] != 0) return false;
323+
if (axis.size() >= nvinfer1::Dims::MAX_DIMS) return false;
324+
if (axis[0] == 0 && axis.size() == 2) return false;
325+
326+
auto* block = desc.Block();
327+
auto x_var_name = desc.Input("X")[0];
328+
auto* x_var_desc = block->FindVar(x_var_name);
329+
const auto x_shape = x_var_desc->GetShape();
330+
int dims = x_shape.size();
331+
std::vector<int> perm(nvinfer1::Dims::MAX_DIMS);
332+
for (int i = 0; i < dims; i++) {
333+
perm[i] = axis[i];
334+
}
335+
auto is_valid_permutation = [&](int dims,
336+
const std::vector<int>& permutation) {
337+
std::bitset<nvinfer1::Dims::MAX_DIMS> found;
338+
for (int i = 0; i < dims; ++i) {
339+
const int x = permutation[i];
340+
if ((x < 0) || (x >= dims) || found[x])
341+
return false; // Out of bounds or duplicate
342+
found.set(x);
343+
}
344+
return true;
345+
};
346+
if (!is_valid_permutation(dims, perm)) {
347+
VLOG(3) << "Invalid permutation dimensions for trt transpose op "
348+
"converter: duplicate or out of bound.";
324349
}
325350
}
326351
if (op_type == "flatten2" || op_type == "flatten") {
@@ -475,6 +500,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
475500
return false;
476501
}
477502
}
503+
if ((scale <= 0.f) && with_dynamic_shape) {
504+
VLOG(3) << "dynamic shape not support scale not set.";
505+
return false;
506+
}
478507
}
479508
}
480509

@@ -547,21 +576,95 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
547576
<< desc.Input("X").size() << ".";
548577
return false;
549578
}
579+
auto split_inputs = desc.Inputs();
580+
if (split_inputs.find("AxisTensor") != split_inputs.end()) {
581+
if (desc.Input("AxisTensor").size() >= 1) {
582+
return false;
583+
}
584+
}
585+
if (split_inputs.find("SectionsTensorList") != split_inputs.end()) {
586+
if (desc.Input("SectionsTensorList").size() >= 1) {
587+
return false;
588+
}
589+
}
550590
if (!desc.HasAttr("axis")) {
551591
return false;
552-
} else {
553-
int axis = BOOST_GET_CONST(int, desc.GetAttr("axis"));
554-
if (axis == 0) {
555-
VLOG(3) << "Invalid split axis. Split on batch is not supported in "
556-
"TensorRT";
592+
}
593+
int axis = BOOST_GET_CONST(int, desc.GetAttr("axis"));
594+
595+
if (axis == 0) {
596+
VLOG(3) << "Invalid split axis. Split on batch is not supported in "
597+
"TensorRT";
598+
return false;
599+
}
600+
auto* block = desc.Block();
601+
auto x_var_name = desc.Input("X")[0];
602+
auto* x_var_desc = block->FindVar(x_var_name);
603+
const auto x_shape = x_var_desc->GetShape();
604+
size_t output_num = desc.Output("Out").size();
605+
std::vector<int> output_lengths;
606+
int num = 0;
607+
if (desc.HasAttr("num")) {
608+
num = BOOST_GET_CONST(int, desc.GetAttr("num"));
609+
}
610+
if (desc.HasAttr("sections")) {
611+
output_lengths =
612+
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("sections"));
613+
}
614+
if (output_lengths.size() == 0 && num == 0) {
615+
VLOG(3) << "sections and num cannot be equal to 0 at the same time";
616+
return false;
617+
}
618+
if (with_dynamic_shape) {
619+
#if IS_TRT_VERSION_GE(6000)
620+
#else
621+
VLOG(3) << "You are running the TRT Dynamic Shape mode, need to "
622+
"confirm that "
623+
"your TRT version is no less than 6.0";
624+
return false;
625+
#endif
626+
}
627+
axis += (axis < 0) ? x_shape.size() : 0;
628+
if (x_shape[axis] == -1) {
629+
VLOG(3) << "The (" << axis << ") dim of input should not be -1";
630+
return false;
631+
}
632+
if (output_lengths.size() == 0) {
633+
if (num > 0) {
634+
int64_t in_axis_dim = x_shape[axis];
635+
if (in_axis_dim % num != 0) {
636+
VLOG(3) << "Invalid number to split. Tensor split does not result"
637+
" in an equal division of dimensions. Axis dim = "
638+
<< in_axis_dim << " num = " << num << "!= 0";
639+
return false;
640+
}
641+
size_t out_axis_dim = in_axis_dim / num;
642+
for (int i = 0; i < num; ++i) {
643+
output_lengths.push_back(out_axis_dim);
644+
}
645+
}
646+
}
647+
if (output_lengths.size() != output_num) {
648+
VLOG(3) << "The output_length should be equal to the output size.";
649+
return false;
650+
}
651+
}
652+
if (op_type == "scale") {
653+
auto scale_inputs = desc.Inputs();
654+
if (scale_inputs.find("ScaleTensor") != scale_inputs.end()) {
655+
if (desc.Input("ScaleTensor").size() >= 1) {
557656
return false;
558657
}
559658
}
659+
auto* block = desc.Block();
660+
auto x_var_name = desc.Input("X")[0];
661+
auto* x_var_desc = block->FindVar(x_var_name);
662+
const auto x_shape = x_var_desc->GetShape();
663+
if (!with_dynamic_shape && x_shape.size() == 1) return false;
560664
}
561-
562665
if (op_type == "slice") {
563666
if (!desc.HasAttr("axes") || !desc.HasAttr("starts") ||
564-
!desc.HasAttr("ends")) {
667+
!desc.HasAttr("ends") || !desc.HasAttr("decrease_axis")) {
565668
return false;
566669
} else {
567670
std::vector<int> axes =
@@ -570,9 +673,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
570673
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("starts"));
571674
std::vector<int> ends =
572675
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("ends"));
676+
std::vector<int> decrease_axis =
677+
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("decrease_axis"));
573678
if (axes.size() != starts.size() || axes.size() != ends.size()) {
574679
return false;
575680
}
681+
if (decrease_axis.size() > 0) {
682+
VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0"
683+
"is not supported in TensorRT";
684+
return false;
685+
}
576686
if (!with_dynamic_shape) {
577687
for (size_t i = 0; i < axes.size(); i++) {
578688
if (axes[i] == 0) {

0 commit comments

Comments
 (0)