@@ -72,24 +72,35 @@ BatchInputBuilder::BatchInputBuilder(
7272ForwardInput BatchInputBuilder::build_forward_input (
7373 uint32_t num_decoding_tokens,
7474 uint32_t min_decoding_batch_size) {
75+ // Since dont test multithreaded for ForwardInput, set thread_pool_ to
76+ // nullptr.
77+ thread_pool_ = nullptr ;
7578 process_sequences ();
7679 padding_decode_batch_size (num_decoding_tokens, min_decoding_batch_size);
7780
7881 return state_to_forward_input ();
7982}
8083
8184RawForwardInput BatchInputBuilder::build_raw_forward_input () {
82- if (!thread_pool_ || num_sequences_ < thread_pool_->size ()) {
83- process_sequences ();
84- } else {
85- process_sequences_multithreaded ();
86- }
85+ process_sequences ();
8786 return state_to_raw_forward_input ();
8887}
8988
9089void BatchInputBuilder::process_sequences () {
91- for (int32_t i = 0 ; i < num_sequences_; ++i) {
92- process_single_sequence (i);
90+ // when speculative decoding, we need to build raw forward input
91+ // of decode batch for MTP (Eagle).
92+ is_mtp_decode_ = false ;
93+ if (state_.batch_forward_type .is_decode () &&
94+ FLAGS_num_speculative_tokens > 0 ) {
95+ is_mtp_decode_ = true ;
96+ }
97+
98+ if (thread_pool_ && num_sequences_ >= thread_pool_->size ()) {
99+ process_sequences_multithreaded ();
100+ } else {
101+ for (int32_t i = 0 ; i < num_sequences_; ++i) {
102+ process_single_sequence (i);
103+ }
93104 }
94105}
95106
@@ -275,14 +286,15 @@ void BatchInputBuilder::process_single_sequence(
275286 << allowed_max_tokens_[seq_index];
276287
277288 // Update state
289+ int32_t offset = is_mtp_decode_ ? -1 : 0 ;
278290 state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0 );
279- state.max_seq_len = std::max (state.max_seq_len , seq_len);
291+ state.max_seq_len = std::max (state.max_seq_len , seq_len + offset );
280292 state.q_max_seq_len = std::max (state.q_max_seq_len , q_seq_len);
281293#if defined(USE_NPU)
282- state.seq_lens .push_back (seq_len);
294+ state.seq_lens .push_back (seq_len + offset );
283295 state.q_seq_lens .push_back (q_seq_len);
284296#elif defined(USE_MLU) || defined(USE_CUDA)
285- state.seq_lens .push_back (state.seq_lens .back () + seq_len);
297+ state.seq_lens .push_back (state.seq_lens .back () + seq_len + offset );
286298 state.q_seq_lens .push_back (state.q_seq_lens .back () + q_seq_len);
287299#endif
288300 // Process tokens and positions
@@ -338,7 +350,8 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
338350 state.flatten_tokens_vec .push_back (token_ids[j]);
339351
340352 if (!use_mrope_) {
341- state.flatten_positions_vec .push_back (static_cast <int32_t >(j));
353+ int32_t offset = is_mtp_decode_ ? -1 : 0 ;
354+ state.flatten_positions_vec .push_back (static_cast <int32_t >(j + offset));
342355 }
343356
344357 // Handle sampling for last tokens
@@ -422,6 +435,9 @@ void BatchInputBuilder::setup_kv_cache_info(
422435 // update kv cache tokens num
423436 sequence->kv_state ().incr_kv_cache_tokens_num (/* size=*/ q_seq_len);
424437
438+ int32_t offset = is_mtp_decode_ ? -1 : 0 ;
439+ seq_len += offset;
440+ n_kv_cache_tokens += offset;
425441 const auto blocks = sequence->kv_state ().kv_blocks ();
426442 const auto slot_ids =
427443 sequence->kv_state ().kv_cache_slots (n_kv_cache_tokens, seq_len);
@@ -443,6 +459,7 @@ void BatchInputBuilder::setup_kv_cache_info(
443459 (seq_len % block_size == 0 ) ? block_size : seq_len % block_size;
444460 state.paged_kv_last_page_len .push_back (last_page_len);
445461
462+ // calculate the block ids that need to be written
446463 int32_t kv_cache_block_idx = n_kv_cache_tokens / block_size;
447464 for (auto iter = block_ids.begin () + kv_cache_block_idx;
448465 iter != block_ids.end ();
0 commit comments