diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 4d50c809d196..c9744d31f0ef 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -382,6 +382,7 @@ th { | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | +| `KimiLinearForCausalLM` | Kimi-Linear-48B-A3B-Base, Kimi-Linear-48B-A3B-Instruct | `moonshotai/Kimi-Linear-48B-A3B-Base`, `moonshotai/Kimi-Linear-48B-A3B-Instruct` | | ✅︎ | | `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | | `Lfm2MoeForCausalLM` | LFM2MoE | `LiquidAI/LFM2-8B-A1B-preview`, etc. | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 17b1d7b527f6..9a2a1eb5f1a7 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -296,6 +296,9 @@ def check_available_online( "random": "ai21labs/Jamba-tiny-random", }, ), + "KimiLinearForCausalLM": _HfExamplesInfo( + "moonshotai/Kimi-Linear-48B-A3B-Instruct", trust_remote_code=True + ), "Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B"), "Lfm2MoeForCausalLM": _HfExamplesInfo( "LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c24a94091be4..567a9bfe4ced 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -453,6 +453,7 @@ class CompilationConfig: "vllm::linear_attention", "vllm::plamo2_mamba_mixer", "vllm::gdn_attention", + "vllm::kda_attention", "vllm::sparse_attn_indexer", ] diff --git a/vllm/config/model.py b/vllm/config/model.py index e22c218c769d..9919dd4829b8 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1257,6 +1257,7 @@ def is_deepseek_mla(self) -> bool: "deepseek_v32", "deepseek_mtp", "kimi_k2", + "kimi_linear", "longcat_flash", ): return self.hf_text_config.kv_lora_rank is not None diff --git a/vllm/model_executor/layers/fla/ops/kda.py b/vllm/model_executor/layers/fla/ops/kda.py index a10847d347d1..700f287ca456 100644 --- a/vllm/model_executor/layers/fla/ops/kda.py +++ b/vllm/model_executor/layers/fla/ops/kda.py @@ -1304,7 +1304,7 @@ def kda_gate_fwd_kernel( tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1)) -def kda_gate_fwd( +def fused_kda_gate( g: torch.Tensor, A: torch.Tensor, head_k_dim: int, diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py new file mode 100644 index 000000000000..c45e7546fac1 --- /dev/null +++ b/vllm/model_executor/layers/kda.py @@ -0,0 +1,426 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from einops import rearrange +from torch import nn + +from vllm.attention import AttentionBackend +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + +from .fla.ops.kda import ( + FusedRMSNormGated, + chunk_kda, + fused_kda_gate, + fused_recurrent_kda, +) +from .linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from .mamba.abstract import MambaBase +from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator +from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from .quantization.base_config import QuantizationConfig + +logger = init_logger(__name__) + + +def kda_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, output=output) + + +def kda_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="kda_attention", + op_func=kda_attention, + mutates_args=["output"], + fake_impl=kda_attention_fake, +) + + +class KimiDeltaAttention(nn.Module, MambaBase): + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend + + return GDNAttentionBackend + + def get_state_dtype( + self, + ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]: + if self.model_config is None or self.cache_config is None: + raise ValueError("model_config and cache_config must be set") + return MambaStateDtypeCalculator.kda_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype + ) + + def get_state_shape( + self, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.kda_state_shape( + self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size + ) + + def __init__( + self, + layer_idx: int, + hidden_size: int, + quant_config: QuantizationConfig | None = None, + cache_config: CacheConfig | None = None, + model_config: ModelConfig | None = None, + rms_norm_eps: float = 1e-5, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = hidden_size + self.model_config = model_config + self.cache_config = cache_config + if model_config is None: + raise ValueError("model_config must be provided") + kda_config = model_config.linear_attn_config + self.head_dim = kda_config["head_dim"] + self.num_heads = kda_config["num_heads"] + self.layer_idx = layer_idx + self.prefix = prefix + assert self.num_heads % self.tp_size == 0 + self.local_num_heads = divide(self.num_heads, self.tp_size) + + projection_size = self.head_dim * self.num_heads + self.conv_size = kda_config["short_conv_kernel_size"] + + self.q_proj = ColumnParallelLinear( + self.hidden_size, + projection_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.k_proj = ColumnParallelLinear( + self.hidden_size, + projection_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.k_proj", + ) + self.v_proj = ColumnParallelLinear( + self.hidden_size, + projection_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.v_proj", + ) + + self.f_a_proj = ReplicatedLinear( + self.hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.f_a_proj", + ) + + self.f_b_proj = ColumnParallelLinear( + self.head_dim, + projection_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.f_b_proj", + ) + self.dt_bias = nn.Parameter( + torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32) + ) + + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + self.b_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.b_proj", + ) + + self.q_conv1d = ColumnParallelLinear( + input_size=self.conv_size, + output_size=projection_size, + bias=False, + params_dtype=torch.float32, + prefix=f"{prefix}.q_conv1d", + ) + self.k_conv1d = ColumnParallelLinear( + input_size=self.conv_size, + output_size=projection_size, + bias=False, + params_dtype=torch.float32, + prefix=f"{prefix}.k_conv1d", + ) + self.v_conv1d = ColumnParallelLinear( + input_size=self.conv_size, + output_size=projection_size, + bias=False, + params_dtype=torch.float32, + prefix=f"{prefix}.v_conv1d", + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1) + self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1) + self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1) + + self.A_log = nn.Parameter( + torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32) + ) + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)}) + + self.g_a_proj = ReplicatedLinear( + self.hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.g_a_proj", + ) + self.g_b_proj = ColumnParallelLinear( + self.head_dim, + projection_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.g_b_proj", + ) + self.o_norm = FusedRMSNormGated( + self.head_dim, eps=rms_norm_eps, activation="sigmoid" + ) + self.o_proj = RowParallelLinear( + projection_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + output: torch.Tensor, + ) -> None: + return torch.ops.vllm.kda_attention( + hidden_states, + output, + self.prefix, + ) + + def _forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ) -> None: + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + # Mimic the memory allocation in the real run + q = torch.empty_like(hidden_states) + k = torch.empty_like(hidden_states) + v = torch.empty_like(hidden_states) + g = hidden_states.new_empty( + hidden_states.size(0), + self.local_num_heads, + self.head_dim, + dtype=torch.float32, + ) + beta = torch.empty( + hidden_states.size(0), self.local_num_heads, dtype=torch.float32 + ) + core_attn_out = torch.empty_like(hidden_states) + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + constant_caches = self.kv_cache[forward_context.virtual_engine] + + (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches + # deal with strides + conv_state_q = conv_state_q.transpose(-1, -2) + conv_state_k = conv_state_k.transpose(-1, -2) + conv_state_v = conv_state_v.transpose(-1, -2) + + q_proj_states = self.q_proj(hidden_states)[0] + k_proj_states = self.k_proj(hidden_states)[0] + v_proj_states = self.v_proj(hidden_states)[0] + + q_conv_weights = self.q_conv1d.weight.view( + self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2) + ) + k_conv_weights = self.k_conv1d.weight.view( + self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2) + ) + v_conv_weights = self.v_conv1d.weight.view( + self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2) + ) + if attn_metadata.num_prefills > 0: + q_proj_states = q_proj_states.transpose(0, 1) + k_proj_states = k_proj_states.transpose(0, 1) + v_proj_states = v_proj_states.transpose(0, 1) + q = causal_conv1d_fn( + q_proj_states, + q_conv_weights, + self.q_conv1d.bias, + activation="silu", + conv_states=conv_state_q, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + k = causal_conv1d_fn( + k_proj_states, + k_conv_weights, + self.k_conv1d.bias, + activation="silu", + conv_states=conv_state_k, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + v = causal_conv1d_fn( + v_proj_states, + v_conv_weights, + self.v_conv1d.bias, + activation="silu", + conv_states=conv_state_v, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + else: + decode_conv_indices = non_spec_state_indices_tensor[ + : attn_metadata.num_decodes + ] + q = causal_conv1d_update( + q_proj_states, + conv_state_q, + q_conv_weights, + self.q_conv1d.bias, + activation="silu", + conv_state_indices=decode_conv_indices, + validate_data=True, + ) + k = causal_conv1d_update( + k_proj_states, + conv_state_k, + k_conv_weights, + self.k_conv1d.bias, + activation="silu", + conv_state_indices=decode_conv_indices, + validate_data=True, + ) + v = causal_conv1d_update( + v_proj_states, + conv_state_v, + v_conv_weights, + self.v_conv1d.bias, + activation="silu", + conv_state_indices=decode_conv_indices, + validate_data=True, + ) + + q, k, v = map( + lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v) + ) + + beta = self.b_proj(hidden_states)[0].float().sigmoid() + + g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0] + g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias) + + beta = beta.unsqueeze(0) + g = g.unsqueeze(0) + + if attn_metadata.num_prefills > 0: + zero_idx = non_spec_state_indices_tensor[~has_initial_state] + recurrent_state[zero_idx] = 0 + initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous() + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=initial_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=non_spec_query_start_loc, + ) + # Init cache + recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state + else: + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = fused_recurrent_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + use_qk_l2norm_in_kernel=True, + cu_seqlens=non_spec_query_start_loc, + ssm_state_indices=non_spec_state_indices_tensor, + ) + + g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0] + g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim) + core_attn_out = self.o_norm(core_attn_out_non_spec, g) + core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)") + + output[:] = self.o_proj(core_attn_out)[0] diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 91a45623582d..831dab2fbb01 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -80,6 +80,15 @@ def gated_delta_net_state_dtype( state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) return (state_dtype, state_dtype) + @classmethod + def kda_state_dtype( + cls, + model_dtype: ModelDType | torch.dtype, + mamba_cache_dtype: MambaDType, + ): + state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) + return (state_dtype, state_dtype, state_dtype, torch.float32) + class MambaStateShapeCalculator: @classmethod @@ -182,3 +191,35 @@ def gated_delta_net_state_shape( head_v_dim, ) return conv_state_shape, temporal_state_shape + + @classmethod + def kda_state_shape( + cls, + tp_world_size: int, + num_heads: int, + head_dim: int, + num_k_heads: int | None = None, + head_k_dim: int | None = None, + conv_kernel_size: int = 4, + num_spec: int = 0, + ) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]: + if num_k_heads is None: + num_k_heads = num_heads + if head_k_dim is None: + head_k_dim = head_dim + + proj_size = num_heads * head_dim + proj_k_size = num_k_heads * head_k_dim + + conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1) + conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1) + recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim) + + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0] + return ( + conv_state_shape, + conv_state_k_shape, + conv_state_k_shape, + recurrent_state_shape, + ) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 34f05f2ee962..c4c44b83ae6b 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -147,9 +147,10 @@ def forward_native( # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) - q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( - positions, q[..., self.qk_nope_head_dim :], k_pe - ) + if self.rotary_emb is not None: + q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim :], k_pe + ) if self.indexer and self.is_sparse: _topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index ac5949cda9de..48ab78a60529 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy +from math import lcm from typing import TYPE_CHECKING import vllm.envs as envs @@ -8,7 +9,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec if TYPE_CHECKING: from vllm.config import VllmConfig @@ -347,12 +348,28 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # get attention page size (for 1 token) - attn_page_size_1_token = FullAttentionSpec( - block_size=1, - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - ).page_size_bytes + # Attention backend constraints: + # - FlashAttention (FA) requires block size to be multiple of 16 + # - MLA (Multi-head Latent Attention) requires larger alignment: + # * CUTLASS_MLA backend: kernel_block_size 128 alignment + # * Other MLA backends: kernel_block_size 64 alignment + if model_config.use_mla: + use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + kernel_block_alignment_size = 128 if use_cutlass_mla else 64 + attn_page_size_1_token = MLAAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + else: + kernel_block_alignment_size = 16 + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes model_cls, _ = ModelRegistry.resolve_model_cls( model_config.architecture, @@ -372,17 +389,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if mamba_page_size == 0: return - # Attention backend constraints: - # - FlashAttention (FA) requires block size to be multiple of 16 - # - MLA (Multi-head Latent Attention) requires larger alignment: - # * CUTLASS_MLA backend: 128-byte alignment - # * Other MLA backends: 64-byte alignment - if model_config.use_mla: - use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" - kernel_block_alignment_size = 128 if use_cutlass_mla else 64 - else: - kernel_block_alignment_size = 16 - if cache_config.enable_prefix_caching: # With prefix caching, select attention block size to # optimize for mamba kernel performance @@ -400,15 +406,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # easily by changing the way we layout chunks in the # mamba2 kernels. - from math import gcd - - def lcm(a, b): - return a * b // gcd(a, b) - - base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() - + base_chunk_size = model_config.get_mamba_chunk_size() attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) - chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) cache_config.mamba_block_size = attn_block_size diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py new file mode 100644 index 000000000000..a60a8d764d9d --- /dev/null +++ b/vllm/model_executor/models/kimi_linear.py @@ -0,0 +1,663 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.kda import KimiDeltaAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig + +from .interfaces import HasInnerState, IsHybrid, MixtureOfExperts, SupportsPP +from .utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class KimiMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QKVParallelLinear | None = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class KimiMoE(nn.Module): + def __init__( + self, + config: KimiLinearConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + layer_idx: int = 0, + ): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + moe_intermediate_size = config.moe_intermediate_size + num_experts = config.num_experts + moe_renormalize = config.moe_renormalize + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.num_shared_experts = config.num_shared_experts + self.layer_idx = layer_idx + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.gate.e_score_correction_bias = nn.Parameter(torch.empty(num_experts)) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=config.num_experts_per_token, + hidden_size=hidden_size, + intermediate_size=moe_intermediate_size, + reduce_results=False, + renormalize=moe_renormalize, + quant_config=quant_config, + use_grouped_topk=config.use_grouped_topk, + num_expert_group=config.num_expert_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.moe_router_activation_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + ) + + if self.num_shared_experts is not None: + intermediate_size = moe_intermediate_size * self.num_shared_experts + self.shared_experts = KimiMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_size) + if self.num_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = ( + self.experts(hidden_states=hidden_states, router_logits=router_logits) + * self.routed_scaling_factor + ) + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + +class KimiMLAAttention(nn.Module): + """ + Main reference: DeepseekV2 vllm Implementation + """ + + def __init__( + self, + config: KimiLinearConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int | None, + kv_lora_rank: int, + rope_theta: float = 10000, + use_nope: bool = False, + rope_scaling: dict[str, Any] | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.use_nope = use_nope + assert self.use_nope is True + assert self.q_lora_rank is None + assert rope_scaling is None + assert num_heads % tp_size == 0 + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_a_layernorm = RMSNorm( + self.kv_lora_rank, + eps=config.rms_norm_eps, + ) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + mla_modules = MLAModules( + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + rotary_emb=None, + o_proj=self.o_proj, + fused_qkv_a_proj=None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + q_a_layernorm=None, + q_b_proj=None, + q_proj=self.q_proj, + indexer=None, + is_sparse=False, + topk_indices_buffer=None, + ) + self.mla_attn = MultiHeadLatentAttentionWrapper( + self.hidden_size, + self.num_local_heads, + self.scaling, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + self.v_head_dim, + self.q_lora_rank, + self.kv_lora_rank, + mla_modules, + cache_config, + quant_config, + prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + output: torch.Tensor, + ) -> None: + output[:] = self.mla_attn(positions, hidden_states) + + +class KimiDecoderLayer(nn.Module): + def __init__( + self, + config: KimiLinearConfig, + layer_idx: int, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, + model_config: ModelConfig | None = None, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + self.is_moe = config.is_moe + + if config.is_kda_layer(layer_idx): + self.self_attn = KimiDeltaAttention( + layer_idx=layer_idx, + hidden_size=config.hidden_size, + quant_config=quant_config, + cache_config=cache_config, + model_config=config, + prefix=f"{prefix}.self_attn", + ) + else: + self.self_attn = KimiMLAAttention( + layer_idx=layer_idx, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + quant_config=quant_config, + cache_config=cache_config, + model_config=model_config, + prefix=f"{prefix}.self_attn", + config=config, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank, + kv_lora_rank=config.kv_lora_rank, + use_nope=config.mla_use_nope, + ) + + if ( + self.is_moe + and config.num_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): + self.block_sparse_moe = KimiMoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.mlp = self.block_sparse_moe + else: + self.mlp = KimiMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_output = torch.empty_like(hidden_states) + self.self_attn( + hidden_states=hidden_states, + positions=positions, + output=attn_output, + ) + hidden_states = attn_output + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class KimiLinearModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_text_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + self.config = config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + extra_kwargs = {} + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + return KimiDecoderLayer( + config, + layer_idx, + cache_config, + quant_config, + parallel_config, + model_config, + prefix, + **extra_kwargs, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + get_layer, + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + world_size = get_tensor_model_parallel_world_size() + assert config.num_attention_heads % world_size == 0, ( + "num_attention_heads must be divisible by world_size" + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for _, layer in enumerate(self.layers[self.start_layer : self.end_layer]): + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class KimiLinearForCausalLM( + nn.Module, HasInnerState, SupportsPP, MixtureOfExperts, IsHybrid +): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.model_config = vllm_config.model_config + self.vllm_config = vllm_config + self.config = self.model_config.hf_config + quant_config = vllm_config.quant_config + self.quant_config = quant_config + self.model = KimiLinearModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.config.vocab_size, scale=logit_scale + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) + return hidden_states + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.kda_state_dtype( + vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + tp_size = parallel_config.tensor_parallel_size + num_spec = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ) + return MambaStateShapeCalculator.kda_state_shape( + tp_size, + hf_config.linear_attn_config["num_heads"], + hf_config.linear_attn_config["head_dim"], + conv_kernel_size=hf_config.linear_attn_config["short_conv_kernel_size"], + num_spec=num_spec, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + if self.config.is_moe: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_experts, + ) + else: + expert_params_mapping = [] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for args in weights: + name, loaded_weight = args[:2] + kwargs = args[2] if len(args) > 2 else {} + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for idx, (param_name, weight_name, expert_id, shard_id) in enumerate( + expert_params_mapping + ): + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + expert_id=expert_id, + shard_id=shard_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") + and name not in params_dict + and not self.config.is_linear_attn + ): # noqa: E501 + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight, **kwargs) + loaded_params.add(name) + + +def get_spec_layer_idx_from_weight_name( + config: KimiLinearConfig, weight_name: str +) -> int | None: + if hasattr(config, "num_nextn_predict_layers") and ( + config.num_nextn_predict_layers > 0 + ): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 0027954ac277..8e4413c90cf6 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -118,6 +118,7 @@ "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), # noqa: E501 "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 34c0429a8067..b1f4e3e2a983 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -79,6 +79,7 @@ def __getitem__(self, key): deepseek_v3="DeepseekV3Config", deepseek_v32="DeepseekV3Config", flex_olmo="FlexOlmoConfig", + kimi_linear="KimiLinearConfig", kimi_vl="KimiVLConfig", Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index befe9cdae76a..663a8e44d71d 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -19,6 +19,7 @@ from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig from vllm.transformers_utils.configs.jais import JAISConfig +from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -54,6 +55,7 @@ "MiDashengLMConfig", "MLPSpeculatorConfig", "MoonViTConfig", + "KimiLinearConfig", "KimiVLConfig", "NemotronConfig", "NemotronHConfig", diff --git a/vllm/transformers_utils/configs/kimi_linear.py b/vllm/transformers_utils/configs/kimi_linear.py new file mode 100644 index 000000000000..65ddf48c5249 --- /dev/null +++ b/vllm/transformers_utils/configs/kimi_linear.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class KimiLinearConfig(PretrainedConfig): + model_type = "kimi_linear" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + model_type="kimi_linear", + vocab_size=163840, + hidden_size=4096, + head_dim=None, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + rope_theta=10000.0, + rope_scaling=None, + tie_word_embeddings=False, + moe_intermediate_size: int | None = None, + moe_renormalize: bool = True, + moe_router_activation_func: str = "sigmoid", + num_experts: int | None = None, + num_experts_per_token: int | None = None, + num_shared_experts: int = 0, + routed_scaling_factor: float = 1.0, + first_k_dense_replace: int = 0, + moe_layer_freq: int = 1, + use_grouped_topk: bool = True, + num_expert_group: int = 1, + topk_group: int = 1, + q_lora_rank: int | None = None, + kv_lora_rank: int | None = None, + qk_nope_head_dim: int | None = None, + qk_rope_head_dim: int | None = None, + v_head_dim: int | None = None, + mla_use_nope: bool | None = False, + num_nextn_predict_layers: int = 0, + linear_attn_config: dict | None = None, + **kwargs, + ): + self.model_type = model_type + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.head_dim = ( + head_dim if head_dim is not None else hidden_size // num_attention_heads + ) + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.mla_use_nope = mla_use_nope + # moe config + self.num_experts = num_experts + self.num_experts_per_token = num_experts_per_token + self.moe_renormalize = moe_renormalize + self.num_shared_experts = num_shared_experts + self.routed_scaling_factor = routed_scaling_factor + self.moe_router_activation_func = moe_router_activation_func + assert self.moe_router_activation_func in ("softmax", "sigmoid") + self.moe_intermediate_size = moe_intermediate_size + self.first_k_dense_replace = first_k_dense_replace + self.moe_layer_freq = moe_layer_freq + self.use_grouped_topk = use_grouped_topk + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.num_nextn_predict_layers = num_nextn_predict_layers + + if linear_attn_config is not None: + assert linear_attn_config["kda_layers"] is not None + assert linear_attn_config["full_attn_layers"] is not None + self.linear_attn_config = linear_attn_config + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def is_mla(self): + return ( + self.q_lora_rank is not None + or self.kv_lora_rank is not None + or self.qk_nope_head_dim is not None + or self.qk_rope_head_dim is not None + or self.v_head_dim is not None + or self.mla_use_nope is True + ) + + @property + def is_moe(self): + return self.num_experts is not None + + @property + def is_linear_attn(self) -> bool: + return not ( + self.linear_attn_config is None + or ( + isinstance(self.linear_attn_config, dict) + and self.linear_attn_config["kda_layers"] is not None + and len(self.linear_attn_config["kda_layers"]) == 0 + ) + ) + + def is_kda_layer(self, layer_idx: int): + return ( + self.linear_attn_config is not None + and (layer_idx + 1) in self.linear_attn_config["kda_layers"] + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 129d7e54466a..d4ed8281841e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,6 +8,7 @@ from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy +from functools import reduce from itertools import product from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast @@ -4122,26 +4123,18 @@ def _check_and_update_cudagraph_mode( def calculate_reorder_batch_threshold(self) -> None: """ - Check that if any backends reorder batches; that the reordering - is compatible (e.g., decode threshold is the same) + Choose the minimum reorder batch threshold from all attention groups. + Backends should be able to support lower threshold then what they request + just may have a performance penalty due to that backend treating decodes + as prefills. """ - for group in self._attn_group_iterator(): - attn_metadata_builder_i = group.get_metadata_builder() - - # check that if any backends reorder batches; that the reordering - # is compatible (e.g., decode threshold is the same) - reorder_batch_threshold_i = attn_metadata_builder_i.reorder_batch_threshold - if reorder_batch_threshold_i is not None: - if self.reorder_batch_threshold is not None: - if reorder_batch_threshold_i != self.reorder_batch_threshold: - raise ValueError( - f"Attention backend reorders decodes with " - f"threshold {reorder_batch_threshold_i} but other " - f"backend uses threshold " - f"{self.reorder_batch_threshold}" - ) - else: - self.reorder_batch_threshold = reorder_batch_threshold_i + min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b) + + reorder_batch_thresholds = [ + group.get_metadata_builder().reorder_batch_threshold + for group in self._attn_group_iterator() + ] + self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) def _find_compatible_block_sizes( self,