Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xllm/core/framework/model/model_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ struct ModelArgs {
PROPERTY(int32_t, v_head_dim) = 0;
PROPERTY(int32_t, q_lora_rank) = 0;
PROPERTY(int32_t, kv_lora_rank) = 0;
// deepseek v3/v3.2 MTP
PROPERTY(int32_t, num_nextn_predict_layers) = 0;

// deepseek v3.2 indexer
PROPERTY(int32_t, index_head_dim) = 0;
Expand Down
11 changes: 11 additions & 0 deletions xllm/core/framework/model/model_input_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,17 @@ struct ModelInputParams {
LOG(INFO) << "ModelInputParams: dp_global_token_nums is "
<< dp_global_token_nums;
}

int32_t get_q_seq_len(int32_t seq_idx) const {
#if defined(USE_NPU)
CHECK(seq_idx < q_seq_lens_vec.size()) << "seq_idx out of range";
return q_seq_lens_vec[seq_idx];
#else
CHECK(seq_idx < q_seq_lens_vec.size() - 1) << "seq_idx out of range";
return q_seq_lens_vec[seq_idx + 1] - q_seq_lens_vec[seq_idx];
#endif
}

// whether the kv-cache is empty for all sequences.
bool empty_kv_cache = true;

Expand Down
3 changes: 2 additions & 1 deletion xllm/core/layers/common/indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ limitations under the License.
#include "../mlu/attention.h"
#elif defined(USE_CUDA)
#include "../cuda/attention.h"
#endif #include "framework/kv_cache/kv_cache.h"
#endif
#include "framework/kv_cache/kv_cache.h"
#include "framework/model/model_input_params.h"
#include "framework/parallel_state/parallel_args.h"
#include "framework/quant_args.h"
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/runtime/llm_worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ std::optional<ForwardOutput> LLMWorkerImpl::step(const ForwardInput& input) {
// should be in same prefill stage, so, to judge empty_kv_cache,
// just use micro batch 0 here
if (options_.enable_speculative_decode() && !is_spec_draft_) {
if (input.input_params.q_seq_lens_vec[0] > 1) {
if (check_is_prefill(input.input_params.q_seq_lens_vec)) {
output.sample_output.embeddings = hidden_states;
} else if (sampling_params.sample_idxes.defined()) {
// auto sample_idxes =
Expand Down
169 changes: 148 additions & 21 deletions xllm/core/runtime/speculative_worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,101 @@ namespace {
: tensor_; \
} while (0)

// Convert tensor to int64 for MLU platform (temp workaround)
// MLU will support int32 for masked_scatter in the future
torch::Tensor ensure_int64_for_certain_platform(torch::Tensor tensor) {
#if defined(USE_MLU)
return tensor.to(torch::kInt64);
#else
return tensor;
#endif
}

// Push cumulative sum to vector (used for cumulative format)
void push_cumsum(std::vector<int32_t>& vec, int32_t len) {
if (vec.empty()) {
vec.emplace_back(0);
}
vec.emplace_back(vec.back() + len);
}

// Batch expansion strategy for validation
// Process validation sequence lengths for each token (used in
// prepare_validate_inputs) For NPU without ATB: add direct values for each
// token For MLU: add cumulative values for each token
void batch_expansion_process_seq_lens(
std::vector<int32_t>& kv_seq_lens_vec,
std::vector<int32_t>& q_seq_lens_vec,
std::vector<std::vector<int32_t>>& block_tables_vec,
const Slice<int32_t>& kv_seq_lens_slice,
const Slice<int32_t>& block_table_slice,
int32_t seq_id,
int32_t position_offset,
int32_t num_val_tokens) {
for (int32_t token_id = position_offset;
token_id < num_val_tokens + position_offset;
++token_id) {
#if defined(USE_MLU)
// process kv length and q length with the style of cumulative lengths
// we use batch expansion strategy for validation, so q_len is always 1
int32_t kv_len =
kv_seq_lens_slice[seq_id + 1] - kv_seq_lens_slice[seq_id] + token_id;
int32_t q_len = 1;
push_cumsum(kv_seq_lens_vec, kv_len);
push_cumsum(q_seq_lens_vec, q_len);
#else
// For NPU without ATB: direct format
q_seq_lens_vec.emplace_back(1);
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + token_id);
#endif
block_tables_vec.emplace_back(block_table_slice);
}
}

// Update kv_seq_lens_vec based on platform type
// For NPU: directly add kv_seq_lens_slice[seq_id] + offset
// For others: build cumulative format
void update_kv_seq_lens_vec(std::vector<int32_t>& kv_seq_lens_vec,
const Slice<int32_t>& kv_seq_lens_slice,
int32_t seq_id,
int32_t offset) {
#if defined(USE_NPU)
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset);
#else
// build cumulative format for kv_seq_lens
int32_t offset_kv_len =
kv_seq_lens_slice[seq_id + 1] - kv_seq_lens_slice[seq_id] + offset;
push_cumsum(kv_seq_lens_vec, offset_kv_len);
#endif
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation of kv_max_seq_len should be moved here, and get_kv_max_seq_len should be removed.


// For GPU and MLU, kv_seq_lens_vec uses the cumulative format (accumulative
// storage). The maximum sequence length is the largest difference between
// consecutive elements. For NPU, kv_seq_lens_vec is in direct format (actual
// lengths), so we simply return the maximum value.
int32_t get_kv_max_seq_len(std::vector<int32_t>& kv_seq_lens_vec) {
#if defined(USE_NPU)
// NPU: kv_seq_lens_vec is in direct format, return the maximum value
// directly.
return *std::max_element(kv_seq_lens_vec.begin(), kv_seq_lens_vec.end());
#else
// GPU/MLU: kv_seq_lens_vec is in cumulative format.
// The maximum sequence length is the maximum difference between consecutive
// elements.
if (kv_seq_lens_vec.size() < 2) {
return 0;
}
int32_t max_seq_len = 0;
for (size_t i = 1; i < kv_seq_lens_vec.size(); ++i) {
int32_t len = kv_seq_lens_vec[i] - kv_seq_lens_vec[i - 1];
if (len > max_seq_len) {
max_seq_len = len;
}
}
return max_seq_len;
#endif
}

int32_t get_new_token_slot_id(const int32_t cur_token_slot_id,
const int32_t block_size,
const int32_t offset,
Expand Down Expand Up @@ -90,6 +185,11 @@ SpeculativeWorkerImpl::SpeculativeWorkerImpl(const ParallelArgs& parallel_args,
runtime_options.enable_schedule_overlap(false);
impl_ =
std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
// here we specify num speculative tokens to 0 to pass the indication of
// draft model to worker when enable_speculative_decode.
// NOTE: If you want to modify this part, make sure you also check the usage
// of
// num_speculative_tokens in draft model.
runtime_options.num_decoding_tokens(1).num_speculative_tokens(0);
draft_impl_ =
std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
Expand Down Expand Up @@ -217,7 +317,8 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(

// prepare input for draft model
auto& embeddings = output.sample_output.embeddings;
auto next_tokens = safe_to(output.sample_output.next_tokens, torch::kInt);
auto next_tokens = ensure_int64_for_certain_platform(
safe_to(output.sample_output.next_tokens, torch::kInt));
auto start_idx = 0;
auto token_start_idx = 0;

Expand All @@ -231,6 +332,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
}
if (next_tokens.defined()) {
auto& token_ids = prefill_input.token_ids;
token_ids = ensure_int64_for_certain_platform(token_ids);
auto mask = (token_ids == -1);
// TODO: support multi stream parallel case
// token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx,
Expand Down Expand Up @@ -288,7 +390,7 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs(
new_token_ids.reserve(input.token_ids.numel());
for (size_t i = 0; i < input_params.num_sequences; ++i) {
int32_t q_len = 0;
q_len = input_params.q_seq_lens_vec[i];
q_len = input_params.get_q_seq_len(i);
Slice<int32_t> tokens_ids_slice_i =
tokens_ids_slice.slice(start_idx + 1, start_idx + q_len);
start_idx += q_len;
Expand Down Expand Up @@ -363,11 +465,12 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(

for (int i = 0; i < options_.num_speculative_tokens(); ++i) {
ForwardOutput draft_output = draft_outputs[i];
auto next_tokens =
safe_to(draft_output.sample_output.next_tokens, torch::kInt);
auto next_tokens = ensure_int64_for_certain_platform(
safe_to(draft_output.sample_output.next_tokens, torch::kInt));
int32_t start_idx = 0;
int32_t offset = draft_input.input_params.num_sequences;
auto& token_ids = validate_input.token_ids;
token_ids = ensure_int64_for_certain_platform(token_ids);
auto mask = (token_ids == -1 * (i + 1));
token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, offset));
start_idx += offset;
Expand Down Expand Up @@ -441,7 +544,7 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(const ForwardInput& input,
Slice<int32_t> new_cache_slots_slice = {new_cache_slots.data_ptr<int32_t>(),
new_cache_slots.numel()};
for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) {
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset);
update_kv_seq_lens_vec(kv_seq_lens_vec, kv_seq_lens_slice, seq_id, offset);
torch::Tensor block_table = block_tables[seq_id];
Slice<int32_t> block_table_slice = {block_table.data_ptr<int32_t>(),
block_table.numel()};
Expand Down Expand Up @@ -529,18 +632,21 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(

// process kv length and q length
if (FLAGS_enable_atb_spec_kernel) {
// expand the num of decode tokens for each batch in the batch for
// validation
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] +
num_speculative_tokens + position_offset);
q_seq_lens_vec.emplace_back(num_val_tokens);
} else {
for (int32_t token_id = position_offset;
token_id < num_val_tokens + position_offset;
++token_id) {
q_seq_lens_vec.emplace_back(1);
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + token_id);
// repeat block table
block_tables_vec.emplace_back(block_table_slice);
}
// expand the batch sizes for validation
batch_expansion_process_seq_lens(kv_seq_lens_vec,
q_seq_lens_vec,
block_tables_vec,
kv_seq_lens_slice,
block_table_slice,
seq_id,
position_offset,
num_val_tokens);
}

// process position related params
Expand Down Expand Up @@ -636,6 +742,7 @@ SampleOutput SpeculativeWorkerImpl::validate(
size_t num_draft_tokens = num_target_tokens - batch_size;
COUNTER_ADD(speculative_num_draft_tokens_total, num_draft_tokens);
COUNTER_ADD(speculative_num_accepted_tokens_total, num_draft_tokens - count);

return sample_output;
}

Expand All @@ -651,11 +758,14 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
torch::Tensor positions = safe_to(inputs.positions, torch::kCPU);
Slice<int32_t> positions_slice = {positions.data_ptr<int32_t>(),
positions.numel()};
// Get the tokens generated in the last step (flattened for easier indexing)
torch::Tensor last_token_ids = safe_to(
last_step_output_.sample_output.next_tokens.flatten(), torch::kCPU);
Slice<int64_t> last_tokens_ids_slice = {last_token_ids.data_ptr<int64_t>(),
last_token_ids.numel()};

// Determine how many tokens were decoded in the last step
// If the output is 2D, it means multiple tokens were generated per sequence
int32_t last_step_decode_num = 1;
if (last_step_output_.sample_output.next_tokens.dim() == 2) {
last_step_decode_num = last_step_output_.sample_output.next_tokens.size(1);
Expand All @@ -676,25 +786,31 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
new_token_ids.reserve(num_sequences);
new_positions.reserve(num_sequences);

// update the input_params
input_params.kv_max_seq_len =
input_params.kv_max_seq_len + last_step_decode_num - 1;

std::vector<int32_t> kv_seq_lens_vec = {};
std::vector<int32_t> new_token_slot_ids;
new_token_slot_ids.reserve(num_sequences);

// get right token id and position
// Process each sequence to get the correct token ID and position for the next
// step
for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) {
int32_t postion_offset = 0;
int32_t last_step_token_id = 0;

// If the token ID is non-negative, it's a direct token ID (not a
// placeholder)
if (tokens_ids_slice[seq_id] >= 0) {
last_step_token_id = tokens_ids_slice[seq_id];
} else {
// Negative token IDs are placeholders that need to be resolved from
// last_step_output_ The absolute value minus 1 gives the index into the
// last step's output
int32_t last_step_index = -1 * tokens_ids_slice[seq_id] - 1;
last_step_index = last_step_index * last_step_decode_num;
last_step_token_id = last_tokens_ids_slice[last_step_index];
for (int i = 1; i < last_step_decode_num; ++i) {

// If multiple tokens were decoded, find the last valid (non-negative)
// token This handles cases where some speculative tokens were rejected
for (int32_t i = 1; i < last_step_decode_num; ++i) {
int32_t token_id = last_tokens_ids_slice[last_step_index + i];
if (token_id >= 0) {
last_step_token_id = token_id;
Expand All @@ -704,15 +820,22 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
}

new_token_ids.push_back(last_step_token_id);

// If no position offset, use the same position and cache slot
if (postion_offset == 0) {
new_positions.emplace_back(positions_slice[seq_id]);
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id]);
update_kv_seq_lens_vec(kv_seq_lens_vec, kv_seq_lens_slice, seq_id, 0);
new_token_slot_ids.emplace_back(new_cache_slots_slice[seq_id]);
continue;
}

// Update position and KV sequence length based on the offset
new_positions.emplace_back(positions_slice[seq_id] + postion_offset);
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + postion_offset);
update_kv_seq_lens_vec(
kv_seq_lens_vec, kv_seq_lens_slice, seq_id, postion_offset);

// Calculate the new cache slot ID based on the position offset
// This handles cases where we need to move to a different block
torch::Tensor block_table = block_tables[seq_id];
Slice<int32_t> block_table_slice = {block_table.data_ptr<int32_t>(),
block_table.numel()};
Expand All @@ -724,6 +847,10 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
new_token_slot_ids.emplace_back(new_token_slot_id);
}

// Update the maximum KV sequence length
input_params.kv_max_seq_len = get_kv_max_seq_len(kv_seq_lens_vec);

// Create new tensors with updated values
torch::TensorOptions int_options = inputs.token_ids.options();
new_inputs.token_ids = torch::tensor(new_token_ids, int_options);
new_inputs.positions = torch::tensor(new_positions, int_options);
Expand Down
30 changes: 30 additions & 0 deletions xllm/core/runtime/worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,26 @@ bool WorkerImpl::init_model(const std::string& model_weights_path) {
}
}

#if defined(USE_NPU)
if (options_.enable_speculative_decode() && FLAGS_enable_atb_spec_kernel) {
args.num_speculative_tokens(options_.num_speculative_tokens());
}
#else
if (options_.enable_speculative_decode()) {
args.num_speculative_tokens(options_.num_speculative_tokens());
// When running speculative decoding, the draft worker reuses the same
// checkpoint as the target DeepSeek V3/V32 model. The draft worker needs to
// instantiate the MTP variant, so override the model_type here without
// mutating the original config.
if (options_.num_speculative_tokens() == 0 &&
(args.model_type() == "deepseek_v3" ||
args.model_type() == "deepseek_v32")) {
LOG(INFO) << "Overriding draft model_type from " << args.model_type()
<< " to deepseek_mtp for speculative decoding";
args.model_type("deepseek_mtp");
}
}
#endif

// create model context
dtype_ = dtype;
Expand Down Expand Up @@ -1040,12 +1057,25 @@ AlignedTensorCreater::AlignedTensorCreater(
<< ((uintptr_t)base_ptr_ % page_size == 0 ? "YES" : "NO");
}
bool WorkerImpl::check_is_prefill(const std::vector<int>& q_seq_lens_vec) {
#if defined(USE_NPU)
// On NPU, q_seq_lens_vec directly represents query lengths
for (auto q_len : q_seq_lens_vec) {
if (q_len > 1) {
return true;
}
}
return false;
#else
// On MLU and GPU, q_seq_lens_vec holds cumulative values (starting from 0)
// A sequence is prefill if any per-sequence delta (q_seq_lens_vec[i+1] -
// q_seq_lens_vec[i]) > 1
for (size_t i = 0; i + 1 < q_seq_lens_vec.size(); ++i) {
if ((q_seq_lens_vec[i + 1] - q_seq_lens_vec[i]) > 1) {
return true;
}
}
return false;
#endif
}

} // namespace xllm
Loading
Loading