diff --git a/xllm/core/framework/model/model_args.h b/xllm/core/framework/model/model_args.h index 168565e89..d6f45fe36 100644 --- a/xllm/core/framework/model/model_args.h +++ b/xllm/core/framework/model/model_args.h @@ -124,6 +124,8 @@ struct ModelArgs { PROPERTY(int32_t, v_head_dim) = 0; PROPERTY(int32_t, q_lora_rank) = 0; PROPERTY(int32_t, kv_lora_rank) = 0; + // deepseek v3/v3.2 MTP + PROPERTY(int32_t, num_nextn_predict_layers) = 0; // deepseek v3.2 indexer PROPERTY(int32_t, index_head_dim) = 0; diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 3041a7dc6..f89f0d720 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -161,6 +161,17 @@ struct ModelInputParams { LOG(INFO) << "ModelInputParams: dp_global_token_nums is " << dp_global_token_nums; } + + int32_t get_q_seq_len(int32_t seq_idx) const { +#if defined(USE_NPU) + CHECK(seq_idx < q_seq_lens_vec.size()) << "seq_idx out of range"; + return q_seq_lens_vec[seq_idx]; +#else + CHECK(seq_idx < q_seq_lens_vec.size() - 1) << "seq_idx out of range"; + return q_seq_lens_vec[seq_idx + 1] - q_seq_lens_vec[seq_idx]; +#endif + } + // whether the kv-cache is empty for all sequences. bool empty_kv_cache = true; diff --git a/xllm/core/layers/common/indexer.h b/xllm/core/layers/common/indexer.h index c45788c00..40f81b693 100644 --- a/xllm/core/layers/common/indexer.h +++ b/xllm/core/layers/common/indexer.h @@ -25,7 +25,8 @@ limitations under the License. #include "../mlu/attention.h" #elif defined(USE_CUDA) #include "../cuda/attention.h" -#endif #include "framework/kv_cache/kv_cache.h" +#endif +#include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" #include "framework/parallel_state/parallel_args.h" #include "framework/quant_args.h" diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index d5a69a7a1..8b351269a 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -190,7 +190,7 @@ std::optional LLMWorkerImpl::step(const ForwardInput& input) { // should be in same prefill stage, so, to judge empty_kv_cache, // just use micro batch 0 here if (options_.enable_speculative_decode() && !is_spec_draft_) { - if (input.input_params.q_seq_lens_vec[0] > 1) { + if (check_is_prefill(input.input_params.q_seq_lens_vec)) { output.sample_output.embeddings = hidden_states; } else if (sampling_params.sample_idxes.defined()) { // auto sample_idxes = diff --git a/xllm/core/runtime/speculative_worker_impl.cpp b/xllm/core/runtime/speculative_worker_impl.cpp index 9975be014..beecdab38 100644 --- a/xllm/core/runtime/speculative_worker_impl.cpp +++ b/xllm/core/runtime/speculative_worker_impl.cpp @@ -36,6 +36,101 @@ namespace { : tensor_; \ } while (0) +// Convert tensor to int64 for MLU platform (temp workaround) +// MLU will support int32 for masked_scatter in the future +torch::Tensor ensure_int64_for_certain_platform(torch::Tensor tensor) { +#if defined(USE_MLU) + return tensor.to(torch::kInt64); +#else + return tensor; +#endif +} + +// Push cumulative sum to vector (used for cumulative format) +void push_cumsum(std::vector& vec, int32_t len) { + if (vec.empty()) { + vec.emplace_back(0); + } + vec.emplace_back(vec.back() + len); +} + +// Batch expansion strategy for validation +// Process validation sequence lengths for each token (used in +// prepare_validate_inputs) For NPU without ATB: add direct values for each +// token For MLU: add cumulative values for each token +void batch_expansion_process_seq_lens( + std::vector& kv_seq_lens_vec, + std::vector& q_seq_lens_vec, + std::vector>& block_tables_vec, + const Slice& kv_seq_lens_slice, + const Slice& block_table_slice, + int32_t seq_id, + int32_t position_offset, + int32_t num_val_tokens) { + for (int32_t token_id = position_offset; + token_id < num_val_tokens + position_offset; + ++token_id) { +#if defined(USE_MLU) + // process kv length and q length with the style of cumulative lengths + // we use batch expansion strategy for validation, so q_len is always 1 + int32_t kv_len = + kv_seq_lens_slice[seq_id + 1] - kv_seq_lens_slice[seq_id] + token_id; + int32_t q_len = 1; + push_cumsum(kv_seq_lens_vec, kv_len); + push_cumsum(q_seq_lens_vec, q_len); +#else + // For NPU without ATB: direct format + q_seq_lens_vec.emplace_back(1); + kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + token_id); +#endif + block_tables_vec.emplace_back(block_table_slice); + } +} + +// Update kv_seq_lens_vec based on platform type +// For NPU: directly add kv_seq_lens_slice[seq_id] + offset +// For others: build cumulative format +void update_kv_seq_lens_vec(std::vector& kv_seq_lens_vec, + const Slice& kv_seq_lens_slice, + int32_t seq_id, + int32_t offset) { +#if defined(USE_NPU) + kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset); +#else + // build cumulative format for kv_seq_lens + int32_t offset_kv_len = + kv_seq_lens_slice[seq_id + 1] - kv_seq_lens_slice[seq_id] + offset; + push_cumsum(kv_seq_lens_vec, offset_kv_len); +#endif +} + +// For GPU and MLU, kv_seq_lens_vec uses the cumulative format (accumulative +// storage). The maximum sequence length is the largest difference between +// consecutive elements. For NPU, kv_seq_lens_vec is in direct format (actual +// lengths), so we simply return the maximum value. +int32_t get_kv_max_seq_len(std::vector& kv_seq_lens_vec) { +#if defined(USE_NPU) + // NPU: kv_seq_lens_vec is in direct format, return the maximum value + // directly. + return *std::max_element(kv_seq_lens_vec.begin(), kv_seq_lens_vec.end()); +#else + // GPU/MLU: kv_seq_lens_vec is in cumulative format. + // The maximum sequence length is the maximum difference between consecutive + // elements. + if (kv_seq_lens_vec.size() < 2) { + return 0; + } + int32_t max_seq_len = 0; + for (size_t i = 1; i < kv_seq_lens_vec.size(); ++i) { + int32_t len = kv_seq_lens_vec[i] - kv_seq_lens_vec[i - 1]; + if (len > max_seq_len) { + max_seq_len = len; + } + } + return max_seq_len; +#endif +} + int32_t get_new_token_slot_id(const int32_t cur_token_slot_id, const int32_t block_size, const int32_t offset, @@ -90,6 +185,11 @@ SpeculativeWorkerImpl::SpeculativeWorkerImpl(const ParallelArgs& parallel_args, runtime_options.enable_schedule_overlap(false); impl_ = std::make_unique(parallel_args, device, runtime_options); + // here we specify num speculative tokens to 0 to pass the indication of + // draft model to worker when enable_speculative_decode. + // NOTE: If you want to modify this part, make sure you also check the usage + // of + // num_speculative_tokens in draft model. runtime_options.num_decoding_tokens(1).num_speculative_tokens(0); draft_impl_ = std::make_unique(parallel_args, device, runtime_options); @@ -217,7 +317,8 @@ std::optional SpeculativeWorkerImpl::step_prefill( // prepare input for draft model auto& embeddings = output.sample_output.embeddings; - auto next_tokens = safe_to(output.sample_output.next_tokens, torch::kInt); + auto next_tokens = ensure_int64_for_certain_platform( + safe_to(output.sample_output.next_tokens, torch::kInt)); auto start_idx = 0; auto token_start_idx = 0; @@ -231,6 +332,7 @@ std::optional SpeculativeWorkerImpl::step_prefill( } if (next_tokens.defined()) { auto& token_ids = prefill_input.token_ids; + token_ids = ensure_int64_for_certain_platform(token_ids); auto mask = (token_ids == -1); // TODO: support multi stream parallel case // token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, @@ -288,7 +390,7 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs( new_token_ids.reserve(input.token_ids.numel()); for (size_t i = 0; i < input_params.num_sequences; ++i) { int32_t q_len = 0; - q_len = input_params.q_seq_lens_vec[i]; + q_len = input_params.get_q_seq_len(i); Slice tokens_ids_slice_i = tokens_ids_slice.slice(start_idx + 1, start_idx + q_len); start_idx += q_len; @@ -363,11 +465,12 @@ std::optional SpeculativeWorkerImpl::step_decode( for (int i = 0; i < options_.num_speculative_tokens(); ++i) { ForwardOutput draft_output = draft_outputs[i]; - auto next_tokens = - safe_to(draft_output.sample_output.next_tokens, torch::kInt); + auto next_tokens = ensure_int64_for_certain_platform( + safe_to(draft_output.sample_output.next_tokens, torch::kInt)); int32_t start_idx = 0; int32_t offset = draft_input.input_params.num_sequences; auto& token_ids = validate_input.token_ids; + token_ids = ensure_int64_for_certain_platform(token_ids); auto mask = (token_ids == -1 * (i + 1)); token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, offset)); start_idx += offset; @@ -441,7 +544,7 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(const ForwardInput& input, Slice new_cache_slots_slice = {new_cache_slots.data_ptr(), new_cache_slots.numel()}; for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) { - kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset); + update_kv_seq_lens_vec(kv_seq_lens_vec, kv_seq_lens_slice, seq_id, offset); torch::Tensor block_table = block_tables[seq_id]; Slice block_table_slice = {block_table.data_ptr(), block_table.numel()}; @@ -529,18 +632,21 @@ void SpeculativeWorkerImpl::prepare_validate_inputs( // process kv length and q length if (FLAGS_enable_atb_spec_kernel) { + // expand the num of decode tokens for each batch in the batch for + // validation kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + num_speculative_tokens + position_offset); q_seq_lens_vec.emplace_back(num_val_tokens); } else { - for (int32_t token_id = position_offset; - token_id < num_val_tokens + position_offset; - ++token_id) { - q_seq_lens_vec.emplace_back(1); - kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + token_id); - // repeat block table - block_tables_vec.emplace_back(block_table_slice); - } + // expand the batch sizes for validation + batch_expansion_process_seq_lens(kv_seq_lens_vec, + q_seq_lens_vec, + block_tables_vec, + kv_seq_lens_slice, + block_table_slice, + seq_id, + position_offset, + num_val_tokens); } // process position related params @@ -636,6 +742,7 @@ SampleOutput SpeculativeWorkerImpl::validate( size_t num_draft_tokens = num_target_tokens - batch_size; COUNTER_ADD(speculative_num_draft_tokens_total, num_draft_tokens); COUNTER_ADD(speculative_num_accepted_tokens_total, num_draft_tokens - count); + return sample_output; } @@ -651,11 +758,14 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output( torch::Tensor positions = safe_to(inputs.positions, torch::kCPU); Slice positions_slice = {positions.data_ptr(), positions.numel()}; + // Get the tokens generated in the last step (flattened for easier indexing) torch::Tensor last_token_ids = safe_to( last_step_output_.sample_output.next_tokens.flatten(), torch::kCPU); Slice last_tokens_ids_slice = {last_token_ids.data_ptr(), last_token_ids.numel()}; + // Determine how many tokens were decoded in the last step + // If the output is 2D, it means multiple tokens were generated per sequence int32_t last_step_decode_num = 1; if (last_step_output_.sample_output.next_tokens.dim() == 2) { last_step_decode_num = last_step_output_.sample_output.next_tokens.size(1); @@ -676,25 +786,31 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output( new_token_ids.reserve(num_sequences); new_positions.reserve(num_sequences); - // update the input_params - input_params.kv_max_seq_len = - input_params.kv_max_seq_len + last_step_decode_num - 1; - std::vector kv_seq_lens_vec = {}; std::vector new_token_slot_ids; new_token_slot_ids.reserve(num_sequences); - // get right token id and position + // Process each sequence to get the correct token ID and position for the next + // step for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) { int32_t postion_offset = 0; int32_t last_step_token_id = 0; + + // If the token ID is non-negative, it's a direct token ID (not a + // placeholder) if (tokens_ids_slice[seq_id] >= 0) { last_step_token_id = tokens_ids_slice[seq_id]; } else { + // Negative token IDs are placeholders that need to be resolved from + // last_step_output_ The absolute value minus 1 gives the index into the + // last step's output int32_t last_step_index = -1 * tokens_ids_slice[seq_id] - 1; last_step_index = last_step_index * last_step_decode_num; last_step_token_id = last_tokens_ids_slice[last_step_index]; - for (int i = 1; i < last_step_decode_num; ++i) { + + // If multiple tokens were decoded, find the last valid (non-negative) + // token This handles cases where some speculative tokens were rejected + for (int32_t i = 1; i < last_step_decode_num; ++i) { int32_t token_id = last_tokens_ids_slice[last_step_index + i]; if (token_id >= 0) { last_step_token_id = token_id; @@ -704,15 +820,22 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output( } new_token_ids.push_back(last_step_token_id); + + // If no position offset, use the same position and cache slot if (postion_offset == 0) { new_positions.emplace_back(positions_slice[seq_id]); - kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id]); + update_kv_seq_lens_vec(kv_seq_lens_vec, kv_seq_lens_slice, seq_id, 0); new_token_slot_ids.emplace_back(new_cache_slots_slice[seq_id]); continue; } + + // Update position and KV sequence length based on the offset new_positions.emplace_back(positions_slice[seq_id] + postion_offset); - kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + postion_offset); + update_kv_seq_lens_vec( + kv_seq_lens_vec, kv_seq_lens_slice, seq_id, postion_offset); + // Calculate the new cache slot ID based on the position offset + // This handles cases where we need to move to a different block torch::Tensor block_table = block_tables[seq_id]; Slice block_table_slice = {block_table.data_ptr(), block_table.numel()}; @@ -724,6 +847,10 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output( new_token_slot_ids.emplace_back(new_token_slot_id); } + // Update the maximum KV sequence length + input_params.kv_max_seq_len = get_kv_max_seq_len(kv_seq_lens_vec); + + // Create new tensors with updated values torch::TensorOptions int_options = inputs.token_ids.options(); new_inputs.token_ids = torch::tensor(new_token_ids, int_options); new_inputs.positions = torch::tensor(new_positions, int_options); diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 7e5a61086..408de6de9 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -600,9 +600,26 @@ bool WorkerImpl::init_model(const std::string& model_weights_path) { } } +#if defined(USE_NPU) if (options_.enable_speculative_decode() && FLAGS_enable_atb_spec_kernel) { args.num_speculative_tokens(options_.num_speculative_tokens()); } +#else + if (options_.enable_speculative_decode()) { + args.num_speculative_tokens(options_.num_speculative_tokens()); + // When running speculative decoding, the draft worker reuses the same + // checkpoint as the target DeepSeek V3/V32 model. The draft worker needs to + // instantiate the MTP variant, so override the model_type here without + // mutating the original config. + if (options_.num_speculative_tokens() == 0 && + (args.model_type() == "deepseek_v3" || + args.model_type() == "deepseek_v32")) { + LOG(INFO) << "Overriding draft model_type from " << args.model_type() + << " to deepseek_mtp for speculative decoding"; + args.model_type("deepseek_mtp"); + } + } +#endif // create model context dtype_ = dtype; @@ -1040,12 +1057,25 @@ AlignedTensorCreater::AlignedTensorCreater( << ((uintptr_t)base_ptr_ % page_size == 0 ? "YES" : "NO"); } bool WorkerImpl::check_is_prefill(const std::vector& q_seq_lens_vec) { +#if defined(USE_NPU) + // On NPU, q_seq_lens_vec directly represents query lengths for (auto q_len : q_seq_lens_vec) { if (q_len > 1) { return true; } } return false; +#else + // On MLU and GPU, q_seq_lens_vec holds cumulative values (starting from 0) + // A sequence is prefill if any per-sequence delta (q_seq_lens_vec[i+1] - + // q_seq_lens_vec[i]) > 1 + for (size_t i = 0; i + 1 < q_seq_lens_vec.size(); ++i) { + if ((q_seq_lens_vec[i + 1] - q_seq_lens_vec[i]) > 1) { + return true; + } + } + return false; +#endif } } // namespace xllm diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 441b52004..20c822d68 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -421,8 +421,9 @@ class LlmForCausalLMImplBase : public torch::nn::Module { #endif } - void load_model(std::unique_ptr loader, - std::string prefix = "model." /*llm model weight prefix*/) { + virtual void load_model( + std::unique_ptr loader, + std::string prefix = "model." /*llm model weight prefix*/) { for (const auto& state_dict : loader->get_state_dicts()) { model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); if (tie_word_embeddings) { diff --git a/xllm/models/llm/mlu/deepseek_mtp.h b/xllm/models/llm/mlu/deepseek_mtp.h new file mode 100644 index 000000000..0da7bcd11 --- /dev/null +++ b/xllm/models/llm/mlu/deepseek_mtp.h @@ -0,0 +1,290 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#pragma once + +#include +#include + +#include +#include + +#include "core/layers/deepseek_v2_decoder_layer.h" +#include "models/llm/llm_model_base.h" + +// DeepSeek v2 compatible with huggingface weights +// ref to: +// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py + +namespace xllm { + +class DeepseekMultiTokenPredictorLayerImpl : public torch::nn::Module { + public: + DeepseekMultiTokenPredictorLayerImpl(const ModelContext& context, + const int32_t layer_index) { + auto options = context.get_tensor_options(); + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + + // register submodules + enorm_ = register_module("enorm", layer::RmsNorm(context)); + hnorm_ = register_module("hnorm", layer::RmsNorm(context)); + // no quantization for eh_proj + eh_proj_ = + register_module("eh_proj", + layer::ReplicatedLinear(model_args.hidden_size() * 2, + model_args.hidden_size(), + /*bias=*/false, + /*QuantArgs=*/QuantArgs(), + options)); + mtp_block_ = register_module( + "mtp_block", layer::DeepseekV2DecoderLayer(context, layer_index)); + } + + torch::Tensor forward(torch::Tensor embed, + torch::Tensor positions, + const layer::AttentionMetadata& attn_metadata, + KVCache& kv_cache, + const ModelInputParams& input_params) { + // Layer norm on token inputs + auto enorm_out = enorm_(embed); + + const auto& embedding_data = + input_params.mm_data.get("embedding"); + CHECK(embedding_data.has_value()) + << "embedding is not defined or has no value in input_params.mm_data"; + torch::Tensor previous_hidden_states = embedding_data.value(); + previous_hidden_states = hnorm_(previous_hidden_states); + + // Concatenate along last dimension and project + auto concat_emb = torch::cat({enorm_out, previous_hidden_states}, -1); + auto hidden_states = eh_proj_(concat_emb); + + // Pass through mtp block + hidden_states = mtp_block_( + hidden_states, positions, attn_metadata, kv_cache, input_params); + + return hidden_states; + } + + void load_state_dict(const StateDict& state_dict) { + enorm_->load_state_dict(state_dict.get_dict_with_prefix("enorm.")); + hnorm_->load_state_dict(state_dict.get_dict_with_prefix("hnorm.")); + eh_proj_->load_state_dict(state_dict.get_dict_with_prefix("eh_proj.")); + mtp_block_->load_state_dict(state_dict); + } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + + virtual void update_expert_weight(int32_t layer_id) { return; } + + private: + layer::RmsNorm enorm_{nullptr}; + layer::RmsNorm hnorm_{nullptr}; + layer::ReplicatedLinear eh_proj_{nullptr}; + layer::DeepseekV2DecoderLayer mtp_block_{nullptr}; +}; +TORCH_MODULE(DeepseekMultiTokenPredictorLayer); + +class DeepseekMTPModelImpl : public torch::nn::Module { + public: + DeepseekMTPModelImpl(const ModelContext& context) { + auto options = context.get_tensor_options(); + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + + // get mtp start and end layer index + mtp_start_layer_idx_ = model_args.n_layers(); + mtp_end_layer_idx_ = + mtp_start_layer_idx_ + model_args.num_nextn_predict_layers(); + blocks_ = register_module("layers", torch::nn::ModuleList()); + mtp_layers_.reserve(model_args.num_nextn_predict_layers()); + + // create mtp layers + for (int32_t i = mtp_start_layer_idx_; i < mtp_end_layer_idx_; ++i) { + auto mtp_layer = DeepseekMultiTokenPredictorLayer(context, i); + mtp_layers_.push_back(mtp_layer); + blocks_->push_back(mtp_layer); + } + embed_tokens_ = + register_module("embed_tokens", + layer::WordEmbedding(model_args.vocab_size(), + model_args.hidden_size(), + context.get_parallel_args(), + options)); + norm_ = register_module("norm", layer::RmsNorm(context)); + + // get dp size and rank + dp_size_ = parallel_args.dp_size(); + std::vector indices; + dp_local_tp_size_ = parallel_args.world_size() / dp_size_; + dp_rank_ = parallel_args.rank() / dp_local_tp_size_; + rank_ = parallel_args.rank(); + for (size_t i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) { + indices.push_back(i); + } + } + + // Provide batched signature to satisfy callers that pass vectors + torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + bool is_prefill = input_params.q_max_seq_len > 1; + auto attn_metadata = + layer::AttentionMetadata::build(input_params, is_prefill); + torch::Tensor hidden_states = embed_tokens_(tokens); + // Mask out embeddings where positions == 0 (for MTP not needed at pos 0) + auto mask = (positions == 0); // bool tensor + if (mask.any().item()) { + // set masked rows to zero + hidden_states.index_put_({mask}, + torch::zeros_like(hidden_states.index({mask}))); + } + + for (size_t i = 0; i < mtp_layers_.size(); i++) { + auto& layer = mtp_layers_[i]; + hidden_states = layer( + hidden_states, positions, attn_metadata, kv_caches[i], input_params); + } + auto result = norm_(hidden_states); + return result; + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + // call each layer's load_state_dict function + for (int32_t i = 0; i < mtp_layers_.size(); i++) { + int32_t layer_index = mtp_start_layer_idx_ + i; + mtp_layers_[i]->load_state_dict(state_dict.get_dict_with_prefix( + "layers." + std::to_string(layer_index) + ".")); + // there is only one shared_head.norm for deepseek models, so we load it + // here + if (i == mtp_layers_.size() - 1) { + norm_->load_state_dict(state_dict.get_dict_with_prefix( + "layers." + std::to_string(layer_index) + ".shared_head.norm.")); + } + } + } + + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; + } + + private: + torch::nn::ModuleList blocks_{nullptr}; + std::vector mtp_layers_; + int32_t mtp_start_layer_idx_; + int32_t mtp_end_layer_idx_; + int32_t dp_rank_; + int32_t rank_; + int32_t dp_size_; + int32_t dp_local_tp_size_; + layer::WordEmbedding embed_tokens_{nullptr}; + layer::RmsNorm norm_{nullptr}; +}; +TORCH_MODULE(DeepseekMTPModel); + +class DeepseekMTPForCausalLMImpl + : public LlmForCausalLMImplBase { + public: + DeepseekMTPForCausalLMImpl(const ModelContext& context) + : LlmForCausalLMImplBase(context) {} + + void load_model( + std::unique_ptr loader, + std::string prefix = "model." /*llm model weight prefix*/) override { + // no need to load lm_head since it shares the same weights with main models + for (const auto& state_dict : loader->get_state_dicts()) { + model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); + } + } +}; +TORCH_MODULE(DeepseekMTPForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(deepseek_mtp, DeepseekMTPForCausalLM); + +// example config: +// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json +REGISTER_MODEL_ARGS(deepseek_mtp, [&] { + LOAD_ARG_OR(model_type, "model_type", "deepseek_mtp"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(vocab_size, "vocab_size", 129280); + LOAD_ARG_OR(hidden_size, "hidden_size", 7168); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 61); + LOAD_ARG_OR(n_heads, "num_attention_heads", 128); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 128); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 18432); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 163840); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 1); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 0); + LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f); + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR(sliding_window, "sliding_window", 4096); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 61); + + LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 0); + LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1); + LOAD_ARG_OR(topk_method, "topk_method", "noaux_tc"); + LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 256); + LOAD_ARG_OR(n_shared_experts, "n_shared_experts", 1); + LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 8); + LOAD_ARG_OR(moe_intermediate_size, "moe_intermediate_size", 2048); + LOAD_ARG_OR(routed_scaling_factor, "routed_scaling_factor", 2.5f); + LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true); + LOAD_ARG_OR(n_group, "n_group", 8); + LOAD_ARG_OR(topk_group, "topk_group", 4); + LOAD_ARG_OR(qk_nope_head_dim, "qk_nope_head_dim", 128); + LOAD_ARG_OR(qk_rope_head_dim, "qk_rope_head_dim", 64); + LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); + LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 1536); + LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512); + LOAD_ARG_OR(num_nextn_predict_layers, "num_nextn_predict_layers", 1); + + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim(); + }); + LOAD_ARG_OR_FUNC( + rotary_dim, "rotary_dim", [&] { return args->qk_rope_head_dim(); }); + + SET_ARG(rope_scaling_rope_type, "deepseek_yarn"); + LOAD_ARG(rope_scaling_beta_fast, "rope_scaling.beta_fast"); + LOAD_ARG(rope_scaling_beta_slow, "rope_scaling.beta_slow"); + LOAD_ARG(rope_scaling_factor, "rope_scaling.factor"); + LOAD_ARG_OR( + rope_extrapolation_factor, "rope_scaling.extrapolation_factor", 1.0f); + LOAD_ARG(rope_scaling_mscale, "rope_scaling.mscale"); + LOAD_ARG(rope_scaling_mscale_all_dim, "rope_scaling.mscale_all_dim"); + LOAD_ARG(rope_scaling_original_max_position_embeddings, + "rope_scaling.original_max_position_embeddings"); + LOAD_ARG_OR(rope_scaling_attn_factor, "rope_scaling.attn_factor", 1.0f); + + SET_ARG(stop_token_ids, std::unordered_set({1})); + + // extra parameters for DeepSeek-V3.2-Exp + // example config: + // https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/config.json + // set default value to 0 so as to distinguish from DeepSeek-V3. + LOAD_ARG_OR(index_head_dim, "index_head_dim", 128); + LOAD_ARG_OR(index_n_heads, "index_n_heads", 64); + LOAD_ARG_OR(index_topk, "index_topk", 2048); +}); +} // namespace xllm diff --git a/xllm/models/llm/mlu/deepseek_v2.h b/xllm/models/llm/mlu/deepseek_v2.h index 733d3e312..9666b86c9 100644 --- a/xllm/models/llm/mlu/deepseek_v2.h +++ b/xllm/models/llm/mlu/deepseek_v2.h @@ -28,15 +28,13 @@ limitations under the License. namespace xllm { -using torch::indexing::None; -using ISlice = torch::indexing::Slice; - class DeepseekV2DecoderLayerImpl : public torch::nn::Module { public: - DeepseekV2DecoderLayerImpl(const ModelContext& context, const int32_t i) { + DeepseekV2DecoderLayerImpl(const ModelContext& context, + const int32_t layer_index) { // register submodules - decoder_layer_ = register_module("decoder_layer", - layer::DeepseekV2DecoderLayer(context, i)); + decoder_layer_ = register_module( + "decoder_layer", layer::DeepseekV2DecoderLayer(context, layer_index)); } torch::Tensor forward(torch::Tensor& x, @@ -72,14 +70,6 @@ class DeepseekV2ModelImpl : public torch::nn::Module { blocks_ = register_module("layers", torch::nn::ModuleList()); layers_.reserve(model_args.n_layers()); - // register submodules - num_speculative_tokens_ = model_args.num_speculative_tokens(); - - // MTP is not support for now - if (num_speculative_tokens_ > 0) { - LOG(FATAL) << "DeepSeek MTP on MLU is not support for now"; - } - embed_tokens_ = register_module("embed_tokens", layer::WordEmbedding(model_args.vocab_size(), @@ -115,12 +105,13 @@ class DeepseekV2ModelImpl : public torch::nn::Module { bool is_prefill = input_params.q_max_seq_len > 1; auto attn_metadata = layer::AttentionMetadata::build(input_params, is_prefill); - torch::Tensor h = embed_tokens_(tokens); + torch::Tensor hidden_states = embed_tokens_(tokens); for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; - h = layer(h, positions, attn_metadata, kv_caches[i], input_params); + hidden_states = layer( + hidden_states, positions, attn_metadata, kv_caches[i], input_params); } - return norm_(h); + return norm_(hidden_states); } // Provide batched signature to satisfy callers that pass vectors @@ -156,7 +147,6 @@ class DeepseekV2ModelImpl : public torch::nn::Module { int32_t rank_; int32_t dp_size_; int32_t dp_local_tp_size_; - int32_t num_speculative_tokens_ = 0; layer::WordEmbedding embed_tokens_{nullptr}; layer::RmsNorm norm_{nullptr}; }; @@ -210,6 +200,7 @@ REGISTER_MODEL_ARGS(deepseek_v2, [&] { LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 0); LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512); + LOAD_ARG_OR(num_nextn_predict_layers, "num_nextn_predict_layers", 1); LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim(); diff --git a/xllm/models/llm/mlu/deepseek_v3.h b/xllm/models/llm/mlu/deepseek_v3.h index c96830f4c..308b4f42f 100644 --- a/xllm/models/llm/mlu/deepseek_v3.h +++ b/xllm/models/llm/mlu/deepseek_v3.h @@ -59,6 +59,7 @@ REGISTER_MODEL_ARGS(deepseek_v3, [&] { LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 1536); LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512); + LOAD_ARG_OR(num_nextn_predict_layers, "num_nextn_predict_layers", 1); LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim(); diff --git a/xllm/models/llm/mlu/deepseek_v32.h b/xllm/models/llm/mlu/deepseek_v32.h index 996535fbe..cc36768e5 100644 --- a/xllm/models/llm/mlu/deepseek_v32.h +++ b/xllm/models/llm/mlu/deepseek_v32.h @@ -59,6 +59,7 @@ REGISTER_MODEL_ARGS(deepseek_v32, [&] { LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 1536); LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512); + LOAD_ARG_OR(num_nextn_predict_layers, "num_nextn_predict_layers", 1); LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim(); diff --git a/xllm/models/models.h b/xllm/models/models.h index 0460d6ff5..60964d172 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -38,6 +38,7 @@ limitations under the License. #include "vlm/qwen3_vl.h" // IWYU pragma: keep #include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep #elif defined(USE_MLU) +#include "llm/mlu/deepseek_mtp.h" // IWYU pragma: keep #include "llm/mlu/deepseek_v2.h" // IWYU pragma: keep #include "llm/mlu/deepseek_v3.h" // IWYU pragma: keep #include "llm/mlu/deepseek_v32.h" // IWYU pragma: keep