1313// limitations under the License.
1414
1515#include " paddle/fluid/inference/tensorrt/op_teller.h"
16-
16+ # include < bitset >
1717#include " paddle/fluid/framework/block_desc.h"
1818#include " paddle/fluid/framework/data_layout.h"
1919
@@ -316,11 +316,36 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
316316 if (op_type == " transpose2" || op_type == " transpose" ) {
317317 if (!desc.HasAttr (" axis" )) {
318318 return false ;
319- } else {
320- std::vector<int > axis =
321- BOOST_GET_CONST (std::vector<int >, desc.GetAttr (" axis" ));
322- if (!with_dynamic_shape && axis[0 ] != 0 ) return false ;
323- if (axis.size () >= nvinfer1::Dims::MAX_DIMS) return false ;
319+ }
320+ std::vector<int > axis =
321+ BOOST_GET_CONST (std::vector<int >, desc.GetAttr (" axis" ));
322+ if (!with_dynamic_shape && axis[0 ] != 0 ) return false ;
323+ if (axis.size () >= nvinfer1::Dims::MAX_DIMS) return false ;
324+ if (axis[0 ] == 0 && axis.size () == 2 ) return false ;
325+
326+ auto * block = desc.Block ();
327+ auto x_var_name = desc.Input (" X" )[0 ];
328+ auto * x_var_desc = block->FindVar (x_var_name);
329+ const auto x_shape = x_var_desc->GetShape ();
330+ int dims = x_shape.size ();
331+ std::vector<int > perm (nvinfer1::Dims::MAX_DIMS);
332+ for (int i = 0 ; i < dims; i++) {
333+ perm[i] = axis[i];
334+ }
335+ auto is_valid_permutation = [&](int dims,
336+ const std::vector<int >& permutation) {
337+ std::bitset<nvinfer1::Dims::MAX_DIMS> found;
338+ for (int i = 0 ; i < dims; ++i) {
339+ const int x = permutation[i];
340+ if ((x < 0 ) || (x >= dims) || found[x])
341+ return false ; // Out of bounds or duplicate
342+ found.set (x);
343+ }
344+ return true ;
345+ };
346+ if (!is_valid_permutation (dims, perm)) {
347+ VLOG (3 ) << " Invalid permutation dimensions for trt transpose op "
348+ " converter: duplicate or out of bound." ;
324349 }
325350 }
326351 if (op_type == " flatten2" || op_type == " flatten" ) {
@@ -475,6 +500,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
475500 return false ;
476501 }
477502 }
503+ if ((scale <= 0 .f ) && with_dynamic_shape) {
504+ VLOG (3 ) << " dynamic shape not support scale not set." ;
505+ return false ;
506+ }
478507 }
479508 }
480509
@@ -547,21 +576,95 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
547576 << desc.Input (" X" ).size () << " ." ;
548577 return false ;
549578 }
579+ auto split_inputs = desc.Inputs ();
580+ if (split_inputs.find (" AxisTensor" ) != split_inputs.end ()) {
581+ if (desc.Input (" AxisTensor" ).size () >= 1 ) {
582+ return false ;
583+ }
584+ }
585+ if (split_inputs.find (" SectionsTensorList" ) != split_inputs.end ()) {
586+ if (desc.Input (" SectionsTensorList" ).size () >= 1 ) {
587+ return false ;
588+ }
589+ }
550590 if (!desc.HasAttr (" axis" )) {
551591 return false ;
552- } else {
553- int axis = BOOST_GET_CONST (int , desc.GetAttr (" axis" ));
554- if (axis == 0 ) {
555- VLOG (3 ) << " Invalid split axis. Split on batch is not supported in "
556- " TensorRT" ;
592+ }
593+ int axis = BOOST_GET_CONST (int , desc.GetAttr (" axis" ));
594+
595+ if (axis == 0 ) {
596+ VLOG (3 ) << " Invalid split axis. Split on batch is not supported in "
597+ " TensorRT" ;
598+ return false ;
599+ }
600+ auto * block = desc.Block ();
601+ auto x_var_name = desc.Input (" X" )[0 ];
602+ auto * x_var_desc = block->FindVar (x_var_name);
603+ const auto x_shape = x_var_desc->GetShape ();
604+ size_t output_num = desc.Output (" Out" ).size ();
605+ std::vector<int > output_lengths;
606+ int num = 0 ;
607+ if (desc.HasAttr (" num" )) {
608+ num = BOOST_GET_CONST (int , desc.GetAttr (" num" ));
609+ }
610+ if (desc.HasAttr (" sections" )) {
611+ output_lengths =
612+ BOOST_GET_CONST (std::vector<int >, desc.GetAttr (" sections" ));
613+ }
614+ if (output_lengths.size () == 0 && num == 0 ) {
615+ VLOG (3 ) << " sections and num cannot be equal to 0 at the same time" ;
616+ return false ;
617+ }
618+ if (with_dynamic_shape) {
619+ #if IS_TRT_VERSION_GE(6000)
620+ #else
621+ VLOG (3 ) << " You are running the TRT Dynamic Shape mode, need to "
622+ " confirm that "
623+ " your TRT version is no less than 6.0" ;
624+ return false ;
625+ #endif
626+ }
627+ axis += (axis < 0 ) ? x_shape.size () : 0 ;
628+ if (x_shape[axis] == -1 ) {
629+ VLOG (3 ) << " The (" << axis << " ) dim of input should not be -1" ;
630+ return false ;
631+ }
632+ if (output_lengths.size () == 0 ) {
633+ if (num > 0 ) {
634+ int64_t in_axis_dim = x_shape[axis];
635+ if (in_axis_dim % num != 0 ) {
636+ VLOG (3 ) << " Invalid number to split. Tensor split does not result"
637+ " in an equal division of dimensions. Axis dim = "
638+ << in_axis_dim << " num = " << num << " != 0" ;
639+ return false ;
640+ }
641+ size_t out_axis_dim = in_axis_dim / num;
642+ for (int i = 0 ; i < num; ++i) {
643+ output_lengths.push_back (out_axis_dim);
644+ }
645+ }
646+ }
647+ if (output_lengths.size () != output_num) {
648+ VLOG (3 ) << " The output_length should be equal to the output size." ;
649+ return false ;
650+ }
651+ }
652+ if (op_type == " scale" ) {
653+ auto scale_inputs = desc.Inputs ();
654+ if (scale_inputs.find (" ScaleTensor" ) != scale_inputs.end ()) {
655+ if (desc.Input (" ScaleTensor" ).size () >= 1 ) {
557656 return false ;
558657 }
559658 }
659+ auto * block = desc.Block ();
660+ auto x_var_name = desc.Input (" X" )[0 ];
661+ auto * x_var_desc = block->FindVar (x_var_name);
662+ const auto x_shape = x_var_desc->GetShape ();
663+ if (!with_dynamic_shape && x_shape.size () == 1 ) return false ;
560664 }
561-
562665 if (op_type == " slice" ) {
563666 if (!desc.HasAttr (" axes" ) || !desc.HasAttr (" starts" ) ||
564- !desc.HasAttr (" ends" )) {
667+ !desc.HasAttr (" ends" ) || !desc. HasAttr ( " decrease_axis " ) ) {
565668 return false ;
566669 } else {
567670 std::vector<int > axes =
@@ -570,9 +673,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
570673 BOOST_GET_CONST (std::vector<int >, desc.GetAttr (" starts" ));
571674 std::vector<int > ends =
572675 BOOST_GET_CONST (std::vector<int >, desc.GetAttr (" ends" ));
676+ std::vector<int > decrease_axis =
677+ BOOST_GET_CONST (std::vector<int >, desc.GetAttr (" decrease_axis" ));
573678 if (axes.size () != starts.size () || axes.size () != ends.size ()) {
574679 return false ;
575680 }
681+ if (decrease_axis.size () > 0 ) {
682+ VLOG (3 ) << " Invalid slice decrease_axis. decrease_axis.size() > 0"
683+ " is not supported in TensorRT" ;
684+ return false ;
685+ }
576686 if (!with_dynamic_shape) {
577687 for (size_t i = 0 ; i < axes.size (); i++) {
578688 if (axes[i] == 0 ) {
0 commit comments