Skip to content

Commit 0e9219d

Browse files
committed
feat: add constrained decoding for generative recommendation.
1 parent 257c867 commit 0e9219d

File tree

4 files changed

+278
-0
lines changed

4 files changed

+278
-0
lines changed

xllm/core/framework/sampling/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ cc_library(
1010
rejection_sampler.h
1111
sampler.h
1212
beam_searcher.h
13+
rec_constrained_decoding.h
1314
SRCS
1415
sampling_params.cpp
1516
logits_utils.cpp
1617
rejection_sampler.cpp
1718
sampler.cpp
1819
beam_searcher.cpp
20+
rec_constrained_decoding.cpp
1921
DEPS
2022
:common
2123
glog::glog
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
#include <c10/core/TensorOptions.h>
18+
#include <torch/torch.h>
19+
#include <torch/types.h>
20+
21+
namespace xllm {
22+
23+
// constrained decoding is used to ensure that the generated content
24+
// conforms to specific formats or rules.
25+
class ConstrainedDecoding {
26+
public:
27+
virtual ~ConstrainedDecoding();
28+
29+
virtual bool build_mask_cache();
30+
31+
// input generated_token_list: [sequence_num][generated_token_ids]
32+
// output: mask tensor[sequence_num,vocab_size]
33+
virtual torch::Tensor generate_mask(
34+
const std::vector<std::vector<int32_t>>& generated_token_list);
35+
};
36+
} // namespace xllm
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#include "rec_constrained_decoding.h"
2+
3+
#include <c10/core/TensorOptions.h>
4+
#include <folly/Unit.h>
5+
#include <folly/futures/Future.h>
6+
#include <glog/logging.h>
7+
8+
#include <algorithm>
9+
#include <filesystem>
10+
#include <fstream>
11+
#include <future>
12+
#include <mutex>
13+
14+
#include "common/global_flags.h"
15+
#include "common/version_singleton.h"
16+
#include "framework/state_dict/rec_vocab_dict.h"
17+
#include "util/slice.h"
18+
#include "util/tensor_helper.h"
19+
20+
namespace xllm {
21+
22+
constexpr float PRE_MASK_FACTOR = -10000.0f;
23+
constexpr int GEN_MASK_THREAD_NUM = 16;
24+
25+
RecConstrainedDecoding::RecConstrainedDecoding(uint64_t model_version,
26+
const int32_t vocab_size,
27+
torch::ScalarType dtype,
28+
torch::Device device,
29+
bool use_gen_threadpool)
30+
: model_version_(model_version),
31+
vocab_size_(vocab_size),
32+
dtype_(dtype),
33+
device_(device),
34+
use_gen_threadpool_(use_gen_threadpool) {
35+
if (use_gen_threadpool_) {
36+
gen_threadpool_ = std::make_unique<ThreadPool>(GEN_MASK_THREAD_NUM);
37+
}
38+
39+
build_mask_cache_ = false;
40+
}
41+
42+
bool RecConstrainedDecoding::build_mask_cache() {
43+
first_token_mask_ = torch::full({vocab_size_}, PRE_MASK_FACTOR, dtype_);
44+
45+
std::vector<int32_t> empty_token_ids;
46+
Slice<int32_t> prefix_token_ids = {empty_token_ids.data(),
47+
empty_token_ids.size()};
48+
49+
const std::set<int32_t>& first_token_ids =
50+
VersionSingleton<RecVocabDict>::GetInstance(
51+
std::to_string(model_version_))
52+
->get_next_tokens_by_prefix_tokens(prefix_token_ids);
53+
54+
for (auto token_id : first_token_ids) {
55+
first_token_mask_[token_id] = 0;
56+
}
57+
58+
first_token_mask_ = safe_to(first_token_mask_, device_, true);
59+
60+
build_mask_cache_ = true;
61+
62+
LOG(INFO) << "build mask cache, first token ids size:"
63+
<< first_token_ids.size();
64+
65+
return true;
66+
}
67+
68+
torch::Tensor RecConstrainedDecoding::generate_mask(
69+
const std::vector<std::vector<int32_t>>& generated_token_list) {
70+
if (!build_mask_cache_ || 0 == generated_token_list.size()) {
71+
return torch::Tensor();
72+
}
73+
74+
size_t token_size = generated_token_list[0].size();
75+
76+
// generate mask for first token
77+
if (0 == token_size) {
78+
size_t sequence_num = generated_token_list.size();
79+
auto mask = first_token_mask_.unsqueeze(0);
80+
return mask.repeat({sequence_num, 1});
81+
}
82+
83+
// generate mask for non-first token
84+
return generate_decode_mask(generated_token_list);
85+
}
86+
87+
torch::Tensor RecConstrainedDecoding::generate_decode_mask(
88+
const std::vector<std::vector<int32_t>>& generated_token_list) {
89+
size_t sequence_num = generated_token_list.size();
90+
torch::TensorOptions options = torch::dtype(dtype_).device(device_);
91+
auto mask =
92+
torch::full({sequence_num, vocab_size_}, PRE_MASK_FACTOR, options);
93+
94+
std::mutex global_batch_mutex;
95+
std::vector<int64_t> global_batch_token_indices;
96+
std::vector<int64_t> global_batch_vocab_indices;
97+
98+
int max_index_num_per_token = 8192;
99+
global_batch_token_indices.reserve(max_index_num_per_token * sequence_num);
100+
global_batch_vocab_indices.reserve(max_index_num_per_token * sequence_num);
101+
102+
auto update_mask = [&](size_t start_idx, size_t end_idx) {
103+
std::vector<int64_t> local_token_indices;
104+
std::vector<int64_t> local_vocab_indices;
105+
local_token_indices.reserve(max_index_num_per_token *
106+
(end_idx - start_idx));
107+
local_vocab_indices.reserve(max_index_num_per_token *
108+
(end_idx - start_idx));
109+
110+
for (size_t token_idx = start_idx; token_idx < end_idx; ++token_idx) {
111+
Slice<int32_t> tokens_slice(generated_token_list[token_idx]);
112+
113+
const std::set<int32_t>& next_token_ids =
114+
VersionSingleton<RecVocabDict>::GetInstance(
115+
std::to_string(model_version_))
116+
->get_next_tokens_by_prefix_tokens(tokens_slice);
117+
118+
if (next_token_ids.size() > 0) {
119+
for (int32_t vocab_idx : next_token_ids) {
120+
local_token_indices.push_back(static_cast<int64_t>(token_idx));
121+
local_vocab_indices.push_back(static_cast<int64_t>(vocab_idx));
122+
}
123+
} else {
124+
LOG(ERROR) << "fail to generate mask for tokens:"
125+
<< generated_token_list[token_idx];
126+
}
127+
}
128+
129+
// merge local results to global batch (thread-safe)
130+
if (!local_token_indices.empty()) {
131+
std::lock_guard<std::mutex> lock(global_batch_mutex);
132+
global_batch_token_indices.insert(global_batch_token_indices.end(),
133+
local_token_indices.begin(),
134+
local_token_indices.end());
135+
global_batch_vocab_indices.insert(global_batch_vocab_indices.end(),
136+
local_vocab_indices.begin(),
137+
local_vocab_indices.end());
138+
}
139+
};
140+
141+
if (use_gen_threadpool_) {
142+
const size_t batch_size = std::max(
143+
1UL, (sequence_num + GEN_MASK_THREAD_NUM - 1) / GEN_MASK_THREAD_NUM);
144+
const size_t num_batches = (sequence_num + batch_size - 1) / batch_size;
145+
146+
std::vector<std::future<void>> futures;
147+
std::vector<std::shared_ptr<std::promise<void>>> promises;
148+
149+
promises.reserve(num_batches);
150+
futures.reserve(num_batches);
151+
152+
for (size_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
153+
auto promise = std::make_shared<std::promise<void>>();
154+
futures.push_back(promise->get_future());
155+
promises.push_back(promise);
156+
157+
size_t start_idx = batch_idx * batch_size;
158+
size_t end_idx = std::min(start_idx + batch_size, sequence_num);
159+
160+
gen_threadpool_->schedule(
161+
[update_mask, start_idx, end_idx, promise]() mutable {
162+
update_mask(start_idx, end_idx);
163+
promise->set_value();
164+
});
165+
}
166+
167+
for (auto& future : futures) {
168+
future.get();
169+
}
170+
} else {
171+
update_mask(0, sequence_num);
172+
}
173+
174+
if (!global_batch_token_indices.empty()) {
175+
auto token_indices =
176+
torch::tensor(global_batch_token_indices, torch::kInt64);
177+
auto vocab_indices =
178+
torch::tensor(global_batch_vocab_indices, torch::kInt64);
179+
token_indices = safe_to(token_indices, device_, true);
180+
vocab_indices = safe_to(vocab_indices, device_, true);
181+
mask.index_put_({token_indices, vocab_indices}, 0.0f);
182+
}
183+
184+
return mask;
185+
}
186+
} // namespace xllm
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
#include <torch/torch.h>
18+
#include <torch/types.h>
19+
20+
#include "constrained_decoding.h"
21+
#include "util/threadpool.h"
22+
23+
namespace xllm {
24+
25+
class RecConstrainedDecoding : public ConstrainedDecoding {
26+
public:
27+
RecConstrainedDecoding(uint64_t model_version,
28+
const int32_t vocab_size,
29+
torch::ScalarType dtype,
30+
torch::Device device,
31+
bool use_gen_threadpool_ = true);
32+
virtual ~RecConstrainedDecoding() = default;
33+
34+
bool build_mask_cache() override;
35+
36+
torch::Tensor generate_mask(
37+
const std::vector<std::vector<int32_t>>& generated_token_list) override;
38+
39+
private:
40+
torch::Tensor generate_decode_mask(
41+
const std::vector<std::vector<int32_t>>& generated_token_list);
42+
43+
private:
44+
bool build_mask_cache_;
45+
bool use_gen_threadpool_;
46+
int32_t vocab_size_;
47+
uint64_t model_version_;
48+
torch::Device device_;
49+
torch::ScalarType dtype_;
50+
torch::Tensor first_token_mask_;
51+
std::unique_ptr<ThreadPool> gen_threadpool_;
52+
};
53+
54+
} // namespace xllm

0 commit comments

Comments
 (0)