diff --git a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc index 81426c328bdfc..b62575f71abd4 100644 --- a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc +++ b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc @@ -54,6 +54,11 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const { std::vector, int>> all_selected_indices; int total_num_saved_outputs = 0; + // safe downcast max_output_boxes_per_class to int as cub::DeviceSelect::Flagged() does not support int64_t + int int_max_output_boxes_per_class = max_output_boxes_per_class > std::numeric_limits::max() + ? std::numeric_limits::max() + : static_cast(max_output_boxes_per_class); + for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) { for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) { IAllocatorUniquePtr d_selected_indices{}; @@ -66,7 +71,7 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const { GetCenterPointBox(), batch_index, class_index, - max_output_boxes_per_class, + int_max_output_boxes_per_class, iou_threshold, score_threshold, d_selected_indices, @@ -130,4 +135,4 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const { } } // namespace cuda -}; // namespace onnxruntime +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu index 8d04fc2d4d657..8364753b2c63b 100644 --- a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu +++ b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu @@ -55,13 +55,6 @@ constexpr int kNmsBlockDim = 16; constexpr int kNmsBlockDimMax = 128; constexpr int kNmsChunkSize = 2000; -template -__device__ inline void Swap(T& a, T& b) { - T c(a); - a = b; - b = c; -} - // Check whether two boxes have an IoU greater than threshold. template __device__ inline bool OverThreshold(const Box* a, const Box* b, @@ -88,10 +81,6 @@ __device__ inline bool OverThreshold(const Box* a, const Box* b, return aa >= bt; } -__device__ inline void Flipped(Box& box) { - if (box.x1 > box.x2) Swap(box.x1, box.x2); - if (box.y1 > box.y2) Swap(box.y1, box.y2); -} template __device__ inline bool CheckBit(T* bit_mask, int bit) { constexpr int kShiftLen = NumBits(8 * sizeof(T)) - 1; @@ -104,7 +93,7 @@ __device__ inline bool CheckBit(T* bit_mask, int bit) { // generated by NMSKernel Abort early if max_boxes boxes are selected. Bitmask // is num_boxes*bit_mask_len bits indicating whether to keep or remove a box. __global__ void NMSReduce(const int* bitmask, const int bit_mask_len, - const int num_boxes, const int64_t max_boxes, + const int num_boxes, const int max_boxes, char* result_mask) { extern __shared__ int local[]; @@ -247,7 +236,7 @@ Status NmsGpu(std::function(size_t)> allocator, const float iou_threshold, int* d_selected_indices, int* h_nkeep, - const int64_t max_boxes) { + const int max_boxes) { // Making sure we respect the __align(16)__ // we promised to the compiler. auto iptr = reinterpret_cast(d_sorted_boxes_float_ptr); @@ -337,7 +326,7 @@ Status NonMaxSuppressionImpl( const int64_t center_point_box, int64_t batch_index, int64_t class_index, - int64_t max_output_boxes_per_class, + int max_output_boxes_per_class, float iou_threshold, float score_threshold, IAllocatorUniquePtr& selected_indices, @@ -427,7 +416,7 @@ Status NonMaxSuppressionImpl( CUDA_RETURN_IF_ERROR(cudaGetLastError()); // STEP 4. map back to sorted indices - *h_number_selected = std::min(*h_number_selected, (int)max_output_boxes_per_class); + *h_number_selected = std::min(*h_number_selected, max_output_boxes_per_class); int num_to_keep = *h_number_selected; if (num_to_keep > 0) { IAllocatorUniquePtr d_output_indices_ptr{allocator(num_to_keep * sizeof(int))}; diff --git a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.h b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.h index c10c508377ca6..493c115e52c50 100644 --- a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.h +++ b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.h @@ -19,7 +19,7 @@ Status NonMaxSuppressionImpl( const int64_t center_point_box, int64_t batch_index, int64_t class_index, - int64_t max_output_boxes_per_class, + int max_output_boxes_per_class, float iou_threshold, float score_threshold, IAllocatorUniquePtr& selected_indices, diff --git a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc index 1a796d9da0929..4b8475e3f6a49 100644 --- a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc @@ -344,6 +344,25 @@ TEST(NonMaxSuppressionOpTest, ZeroMaxOutputPerClass) { test.Run(); } +TEST(NonMaxSuppressionOpTest, BigIntMaxOutputBoxesPerClass) { + OpTester test("NonMaxSuppression", 10, kOnnxDomain); + test.AddInput("boxes", {1, 6, 4}, + {0.0f, 0.0f, 1.0f, 1.0f, + 0.0f, 0.1f, 1.0f, 1.1f, + 0.0f, -0.1f, 1.0f, 0.9f, + 0.0f, 10.0f, 1.0f, 11.0f, + 0.0f, 10.1f, 1.0f, 11.1f, + 0.0f, 100.0f, 1.0f, 101.0f}); + test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); + test.AddInput("max_output_boxes_per_class", {}, {9223372036854775807L}); + test.AddInput("iou_threshold", {}, {0.5f}); + test.AddInput("score_threshold", {}, {0.4f}); + test.AddOutput("selected_indices", {2, 3}, + {0L, 0L, 3L, + 0L, 0L, 0L}); + test.Run(); +} + TEST(NonMaxSuppressionOpTest, WithIOUThresholdOpset11) { OpTester test("NonMaxSuppression", 11, kOnnxDomain); test.AddInput("boxes", {1, 6, 4},