1+ #include " rec_constrained_decoding.h"
2+
3+ #include < c10/core/TensorOptions.h>
4+ #include < folly/Unit.h>
5+ #include < folly/futures/Future.h>
6+ #include < glog/logging.h>
7+
8+ #include < algorithm>
9+ #include < filesystem>
10+ #include < fstream>
11+ #include < future>
12+ #include < mutex>
13+
14+ #include " common/global_flags.h"
15+ #include " common/version_singleton.h"
16+ #include " framework/state_dict/rec_vocab_dict.h"
17+ #include " util/slice.h"
18+ #include " util/tensor_helper.h"
19+
20+ namespace xllm {
21+
22+ constexpr float PRE_MASK_FACTOR = -10000 .0f ;
23+ constexpr int GEN_MASK_THREAD_NUM = 16 ;
24+
25+ RecConstrainedDecoding::RecConstrainedDecoding (uint64_t model_version,
26+ const int32_t vocab_size,
27+ torch::ScalarType dtype,
28+ torch::Device device,
29+ bool use_gen_threadpool)
30+ : model_version_(model_version),
31+ vocab_size_ (vocab_size),
32+ dtype_(dtype),
33+ device_(device),
34+ use_gen_threadpool_(use_gen_threadpool) {
35+ if (use_gen_threadpool_) {
36+ gen_threadpool_ = std::make_unique<ThreadPool>(GEN_MASK_THREAD_NUM);
37+ }
38+
39+ build_mask_cache_ = false ;
40+ }
41+
42+ bool RecConstrainedDecoding::build_mask_cache () {
43+ first_token_mask_ = torch::full ({vocab_size_}, PRE_MASK_FACTOR, dtype_);
44+
45+ std::vector<int32_t > empty_token_ids;
46+ Slice<int32_t > prefix_token_ids = {empty_token_ids.data (),
47+ empty_token_ids.size ()};
48+
49+ const std::set<int32_t >& first_token_ids =
50+ VersionSingleton<RecVocabDict>::GetInstance (
51+ std::to_string (model_version_))
52+ ->get_next_tokens_by_prefix_tokens (prefix_token_ids);
53+
54+ for (auto token_id : first_token_ids) {
55+ first_token_mask_[token_id] = 0 ;
56+ }
57+
58+ first_token_mask_ = safe_to (first_token_mask_, device_, true );
59+
60+ build_mask_cache_ = true ;
61+
62+ LOG (INFO) << " build mask cache, first token ids size:"
63+ << first_token_ids.size ();
64+
65+ return true ;
66+ }
67+
68+ torch::Tensor RecConstrainedDecoding::generate_mask (
69+ const std::vector<std::vector<int32_t >>& generated_token_list) {
70+ if (!build_mask_cache_ || 0 == generated_token_list.size ()) {
71+ return torch::Tensor ();
72+ }
73+
74+ size_t token_size = generated_token_list[0 ].size ();
75+
76+ // generate mask for first token
77+ if (0 == token_size) {
78+ size_t sequence_num = generated_token_list.size ();
79+ auto mask = first_token_mask_.unsqueeze (0 );
80+ return mask.repeat ({sequence_num, 1 });
81+ }
82+
83+ // generate mask for non-first token
84+ return generate_decode_mask (generated_token_list);
85+ }
86+
87+ torch::Tensor RecConstrainedDecoding::generate_decode_mask (
88+ const std::vector<std::vector<int32_t >>& generated_token_list) {
89+ size_t sequence_num = generated_token_list.size ();
90+ torch::TensorOptions options = torch::dtype (dtype_).device (device_);
91+ auto mask =
92+ torch::full ({sequence_num, vocab_size_}, PRE_MASK_FACTOR, options);
93+
94+ std::mutex global_batch_mutex;
95+ std::vector<int64_t > global_batch_token_indices;
96+ std::vector<int64_t > global_batch_vocab_indices;
97+
98+ int max_index_num_per_token = 8192 ;
99+ global_batch_token_indices.reserve (max_index_num_per_token * sequence_num);
100+ global_batch_vocab_indices.reserve (max_index_num_per_token * sequence_num);
101+
102+ auto update_mask = [&](size_t start_idx, size_t end_idx) {
103+ std::vector<int64_t > local_token_indices;
104+ std::vector<int64_t > local_vocab_indices;
105+ local_token_indices.reserve (max_index_num_per_token *
106+ (end_idx - start_idx));
107+ local_vocab_indices.reserve (max_index_num_per_token *
108+ (end_idx - start_idx));
109+
110+ for (size_t token_idx = start_idx; token_idx < end_idx; ++token_idx) {
111+ Slice<int32_t > tokens_slice (generated_token_list[token_idx]);
112+
113+ const std::set<int32_t >& next_token_ids =
114+ VersionSingleton<RecVocabDict>::GetInstance (
115+ std::to_string (model_version_))
116+ ->get_next_tokens_by_prefix_tokens (tokens_slice);
117+
118+ if (next_token_ids.size () > 0 ) {
119+ for (int32_t vocab_idx : next_token_ids) {
120+ local_token_indices.push_back (static_cast <int64_t >(token_idx));
121+ local_vocab_indices.push_back (static_cast <int64_t >(vocab_idx));
122+ }
123+ } else {
124+ LOG (ERROR) << " fail to generate mask for tokens:"
125+ << generated_token_list[token_idx];
126+ }
127+ }
128+
129+ // merge local results to global batch (thread-safe)
130+ if (!local_token_indices.empty ()) {
131+ std::lock_guard<std::mutex> lock (global_batch_mutex);
132+ global_batch_token_indices.insert (global_batch_token_indices.end (),
133+ local_token_indices.begin (),
134+ local_token_indices.end ());
135+ global_batch_vocab_indices.insert (global_batch_vocab_indices.end (),
136+ local_vocab_indices.begin (),
137+ local_vocab_indices.end ());
138+ }
139+ };
140+
141+ if (use_gen_threadpool_) {
142+ const size_t batch_size = std::max (
143+ 1UL , (sequence_num + GEN_MASK_THREAD_NUM - 1 ) / GEN_MASK_THREAD_NUM);
144+ const size_t num_batches = (sequence_num + batch_size - 1 ) / batch_size;
145+
146+ std::vector<std::future<void >> futures;
147+ std::vector<std::shared_ptr<std::promise<void >>> promises;
148+
149+ promises.reserve (num_batches);
150+ futures.reserve (num_batches);
151+
152+ for (size_t batch_idx = 0 ; batch_idx < num_batches; ++batch_idx) {
153+ auto promise = std::make_shared<std::promise<void >>();
154+ futures.push_back (promise->get_future ());
155+ promises.push_back (promise);
156+
157+ size_t start_idx = batch_idx * batch_size;
158+ size_t end_idx = std::min (start_idx + batch_size, sequence_num);
159+
160+ gen_threadpool_->schedule (
161+ [update_mask, start_idx, end_idx, promise]() mutable {
162+ update_mask (start_idx, end_idx);
163+ promise->set_value ();
164+ });
165+ }
166+
167+ for (auto & future : futures) {
168+ future.get ();
169+ }
170+ } else {
171+ update_mask (0 , sequence_num);
172+ }
173+
174+ if (!global_batch_token_indices.empty ()) {
175+ auto token_indices =
176+ torch::tensor (global_batch_token_indices, torch::kInt64 );
177+ auto vocab_indices =
178+ torch::tensor (global_batch_vocab_indices, torch::kInt64 );
179+ token_indices = safe_to (token_indices, device_, true );
180+ vocab_indices = safe_to (vocab_indices, device_, true );
181+ mask.index_put_ ({token_indices, vocab_indices}, 0 .0f );
182+ }
183+
184+ return mask;
185+ }
186+ } // namespace xllm
0 commit comments