|
12 | 12 | // See the License for the specific language governing permissions and |
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | | -#include "lite/kernels/arm/beam_search_decode_compute.h" |
16 | | -#include <algorithm> |
17 | | -#include <vector> |
18 | | -#include "lite/api/paddle_place.h" |
19 | | -#include "lite/backends/arm/math/funcs.h" |
20 | | -#include "lite/core/op_registry.h" |
21 | | -#include "lite/core/tensor.h" |
22 | | -#include "lite/core/type_system.h" |
23 | | - |
24 | | -namespace paddle { |
25 | | -namespace lite { |
26 | | -namespace kernels { |
27 | | -namespace arm { |
28 | | - |
29 | | -using LoDTensor = lite::Tensor; |
30 | | -using LoDTensorArray = std::vector<lite::Tensor>; |
31 | | - |
32 | | -// all the lod have 2 levels. |
33 | | -// The first is source level, the second is sentence level. |
34 | | -// source level describe how many prefixes (branchs) for each source sentece |
35 | | -// (beam). sentence level describe how these candidates belong to the prefixes. |
36 | | -const size_t kSourceLevel = 0; |
37 | | -const size_t kSentenceLevel = 1; |
38 | | - |
39 | | -template <typename T> |
40 | | -struct Sentence { |
41 | | - std::vector<int64_t> word_ids; |
42 | | - std::vector<T> scores; |
43 | | -}; |
44 | | - |
45 | | -template <typename T> |
46 | | -using SentenceVector = std::vector<Sentence<T>>; |
47 | | - |
48 | | -template <typename T> |
49 | | -struct BeamSearchDecoder { |
50 | | - BeamSearchDecoder(size_t beam_size, int end_id) |
51 | | - : beam_size_(beam_size), end_id_(end_id) {} |
52 | | - |
53 | | - /** |
54 | | - * convert the result sentence_vector for each source sentence into two |
55 | | - * LodTensor. |
56 | | - * One is all candidate sentences with word id, one is all candidate sentences |
57 | | - * with word score. |
58 | | - * Param: |
59 | | - * sentence_vector_list: sentence_vector for each source sentence. |
60 | | - * id_tensor: result LoDTensor for sentences of id. |
61 | | - * score_tensor: result LoDTensor for sentences of score. |
62 | | - * reverse: whether ids of sentence in sentence_vector_list is reversed |
63 | | - * sort_by_score: whether to sort hypotheses of each sentence by scores. |
64 | | - */ |
65 | | - void ConvertSentenceVectorToLodTensor( |
66 | | - std::vector<SentenceVector<T>> sentence_vector_list, |
67 | | - LoDTensor* id_tensor, |
68 | | - LoDTensor* score_tensor, |
69 | | - bool reverse = true, |
70 | | - bool sort_by_score = true) const { |
71 | | - size_t src_num = sentence_vector_list.size(); |
72 | | - CHECK_GT(src_num, 0) << "src_num should not be 0"; |
73 | | - |
74 | | - std::vector<uint64_t> source_level_lod = {0}; |
75 | | - std::vector<uint64_t> sentence_level_lod = {0}; |
76 | | - std::vector<int64_t> id_data; |
77 | | - std::vector<T> score_data; |
78 | | - |
79 | | - for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { |
80 | | - if (sort_by_score) { |
81 | | - std::stable_sort(sentence_vector_list[src_idx].begin(), |
82 | | - sentence_vector_list[src_idx].end(), |
83 | | - [reverse](const Sentence<T>& a, const Sentence<T>& b) { |
84 | | - if (reverse) |
85 | | - return a.scores.front() > b.scores.front(); |
86 | | - else |
87 | | - return a.scores.back() > b.scores.back(); |
88 | | - }); |
89 | | - } |
90 | | - for (Sentence<T>& sentence : sentence_vector_list[src_idx]) { |
91 | | - if (reverse) { |
92 | | - id_data.insert(id_data.end(), |
93 | | - sentence.word_ids.rbegin(), |
94 | | - sentence.word_ids.rend()); |
95 | | - score_data.insert(score_data.end(), |
96 | | - sentence.scores.rbegin(), |
97 | | - sentence.scores.rend()); |
98 | | - } else { |
99 | | - id_data.insert(id_data.end(), |
100 | | - sentence.word_ids.begin(), |
101 | | - sentence.word_ids.end()); |
102 | | - score_data.insert( |
103 | | - score_data.end(), sentence.scores.begin(), sentence.scores.end()); |
104 | | - } |
105 | | - |
106 | | - sentence_level_lod.push_back(sentence_level_lod.back() + |
107 | | - sentence.word_ids.size()); |
108 | | - } |
109 | | - source_level_lod.push_back(source_level_lod.back() + |
110 | | - sentence_vector_list[src_idx].size()); |
111 | | - } |
112 | | - |
113 | | - LoD lod; |
114 | | - lod.push_back(source_level_lod); |
115 | | - lod.push_back(sentence_level_lod); |
116 | | - |
117 | | - id_tensor->set_lod(lod); |
118 | | - |
119 | | - id_tensor->Resize({static_cast<int64_t>(id_data.size())}); |
120 | | - auto id_ptr = id_tensor->mutable_data<int64_t>(); |
121 | | - TargetCopy( |
122 | | - TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(int64_t)); |
123 | | - |
124 | | - score_tensor->set_lod(lod); |
125 | | - score_tensor->Resize({static_cast<int64_t>(score_data.size())}); |
126 | | - auto score_ptr = score_tensor->mutable_data<T>(); |
127 | | - TargetCopy(TARGET(kARM), |
128 | | - score_ptr, |
129 | | - score_data.data(), |
130 | | - score_data.size() * sizeof(T)); |
131 | | - } |
132 | | - |
133 | | - /** |
134 | | - * Gather the hypotheses for each source sentence by backtrace though the |
135 | | - * LoDTensorArray step_ids whose lods reserve the path in the tree. |
136 | | - */ |
137 | | - void Backtrace(const LoDTensorArray& step_ids, |
138 | | - const LoDTensorArray& step_scores, |
139 | | - LoDTensor* id_tensor, |
140 | | - LoDTensor* score_tensor) const { |
141 | | - CHECK(!step_ids.empty()) << "step num should be larger than 0"; |
142 | | - CHECK_EQ(step_ids.size(), step_scores.size()) |
143 | | - << "step_ids and step_scores should be the same"; |
144 | | - const size_t step_num = step_ids.size(); |
145 | | - const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1; |
146 | | - std::vector<SentenceVector<T>> sentence_vector_list( |
147 | | - src_num, SentenceVector<T>(beam_size_)); |
148 | | - std::vector<std::vector<size_t>> prefix_idx_vector_list(src_num); |
149 | | - for (int step_id = step_num - 1; step_id >= 0; --step_id) { |
150 | | - auto& cur_ids = step_ids.at(step_id); |
151 | | - auto& cur_scores = step_scores.at(step_id); |
152 | | - for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { |
153 | | - // for each source sentence |
154 | | - auto& sentence_vector = sentence_vector_list.at(src_idx); |
155 | | - auto& prefix_idx_vector = prefix_idx_vector_list.at(src_idx); |
156 | | - size_t src_prefix_start = cur_ids.lod().at(kSourceLevel)[src_idx]; |
157 | | - size_t src_prefix_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1]; |
158 | | - if (prefix_idx_vector.empty()) { // be finished and pruned at this step |
159 | | - // or the last time step |
160 | | - for (size_t prefix_idx = src_prefix_start; |
161 | | - prefix_idx < src_prefix_end; |
162 | | - ++prefix_idx) { |
163 | | - size_t candidate_start = |
164 | | - cur_ids.lod().at(kSentenceLevel)[prefix_idx]; |
165 | | - size_t candidate_end = |
166 | | - cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1]; |
167 | | - for (size_t candidate_idx = candidate_start; |
168 | | - candidate_idx < candidate_end; |
169 | | - ++candidate_idx) { |
170 | | - prefix_idx_vector.push_back(prefix_idx); |
171 | | - size_t idx = prefix_idx_vector.size() - 1; |
172 | | - auto cur_id = cur_ids.data<int64_t>()[candidate_idx]; |
173 | | - auto cur_score = cur_scores.data<T>()[candidate_idx]; |
174 | | - sentence_vector.at(idx).word_ids.push_back(cur_id); |
175 | | - sentence_vector.at(idx).scores.push_back(cur_score); |
176 | | - } |
177 | | - } |
178 | | - } else { // use prefix_idx_vector to backtrace |
179 | | - size_t src_candidate_start = |
180 | | - cur_ids.lod().at(kSentenceLevel)[src_prefix_start]; |
181 | | - size_t prefix_idx = src_prefix_start; |
182 | | - size_t candidate_num = |
183 | | - cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] - |
184 | | - cur_ids.lod().at(kSentenceLevel)[prefix_idx]; |
185 | | - for (size_t idx = 0; idx < prefix_idx_vector.size(); ++idx) { |
186 | | - auto candidate_idx = prefix_idx_vector.at(idx); |
187 | | - auto cur_id = cur_ids.data<int64_t>()[candidate_idx]; |
188 | | - auto cur_score = cur_scores.data<T>()[candidate_idx]; |
189 | | - if (cur_id != end_id_ || sentence_vector.at(idx).word_ids.empty()) { |
190 | | - // to skip redundant end tokens |
191 | | - sentence_vector.at(idx).word_ids.push_back(cur_id); |
192 | | - sentence_vector.at(idx).scores.push_back(cur_score); |
193 | | - } |
194 | | - |
195 | | - while (src_candidate_start + candidate_num <= |
196 | | - candidate_idx) { // search the corresponding prefix |
197 | | - prefix_idx++; |
198 | | - candidate_num += |
199 | | - cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] - |
200 | | - cur_ids.lod().at(kSentenceLevel)[prefix_idx]; |
201 | | - } |
202 | | - prefix_idx_vector.at(idx) = prefix_idx; |
203 | | - } |
204 | | - } |
205 | | - } |
206 | | - } |
207 | | - |
208 | | - ConvertSentenceVectorToLodTensor( |
209 | | - sentence_vector_list, id_tensor, score_tensor, true, true); |
210 | | - } |
211 | | - |
212 | | - size_t beam_size_; |
213 | | - int end_id_; |
214 | | -}; |
215 | | - |
216 | | -struct BeamSearchDecodeFunctor { |
217 | | - BeamSearchDecodeFunctor(const LoDTensorArray& step_ids, |
218 | | - const LoDTensorArray& step_scores, |
219 | | - LoDTensor* id_tensor, |
220 | | - LoDTensor* score_tensor, |
221 | | - size_t beam_size, |
222 | | - int end_id) |
223 | | - : beam_size_(beam_size), |
224 | | - end_id_(end_id), |
225 | | - step_ids_(step_ids), |
226 | | - step_scores_(step_scores), |
227 | | - id_tensor_(id_tensor), |
228 | | - score_tensor_(score_tensor) {} |
229 | | - |
230 | | - template <typename T> |
231 | | - void apply() const { |
232 | | - BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_); |
233 | | - beam_search_decoder.Backtrace( |
234 | | - step_ids_, step_scores_, id_tensor_, score_tensor_); |
235 | | - } |
236 | | - |
237 | | - size_t beam_size_; |
238 | | - int end_id_; |
239 | | - const LoDTensorArray& step_ids_; |
240 | | - const LoDTensorArray& step_scores_; |
241 | | - LoDTensor* id_tensor_; |
242 | | - LoDTensor* score_tensor_; |
243 | | -}; |
244 | | - |
245 | | -template <> |
246 | | -void BeamSearchDecodeFunctor::apply<bool>() const { |
247 | | - LOG(FATAL) << "beam search decode op does not support bool!"; |
248 | | -} |
249 | | - |
250 | | -void BeamSearchDecodeCompute::Run() { |
251 | | - auto& param = this->Param<param_t>(); |
252 | | - auto& ctx = this->ctx_->template As<ARMContext>(); |
253 | | - // inputs |
254 | | - auto ids = param.ids; |
255 | | - auto scores = param.scores; |
256 | | - // outputs |
257 | | - auto sentence_ids = param.sentence_ids; |
258 | | - auto sentence_scores = param.sentence_scores; |
259 | | - |
260 | | - const size_t step_num = ids->size(); |
261 | | - CHECK_GT(step_num, 0UL) << "beam search steps should be larger than 0"; |
262 | | - const size_t source_num = ids->at(0).lod().at(0).size() - 1; |
263 | | - CHECK_GT(source_num, 0UL) << "source num should be larger than 0"; |
264 | | - |
265 | | - for (size_t i = 0; i < step_num; ++i) { |
266 | | - CHECK_EQ(ids->at(i).lod().size(), 2UL) << "Level of LodTensor should be 2"; |
267 | | - } |
268 | | - |
269 | | - //! fixme |
270 | | - // only support float score now |
271 | | - BeamSearchDecodeFunctor func(*ids, |
272 | | - *scores, |
273 | | - sentence_ids, |
274 | | - sentence_scores, |
275 | | - param.beam_size, |
276 | | - param.end_id); |
277 | | - |
278 | | - func.apply<float>(); |
279 | | - |
280 | | - // when decode finish, we clear ids and scores |
281 | | - param.ids->clear(); |
282 | | - param.scores->clear(); |
283 | | -} |
284 | | - |
285 | | -} // namespace arm |
286 | | -} // namespace kernels |
287 | | -} // namespace lite |
288 | | -} // namespace paddle |
| 15 | +#include "lite/kernels/host/beam_search_decode_compute.h" |
289 | 16 |
|
290 | 17 | REGISTER_LITE_KERNEL(beam_search_decode, |
291 | 18 | kARM, |
292 | 19 | kFloat, |
293 | 20 | kNCHW, |
294 | | - paddle::lite::kernels::arm::BeamSearchDecodeCompute, |
| 21 | + paddle::lite::kernels::host::BeamSearchDecodeCompute, |
295 | 22 | def) |
296 | 23 | .BindInput("Ids", |
297 | 24 | {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt64))}) |
|
0 commit comments