File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -42,20 +42,18 @@ void argmax_func(const lite::Tensor *input,
4242 for (int n = 0 ; n < out_stride; n++) {
4343 for (int k = 0 ; k < in_stride; k++) {
4444 const InType *in_ptr = input->data <InType>() + n * in_channel + k;
45- std::vector<std::pair<InType, OutType>> vec;
46- vec.resize (size);
47- for (int i = 0 ; i < size; i++) {
48- vec[i] = std::make_pair (in_ptr[i * in_stride], i);
45+ std::pair<InType, OutType> max_pair;
46+ max_pair.first = in_ptr[0 ];
47+ max_pair.second = 0 ;
48+ for (int i = 1 ; i < size; i++) {
49+ if (in_ptr[i * in_stride] > max_pair.first ) {
50+ max_pair.first = in_ptr[i * in_stride];
51+ max_pair.second = i;
52+ }
4953 }
50- // sort
51- std::partial_sort (vec.begin (),
52- vec.begin () + 1 ,
53- vec.end (),
54- std::greater<std::pair<InType, OutType>>());
55-
5654 // out
5755 OutType *out_ptr = output->mutable_data <OutType>() + n * out_channel + k;
58- *out_ptr = vec[ 0 ] .second ;
56+ *out_ptr = max_pair .second ;
5957 }
6058 }
6159}
You can’t perform that action at this time.
0 commit comments