@@ -24,8 +24,18 @@ namespace operators {
2424void BeamSearch::operator ()(const framework::LoDTensor &pre_ids,
2525 framework::LoDTensor *selected_ids,
2626 framework::LoDTensor *selected_scores) {
27+ auto abs_lod = framework::ToAbsOffset (ids_->lod ());
28+ auto &high_level = abs_lod[lod_level_];
29+
2730 auto items = SelectTopBeamSizeItems ();
28- auto selected_items = ToMap (items);
31+ auto selected_items = ToMap (items, high_level.back ());
32+ VLOG (3 ) << " selected_items:" ;
33+ for (size_t i = 0 ; i < selected_items.size (); ++i) {
34+ VLOG (3 ) << " offset:" << i;
35+ for (auto &item : selected_items[i]) {
36+ VLOG (3 ) << ItemToString (item);
37+ }
38+ }
2939 PruneEndidCandidates (pre_ids, &selected_items);
3040 // calculate the output tensor's height
3141 size_t num_instances = std::accumulate (
@@ -63,11 +73,12 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
6373 low_level.push_back (low_offset);
6474
6575 // fill lod
66- auto abs_lod = framework::ToAbsOffset (ids_->lod ());
67- auto &high_level = abs_lod[lod_level_];
6876 framework::LoD lod (2 );
6977 lod[0 ].assign (high_level.begin (), high_level.end ());
7078 lod[1 ].assign (low_level.begin (), low_level.end ());
79+ if (!framework::CheckLoD (lod)) {
80+ PADDLE_THROW (" lod %s is not right" , framework::LoDToString (lod));
81+ }
7182 selected_ids->set_lod (lod);
7283 selected_scores->set_lod (lod);
7384}
@@ -90,13 +101,11 @@ int BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids,
90101}
91102
92103std::vector<std::vector<BeamSearch::Item>> BeamSearch::ToMap (
93- const std::vector<std::vector<Item>> &items) {
104+ const std::vector<std::vector<Item>> &items, size_t element_num ) {
94105 std::vector<std::vector<Item>> result;
106+ result.resize (element_num);
95107 for (auto &entries : items) {
96108 for (const auto &item : entries) {
97- if (item.offset >= result.size ()) {
98- result.resize (item.offset + 1 );
99- }
100109 result[item.offset ].push_back (item);
101110 }
102111 }
@@ -122,6 +131,14 @@ BeamSearch::SelectTopBeamSizeItems() {
122131 }
123132 result.emplace_back (items);
124133 }
134+ VLOG (3 ) << " SelectTopBeamSizeItems result size " << result.size ();
135+ for (auto &items : result) {
136+ VLOG (3 ) << " item set:" ;
137+ for (auto &item : items) {
138+ VLOG (3 ) << ItemToString (item);
139+ }
140+ }
141+
125142 return result;
126143}
127144
@@ -159,6 +176,22 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
159176 return true ;
160177}
161178
179+ std::ostream &operator <<(std::ostream &os, const BeamSearch::Item &item) {
180+ os << " {" ;
181+ os << " offset: " << item.offset << " , " ;
182+ os << " id: " << item.id << " , " ;
183+ os << " score: " << item.score << " " ;
184+ os << " }" ;
185+
186+ return os;
187+ }
188+
189+ std::string ItemToString (const BeamSearch::Item &item) {
190+ std::ostringstream stream;
191+ stream << item;
192+ return stream.str ();
193+ }
194+
162195class BeamSearchProtoAndCheckerMaker
163196 : public framework::OpProtoAndCheckerMaker {
164197 public:
@@ -186,8 +219,40 @@ class BeamSearchProtoAndCheckerMaker
186219 }
187220};
188221
222+ class BeamSearchInferShape : public framework ::InferShapeBase {
223+ public:
224+ void operator ()(framework::InferShapeContext *context) const override {
225+ for (const std::string &arg :
226+ std::vector<std::string>({" pre_ids" , " ids" , " scores" })) {
227+ PADDLE_ENFORCE (context->HasInput (arg),
228+ " BeamSearch need input argument '%s'" , arg);
229+ }
230+ for (const std::string &arg :
231+ std::vector<std::string>({" selected_ids" , " selected_scores" })) {
232+ PADDLE_ENFORCE (context->HasOutput (arg),
233+ " BeamSearch need output argument '%s'" , arg);
234+ }
235+ }
236+ };
237+
238+ class BeamSearchInferVarType : public framework ::VarTypeInference {
239+ public:
240+ void operator ()(const framework::OpDesc &op_desc,
241+ framework::BlockDesc *block) const override {
242+ for (auto &o : op_desc.Output (" selected_ids" )) {
243+ block->Var (o)->SetType (framework::proto::VarDesc::LOD_TENSOR);
244+ }
245+ for (auto &o : op_desc.Output (" selected_scores" )) {
246+ block->Var (o)->SetType (framework::proto::VarDesc::LOD_TENSOR);
247+ }
248+ }
249+ };
250+
189251} // namespace operators
190252} // namespace paddle
191253
192- REGISTER_OP_WITHOUT_GRADIENT (beam_search, paddle::operators::BeamSearchOp,
193- paddle::operators::BeamSearchProtoAndCheckerMaker);
254+ REGISTER_OPERATOR (beam_search, paddle::operators::BeamSearchOp,
255+ paddle::operators::BeamSearchProtoAndCheckerMaker,
256+ paddle::operators::BeamSearchInferShape,
257+ paddle::operators::BeamSearchInferVarType,
258+ paddle::framework::EmptyGradOpMaker);
0 commit comments