Skip to content

Commit 125caa1

Browse files
authored
opt argmax op,test=develop (#9384)
1 parent 4793daf commit 125caa1

1 file changed

Lines changed: 9 additions & 11 deletions

File tree

lite/backends/arm/math/argmax.cc

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,18 @@ void argmax_func(const lite::Tensor *input,
4242
for (int n = 0; n < out_stride; n++) {
4343
for (int k = 0; k < in_stride; k++) {
4444
const InType *in_ptr = input->data<InType>() + n * in_channel + k;
45-
std::vector<std::pair<InType, OutType>> vec;
46-
vec.resize(size);
47-
for (int i = 0; i < size; i++) {
48-
vec[i] = std::make_pair(in_ptr[i * in_stride], i);
45+
std::pair<InType, OutType> max_pair;
46+
max_pair.first = in_ptr[0];
47+
max_pair.second = 0;
48+
for (int i = 1; i < size; i++) {
49+
if (in_ptr[i * in_stride] > max_pair.first) {
50+
max_pair.first = in_ptr[i * in_stride];
51+
max_pair.second = i;
52+
}
4953
}
50-
// sort
51-
std::partial_sort(vec.begin(),
52-
vec.begin() + 1,
53-
vec.end(),
54-
std::greater<std::pair<InType, OutType>>());
55-
5654
// out
5755
OutType *out_ptr = output->mutable_data<OutType>() + n * out_channel + k;
58-
*out_ptr = vec[0].second;
56+
*out_ptr = max_pair.second;
5957
}
6058
}
6159
}

0 commit comments

Comments
 (0)