Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2035,8 +2035,8 @@ void NMSInferMeta(const MetaTensor& x, float threshold, MetaTensor* out) {
"whose shape must be [N, 4] "
"N is the number of boxes "
"in last dimension in format [x1, x2, y1, y2]. "));
auto num_boxes = boxes_dim[0];
out->set_dims(phi::make_ddim({num_boxes}));
out->set_dims(phi::make_ddim({-1}));
out->set_dtype(DataType::INT64);
}

void NormInferMeta(const MetaTensor& x,
Expand Down
24 changes: 18 additions & 6 deletions paddle/phi/kernels/cpu/nms_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
#include "paddle/phi/backends/cpu/cpu_context.h"

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/diagonal.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"

namespace phi {

template <typename T>
static void NMS(const T* boxes_data,
int64_t* output_data,
float threshold,
int64_t num_boxes) {
static int64_t NMS(const T* boxes_data,
int64_t* output_data,
float threshold,
int64_t num_boxes) {
auto num_masks = CeilDivide(num_boxes, 64);
std::vector<uint64_t> masks(num_masks, 0);

Expand Down Expand Up @@ -54,18 +55,29 @@ static void NMS(const T* boxes_data,
output_data[output_data_idx++] = i;
}

int64_t num_keep_boxes = output_data_idx;

for (; output_data_idx < num_boxes; ++output_data_idx) {
output_data[output_data_idx] = 0;
}

return num_keep_boxes;
}

template <typename T, typename Context>
void NMSKernel(const Context& dev_ctx,
const DenseTensor& boxes,
float threshold,
DenseTensor* output) {
auto output_data = dev_ctx.template Alloc<int64_t>(output);
NMS<T>(boxes.data<T>(), output_data, threshold, boxes.dims()[0]);
int64_t num_boxes = boxes.dims()[0];
DenseTensor output_tmp;
output_tmp.Resize(phi::make_ddim({num_boxes}));
auto output_tmp_data = dev_ctx.template Alloc<int64_t>(&output_tmp);

int64_t num_keep_boxes =
NMS<T>(boxes.data<T>(), output_tmp_data, threshold, num_boxes);
auto slice_out = output_tmp.Slice(0, num_keep_boxes);
phi::Copy(dev_ctx, slice_out, dev_ctx.GetPlace(), false, output);
}

} // namespace phi
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/gpu/nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ void NMSKernel(const Context& dev_ctx,
const DenseTensor& boxes,
float threshold,
DenseTensor* output) {
auto* output_data = dev_ctx.template Alloc<int64_t>(output);
const int64_t num_boxes = boxes.dims()[0];
const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);
dim3 block(threadsPerBlock);
Expand Down Expand Up @@ -93,11 +92,13 @@ void NMSKernel(const Context& dev_ctx,
}
}
}
output->Resize(phi::make_ddim({last_box_num}));
auto* output_data = dev_ctx.template Alloc<int64_t>(output);
paddle::memory::Copy(dev_ctx.GetPlace(),
output_data,
phi::CPUPlace(),
output_host,
sizeof(int64_t) * num_boxes,
sizeof(int64_t) * last_box_num,
dev_ctx.stream());
}
} // namespace phi
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_nms_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def nms(boxes, nms_threshold):
else:
continue

return selected_indices
return selected_indices[:cnt]


class TestNMSOp(OpTest):
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/vision/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,7 +1611,9 @@ def _nms(boxes, iou_threshold):
import paddle
if category_idxs is None:
sorted_global_indices = paddle.argsort(scores, descending=True)
return _nms(boxes[sorted_global_indices], iou_threshold)
sorted_keep_boxes_indices = _nms(boxes[sorted_global_indices],
iou_threshold)
return sorted_global_indices[sorted_keep_boxes_indices]

if top_k is not None:
assert top_k <= scores.shape[
Expand Down