Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/operators/top_k_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ class TopkNPUKernel : public framework::OpKernel<T> {
indices->mutable_data<int64_t>(ctx.GetPlace());

// prepare assit
auto dim = input->dims().size();
auto size = input->dims().size();
// dim is the last dimension of input
auto dim = input->dims()[size - 1];
framework::Tensor assist_seq_tensor;
assist_seq_tensor.Resize({2 * dim});
assist_seq_tensor.mutable_data<T>(ctx.GetPlace());
Expand Down
36 changes: 36 additions & 0 deletions python/paddle/fluid/tests/unittests/npu/test_top_k_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from test_top_k_v2_op_npu import numpy_topk

paddle.enable_static()
SEED = 2021
Expand Down Expand Up @@ -87,5 +88,40 @@ def test_check_output(self):
self.check_output_with_place(self.place)


class TestTopkV3(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "top_k"

self.init_dtype()
self.set_input_data()
self.set_attrs()
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=True)

self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis}
self.outputs = {'Out': output, 'Indices': indices}

def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True

def init_dtype(self):
self.dtype = np.float16

def test_check_output(self):
self.check_output_with_place(self.place)

def set_attrs(self):
self.k = 3
self.axis = 1

def set_input_data(self):
self.input_data = np.random.choice(
10000, size=(10, 20), replace=False).astype(self.dtype)


if __name__ == '__main__':
unittest.main()