Skip to content

Commit 39904d1

Browse files
committed
refactor: move draft input preparation of decode batch from worker to batch builder.
1 parent 25e16fa commit 39904d1

File tree

14 files changed

+185
-280
lines changed

14 files changed

+185
-280
lines changed

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,35 @@ BatchInputBuilder::BatchInputBuilder(
7272
ForwardInput BatchInputBuilder::build_forward_input(
7373
uint32_t num_decoding_tokens,
7474
uint32_t min_decoding_batch_size) {
75+
// Since dont test multithreaded for ForwardInput, set thread_pool_ to
76+
// nullptr.
77+
thread_pool_ = nullptr;
7578
process_sequences();
7679
padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size);
7780

7881
return state_to_forward_input();
7982
}
8083

8184
RawForwardInput BatchInputBuilder::build_raw_forward_input() {
82-
if (!thread_pool_ || num_sequences_ < thread_pool_->size()) {
83-
process_sequences();
84-
} else {
85-
process_sequences_multithreaded();
86-
}
85+
process_sequences();
8786
return state_to_raw_forward_input();
8887
}
8988

9089
void BatchInputBuilder::process_sequences() {
91-
for (int32_t i = 0; i < num_sequences_; ++i) {
92-
process_single_sequence(i);
90+
// when speculative decoding, we need to build raw forward input
91+
// of decode batch for MTP (Eagle).
92+
is_mtp_decode_ = false;
93+
if (state_.batch_forward_type.is_decode() &&
94+
FLAGS_num_speculative_tokens > 0) {
95+
is_mtp_decode_ = true;
96+
}
97+
98+
if (thread_pool_ && num_sequences_ >= thread_pool_->size()) {
99+
process_sequences_multithreaded();
100+
} else {
101+
for (int32_t i = 0; i < num_sequences_; ++i) {
102+
process_single_sequence(i);
103+
}
93104
}
94105
}
95106

@@ -275,14 +286,15 @@ void BatchInputBuilder::process_single_sequence(
275286
<< allowed_max_tokens_[seq_index];
276287

277288
// Update state
289+
int32_t offset = is_mtp_decode_ ? -1 : 0;
278290
state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0);
279-
state.max_seq_len = std::max(state.max_seq_len, seq_len);
291+
state.max_seq_len = std::max(state.max_seq_len, seq_len + offset);
280292
state.q_max_seq_len = std::max(state.q_max_seq_len, q_seq_len);
281293
#if defined(USE_NPU)
282-
state.seq_lens.push_back(seq_len);
294+
state.seq_lens.push_back(seq_len + offset);
283295
state.q_seq_lens.push_back(q_seq_len);
284296
#elif defined(USE_MLU) || defined(USE_CUDA)
285-
state.seq_lens.push_back(state.seq_lens.back() + seq_len);
297+
state.seq_lens.push_back(state.seq_lens.back() + seq_len + offset);
286298
state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len);
287299
#endif
288300
// Process tokens and positions
@@ -338,7 +350,8 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
338350
state.flatten_tokens_vec.push_back(token_ids[j]);
339351

340352
if (!use_mrope_) {
341-
state.flatten_positions_vec.push_back(static_cast<int32_t>(j));
353+
int32_t offset = is_mtp_decode_ ? -1 : 0;
354+
state.flatten_positions_vec.push_back(static_cast<int32_t>(j + offset));
342355
}
343356

344357
// Handle sampling for last tokens
@@ -422,6 +435,9 @@ void BatchInputBuilder::setup_kv_cache_info(
422435
// update kv cache tokens num
423436
sequence->kv_state().incr_kv_cache_tokens_num(/*size=*/q_seq_len);
424437

438+
int32_t offset = is_mtp_decode_ ? -1 : 0;
439+
seq_len += offset;
440+
n_kv_cache_tokens += offset;
425441
const auto blocks = sequence->kv_state().kv_blocks();
426442
const auto slot_ids =
427443
sequence->kv_state().kv_cache_slots(n_kv_cache_tokens, seq_len);
@@ -443,6 +459,7 @@ void BatchInputBuilder::setup_kv_cache_info(
443459
(seq_len % block_size == 0) ? block_size : seq_len % block_size;
444460
state.paged_kv_last_page_len.push_back(last_page_len);
445461

462+
// calculate the block ids that need to be written
446463
int32_t kv_cache_block_idx = n_kv_cache_tokens / block_size;
447464
for (auto iter = block_ids.begin() + kv_cache_block_idx;
448465
iter != block_ids.end();

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ class BatchInputBuilder {
161161
// thread pool for multithreaded processing, not owned
162162
ThreadPool* thread_pool_ = nullptr;
163163
uint64_t batch_id_;
164+
165+
// whether prepare draft input for MTP(EAGLE) at Decode phase.
166+
bool is_mtp_decode_ = false;
164167
};
165168

166169
} // namespace xllm

xllm/core/framework/request/sequence.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ void Sequence::update_embeddings(const torch::Tensor& embeddings) {
230230
if (output_embedding_.dim() == 1) {
231231
output_embedding_ = output_embedding_.unsqueeze(0);
232232
}
233-
mm_data_ = MMData(MMType::EMBEDDING, {{"embedding", output_embedding_}});
234233
}
235234
}
236235

xllm/core/runtime/acl_graph_executor_impl.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,9 @@ void GraphPersistentParam::update(const torch::Tensor& tokens,
160160
slice_persistent_block_tables.copy_(params.block_tables,
161161
/*non_blocking=*/true);
162162

163-
// Update persistent embedding from mm_data if available
164-
const auto& embedding_res = params.mm_data.get<torch::Tensor>("embedding");
165-
if (embedding_res) {
166-
const torch::Tensor& embedding = embedding_res.value();
163+
// Update persistent embedding from input_embedding if available
164+
const auto& embedding = params.input_embedding;
165+
if (embedding.defined()) {
167166
const int64_t embedding_tokens = embedding.size(0);
168167

169168
// Initialize persistent_embedding_ if needed and not already initialized
@@ -643,17 +642,12 @@ bool AclGraph::capture(CausalLM* model,
643642
graph_params.graph_buffer.tiling_data = persistent_param_.tiling_data();
644643

645644
// Set persistent embedding if available and original input has embedding
646-
const auto& original_embedding =
647-
params.mm_data.get<torch::Tensor>("embedding");
648-
if (original_embedding.has_value()) {
645+
const auto& original_embedding = params.input_embedding;
646+
if (original_embedding.defined()) {
649647
torch::Tensor persistent_embedding =
650648
persistent_param_.persistent_embedding(num_tokens_);
651649
if (persistent_embedding.numel() > 0) {
652-
// graph_params.input_embedding = persistent_embedding;
653-
// Replace embedding in mm_data with persistent embedding using update
654-
// method
655-
graph_params.mm_data.update<torch::Tensor>(
656-
MMType::EMBEDDING, "embedding", persistent_embedding);
650+
graph_params.input_embedding = persistent_embedding;
657651
}
658652
}
659653

xllm/core/runtime/forward_params.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ struct ForwardInput {
100100
inputs.acc_logprob = safe_to(acc_logprob, device, true);
101101
return inputs;
102102
}
103+
104+
void print() const {
105+
LOG(INFO) << " token_ids: " << token_ids << std::endl;
106+
LOG(INFO) << " positions: " << positions << std::endl;
107+
input_params.print();
108+
LOG(INFO) << " params.selected_token_idxes "
109+
<< sampling_params.selected_token_idxes;
110+
LOG(INFO) << " params.sample_idxes " << sampling_params.sample_idxes;
111+
LOG(INFO) << " params.do_sample " << sampling_params.do_sample;
112+
}
113+
103114
// flatten token ids
104115
torch::Tensor token_ids;
105116
// flatten positions

xllm/core/runtime/llm_worker_impl.cpp

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -170,34 +170,12 @@ std::optional<ForwardOutput> LLMWorkerImpl::step(const ForwardInput& input) {
170170
output.beam_search_output = beam_search_output;
171171
}
172172

173-
// if running in multi_stream_parallel step, all micro batches
174-
// should be in same prefill stage, so, to judge empty_kv_cache,
175-
// just use micro batch 0 here
176-
if (options_.enable_speculative_decode() && !is_spec_draft_) {
177-
if (check_is_prefill(input.input_params.q_seq_lens_vec)) {
173+
if (options_.enable_speculative_decode()) {
174+
if (!input.input_params.batch_forward_type.is_decode() && !is_spec_draft_) {
178175
output.sample_output.embeddings = hidden_states;
179-
} else if (sampling_params.sample_idxes.defined()) {
180-
// auto sample_idxes =
181-
// concated_sampling_params.selected_token_idxes.index_select(
182-
// /*dim=*/0, concated_sampling_params.sample_idxes);
176+
} else if (sampling_params.selected_token_idxes.defined()) {
183177
auto embeddings = hidden_states.index_select(
184-
/*dim=*/0, sampling_params.sample_idxes);
185-
output.sample_output.embeddings = embeddings;
186-
}
187-
}
188-
189-
// if running in multi_stream_parallel step, all micro batches
190-
// should be in same prefill stage, so, to judge empty_kv_cache,
191-
// just use micro batch 0 here
192-
if (options_.enable_speculative_decode() && !is_spec_draft_) {
193-
if (input.input_params.q_seq_lens_vec[0] > 1) {
194-
output.sample_output.embeddings = hidden_states;
195-
} else if (sampling_params.sample_idxes.defined()) {
196-
// auto sample_idxes =
197-
// concated_sampling_params.selected_token_idxes.index_select(
198-
// /*dim=*/0, concated_sampling_params.sample_idxes);
199-
auto embeddings = hidden_states.index_select(
200-
/*dim=*/0, sampling_params.sample_idxes);
178+
/*dim=*/0, sampling_params.selected_token_idxes);
201179
output.sample_output.embeddings = embeddings;
202180
}
203181
}

xllm/core/runtime/params_utils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input,
231231
}
232232
torch::Tensor embeddings =
233233
create_2d_tensor(embeddings_vec, torch::kBFloat16);
234-
input_params.mm_data =
235-
MMData(MMType::EMBEDDING, {{"embedding", embeddings}});
234+
input_params.input_embedding = embeddings;
236235
}
237236

238237
CHECK_EQ(sampling_params.size(), selected_token_idxes.size());

0 commit comments

Comments
 (0)