Skip to content

Commit 29cdeb6

Browse files
committed
[Host] add beam_search_decode; test=develop
1 parent b9117a6 commit 29cdeb6

5 files changed

Lines changed: 308 additions & 280 deletions

File tree

lite/kernels/arm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ add_kernel(sum_compute ARM extra SRCS sum_compute.cc DEPS ${lite_kernel_deps} ma
9292
# for OCR specific
9393
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm)
9494
add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm)
95-
add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_compute.cc DEPS ${lite_kernel_deps} math_arm)
95+
add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_compute.cc DEPS ${lite_kernel_deps} beam_search_decode_compute_host)
9696
add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm)
9797
add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_compute.cc DEPS ${lite_kernel_deps} math_arm)
9898
add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)

lite/kernels/arm/beam_search_decode_compute.cc

Lines changed: 2 additions & 275 deletions
Original file line numberDiff line numberDiff line change
@@ -12,286 +12,13 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "lite/kernels/arm/beam_search_decode_compute.h"
16-
#include <algorithm>
17-
#include <vector>
18-
#include "lite/api/paddle_place.h"
19-
#include "lite/backends/arm/math/funcs.h"
20-
#include "lite/core/op_registry.h"
21-
#include "lite/core/tensor.h"
22-
#include "lite/core/type_system.h"
23-
24-
namespace paddle {
25-
namespace lite {
26-
namespace kernels {
27-
namespace arm {
28-
29-
using LoDTensor = lite::Tensor;
30-
using LoDTensorArray = std::vector<lite::Tensor>;
31-
32-
// all the lod have 2 levels.
33-
// The first is source level, the second is sentence level.
34-
// source level describe how many prefixes (branchs) for each source sentece
35-
// (beam). sentence level describe how these candidates belong to the prefixes.
36-
const size_t kSourceLevel = 0;
37-
const size_t kSentenceLevel = 1;
38-
39-
template <typename T>
40-
struct Sentence {
41-
std::vector<int64_t> word_ids;
42-
std::vector<T> scores;
43-
};
44-
45-
template <typename T>
46-
using SentenceVector = std::vector<Sentence<T>>;
47-
48-
template <typename T>
49-
struct BeamSearchDecoder {
50-
BeamSearchDecoder(size_t beam_size, int end_id)
51-
: beam_size_(beam_size), end_id_(end_id) {}
52-
53-
/**
54-
* convert the result sentence_vector for each source sentence into two
55-
* LodTensor.
56-
* One is all candidate sentences with word id, one is all candidate sentences
57-
* with word score.
58-
* Param:
59-
* sentence_vector_list: sentence_vector for each source sentence.
60-
* id_tensor: result LoDTensor for sentences of id.
61-
* score_tensor: result LoDTensor for sentences of score.
62-
* reverse: whether ids of sentence in sentence_vector_list is reversed
63-
* sort_by_score: whether to sort hypotheses of each sentence by scores.
64-
*/
65-
void ConvertSentenceVectorToLodTensor(
66-
std::vector<SentenceVector<T>> sentence_vector_list,
67-
LoDTensor* id_tensor,
68-
LoDTensor* score_tensor,
69-
bool reverse = true,
70-
bool sort_by_score = true) const {
71-
size_t src_num = sentence_vector_list.size();
72-
CHECK_GT(src_num, 0) << "src_num should not be 0";
73-
74-
std::vector<uint64_t> source_level_lod = {0};
75-
std::vector<uint64_t> sentence_level_lod = {0};
76-
std::vector<int64_t> id_data;
77-
std::vector<T> score_data;
78-
79-
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
80-
if (sort_by_score) {
81-
std::stable_sort(sentence_vector_list[src_idx].begin(),
82-
sentence_vector_list[src_idx].end(),
83-
[reverse](const Sentence<T>& a, const Sentence<T>& b) {
84-
if (reverse)
85-
return a.scores.front() > b.scores.front();
86-
else
87-
return a.scores.back() > b.scores.back();
88-
});
89-
}
90-
for (Sentence<T>& sentence : sentence_vector_list[src_idx]) {
91-
if (reverse) {
92-
id_data.insert(id_data.end(),
93-
sentence.word_ids.rbegin(),
94-
sentence.word_ids.rend());
95-
score_data.insert(score_data.end(),
96-
sentence.scores.rbegin(),
97-
sentence.scores.rend());
98-
} else {
99-
id_data.insert(id_data.end(),
100-
sentence.word_ids.begin(),
101-
sentence.word_ids.end());
102-
score_data.insert(
103-
score_data.end(), sentence.scores.begin(), sentence.scores.end());
104-
}
105-
106-
sentence_level_lod.push_back(sentence_level_lod.back() +
107-
sentence.word_ids.size());
108-
}
109-
source_level_lod.push_back(source_level_lod.back() +
110-
sentence_vector_list[src_idx].size());
111-
}
112-
113-
LoD lod;
114-
lod.push_back(source_level_lod);
115-
lod.push_back(sentence_level_lod);
116-
117-
id_tensor->set_lod(lod);
118-
119-
id_tensor->Resize({static_cast<int64_t>(id_data.size())});
120-
auto id_ptr = id_tensor->mutable_data<int64_t>();
121-
TargetCopy(
122-
TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(int64_t));
123-
124-
score_tensor->set_lod(lod);
125-
score_tensor->Resize({static_cast<int64_t>(score_data.size())});
126-
auto score_ptr = score_tensor->mutable_data<T>();
127-
TargetCopy(TARGET(kARM),
128-
score_ptr,
129-
score_data.data(),
130-
score_data.size() * sizeof(T));
131-
}
132-
133-
/**
134-
* Gather the hypotheses for each source sentence by backtrace though the
135-
* LoDTensorArray step_ids whose lods reserve the path in the tree.
136-
*/
137-
void Backtrace(const LoDTensorArray& step_ids,
138-
const LoDTensorArray& step_scores,
139-
LoDTensor* id_tensor,
140-
LoDTensor* score_tensor) const {
141-
CHECK(!step_ids.empty()) << "step num should be larger than 0";
142-
CHECK_EQ(step_ids.size(), step_scores.size())
143-
<< "step_ids and step_scores should be the same";
144-
const size_t step_num = step_ids.size();
145-
const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1;
146-
std::vector<SentenceVector<T>> sentence_vector_list(
147-
src_num, SentenceVector<T>(beam_size_));
148-
std::vector<std::vector<size_t>> prefix_idx_vector_list(src_num);
149-
for (int step_id = step_num - 1; step_id >= 0; --step_id) {
150-
auto& cur_ids = step_ids.at(step_id);
151-
auto& cur_scores = step_scores.at(step_id);
152-
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
153-
// for each source sentence
154-
auto& sentence_vector = sentence_vector_list.at(src_idx);
155-
auto& prefix_idx_vector = prefix_idx_vector_list.at(src_idx);
156-
size_t src_prefix_start = cur_ids.lod().at(kSourceLevel)[src_idx];
157-
size_t src_prefix_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1];
158-
if (prefix_idx_vector.empty()) { // be finished and pruned at this step
159-
// or the last time step
160-
for (size_t prefix_idx = src_prefix_start;
161-
prefix_idx < src_prefix_end;
162-
++prefix_idx) {
163-
size_t candidate_start =
164-
cur_ids.lod().at(kSentenceLevel)[prefix_idx];
165-
size_t candidate_end =
166-
cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1];
167-
for (size_t candidate_idx = candidate_start;
168-
candidate_idx < candidate_end;
169-
++candidate_idx) {
170-
prefix_idx_vector.push_back(prefix_idx);
171-
size_t idx = prefix_idx_vector.size() - 1;
172-
auto cur_id = cur_ids.data<int64_t>()[candidate_idx];
173-
auto cur_score = cur_scores.data<T>()[candidate_idx];
174-
sentence_vector.at(idx).word_ids.push_back(cur_id);
175-
sentence_vector.at(idx).scores.push_back(cur_score);
176-
}
177-
}
178-
} else { // use prefix_idx_vector to backtrace
179-
size_t src_candidate_start =
180-
cur_ids.lod().at(kSentenceLevel)[src_prefix_start];
181-
size_t prefix_idx = src_prefix_start;
182-
size_t candidate_num =
183-
cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] -
184-
cur_ids.lod().at(kSentenceLevel)[prefix_idx];
185-
for (size_t idx = 0; idx < prefix_idx_vector.size(); ++idx) {
186-
auto candidate_idx = prefix_idx_vector.at(idx);
187-
auto cur_id = cur_ids.data<int64_t>()[candidate_idx];
188-
auto cur_score = cur_scores.data<T>()[candidate_idx];
189-
if (cur_id != end_id_ || sentence_vector.at(idx).word_ids.empty()) {
190-
// to skip redundant end tokens
191-
sentence_vector.at(idx).word_ids.push_back(cur_id);
192-
sentence_vector.at(idx).scores.push_back(cur_score);
193-
}
194-
195-
while (src_candidate_start + candidate_num <=
196-
candidate_idx) { // search the corresponding prefix
197-
prefix_idx++;
198-
candidate_num +=
199-
cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] -
200-
cur_ids.lod().at(kSentenceLevel)[prefix_idx];
201-
}
202-
prefix_idx_vector.at(idx) = prefix_idx;
203-
}
204-
}
205-
}
206-
}
207-
208-
ConvertSentenceVectorToLodTensor(
209-
sentence_vector_list, id_tensor, score_tensor, true, true);
210-
}
211-
212-
size_t beam_size_;
213-
int end_id_;
214-
};
215-
216-
struct BeamSearchDecodeFunctor {
217-
BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
218-
const LoDTensorArray& step_scores,
219-
LoDTensor* id_tensor,
220-
LoDTensor* score_tensor,
221-
size_t beam_size,
222-
int end_id)
223-
: beam_size_(beam_size),
224-
end_id_(end_id),
225-
step_ids_(step_ids),
226-
step_scores_(step_scores),
227-
id_tensor_(id_tensor),
228-
score_tensor_(score_tensor) {}
229-
230-
template <typename T>
231-
void apply() const {
232-
BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_);
233-
beam_search_decoder.Backtrace(
234-
step_ids_, step_scores_, id_tensor_, score_tensor_);
235-
}
236-
237-
size_t beam_size_;
238-
int end_id_;
239-
const LoDTensorArray& step_ids_;
240-
const LoDTensorArray& step_scores_;
241-
LoDTensor* id_tensor_;
242-
LoDTensor* score_tensor_;
243-
};
244-
245-
template <>
246-
void BeamSearchDecodeFunctor::apply<bool>() const {
247-
LOG(FATAL) << "beam search decode op does not support bool!";
248-
}
249-
250-
void BeamSearchDecodeCompute::Run() {
251-
auto& param = this->Param<param_t>();
252-
auto& ctx = this->ctx_->template As<ARMContext>();
253-
// inputs
254-
auto ids = param.ids;
255-
auto scores = param.scores;
256-
// outputs
257-
auto sentence_ids = param.sentence_ids;
258-
auto sentence_scores = param.sentence_scores;
259-
260-
const size_t step_num = ids->size();
261-
CHECK_GT(step_num, 0UL) << "beam search steps should be larger than 0";
262-
const size_t source_num = ids->at(0).lod().at(0).size() - 1;
263-
CHECK_GT(source_num, 0UL) << "source num should be larger than 0";
264-
265-
for (size_t i = 0; i < step_num; ++i) {
266-
CHECK_EQ(ids->at(i).lod().size(), 2UL) << "Level of LodTensor should be 2";
267-
}
268-
269-
//! fixme
270-
// only support float score now
271-
BeamSearchDecodeFunctor func(*ids,
272-
*scores,
273-
sentence_ids,
274-
sentence_scores,
275-
param.beam_size,
276-
param.end_id);
277-
278-
func.apply<float>();
279-
280-
// when decode finish, we clear ids and scores
281-
param.ids->clear();
282-
param.scores->clear();
283-
}
284-
285-
} // namespace arm
286-
} // namespace kernels
287-
} // namespace lite
288-
} // namespace paddle
15+
#include "lite/kernels/host/beam_search_decode_compute.h"
28916

29017
REGISTER_LITE_KERNEL(beam_search_decode,
29118
kARM,
29219
kFloat,
29320
kNCHW,
294-
paddle::lite::kernels::arm::BeamSearchDecodeCompute,
21+
paddle::lite::kernels::host::BeamSearchDecodeCompute,
29522
def)
29623
.BindInput("Ids",
29724
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt64))})

lite/kernels/host/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ add_kernel(tril_triu_compute_host Host extra SRCS tril_triu_compute.cc DEPS ${li
6868
add_kernel(topk_v2_compute_host Host extra SRCS topk_v2_compute.cc DEPS ${lite_kernel_deps})
6969
add_kernel(meshgrid_compute_host Host extra SRCS meshgrid_compute.cc DEPS ${lite_kernel_deps})
7070
add_kernel(linspace_compute_host Host extra SRCS linspace_compute.cc DEPS ${lite_kernel_deps})
71+
add_kernel(beam_search_decode_compute_host Host extra SRCS beam_search_decode_compute.cc DEPS ${lite_kernel_deps})
7172

7273
if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
7374
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)

0 commit comments

Comments
 (0)