diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index 978086d1909d..8ffc700ca5cd 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -532,7 +532,7 @@ def _linear_attn_decode_kernel( pid_d = tl.program_id(2) # dimension block index # Load slot index for the current batch - slot_id = tl.load(slot_idx + pid_b) + slot_id = tl.load(slot_idx + pid_b).to(tl.int64) # Skip if slot_id is -1 (padding) if slot_id == -1: diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 42c815b08f04..ad1401791238 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -5,6 +5,17 @@ class MambaStateShapeCalculator: + @classmethod + def linear_attention_state_shape( + cls, + num_heads: int, + tp_size: int, + head_dim: int, + ) -> tuple[tuple[int, int, int], ...]: + + state_shape = (num_heads // tp_size, head_dim, head_dim) + return (state_shape, ) + @classmethod def mamba1_state_shape( cls, diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index f2773af490c5..1f9f7f60cabf 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -14,8 +14,9 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig +from vllm import envs from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, @@ -33,6 +34,9 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -41,8 +45,9 @@ from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata -from .interfaces import HasInnerState, IsHybrid, SupportsV0Only +from .interfaces import HasInnerState, IsHybrid from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -327,7 +332,17 @@ def jit_linear_forward_prefix(q: torch.Tensor, return rearrange(output.squeeze(0), "h n d -> n (h d)") -class MiniMaxText01LinearAttention(nn.Module): +class MiniMaxText01LinearAttention(nn.Module, MambaBase): + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.linear_attention_state_shape( + num_heads=self.num_heads, + tp_size=self.tp_size, + head_dim=self.head_dim) def __init__( self, @@ -359,6 +374,7 @@ def __init__( self.tp_heads = self.total_num_heads // self.tp_size self.qkv_size = self.num_heads * self.head_dim self.tp_hidden = self.head_dim * self.tp_heads + self.prefix = prefix self.qkv_proj = ColumnParallelLinear( hidden_size, @@ -397,6 +413,12 @@ def __init__( self.tp_heads:(self.tp_rank + 1) * self.tp_heads].contiguous() + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + @staticmethod def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: @@ -434,13 +456,14 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, break if _prefill_idx >= len(state_indices_tensor): break - _start = attn_metadata.query_start_loc[_prefill_idx] - _end = attn_metadata.query_start_loc[_prefill_idx + 1] - slot_id = state_indices_tensor[_prefill_idx] + # prefills are packed at end of batch in V1 + offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0 + _start = attn_metadata.query_start_loc[offset + _prefill_idx] + _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] + slot_id = state_indices_tensor[offset + _prefill_idx] qs = q[_start:_end].transpose(0, 1).contiguous() ks = k[_start:_end].transpose(0, 1).contiguous() vs = v[_start:_end].transpose(0, 1).contiguous() - slot_id = state_indices_tensor[_prefill_idx] slice_layer_cache = kv_cache[slot_id, ...] out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( @@ -453,9 +476,13 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, layer_idx=self.layer_idx) hidden.append(out_slice.contiguous()) if attn_metadata.num_decode_tokens > 0: - hidden.append( - self._decode_infer(q, k, v, kv_cache, state_indices_tensor, - attn_metadata)) + hidden_decode = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + if envs.VLLM_USE_V1: + hidden.insert(0, hidden_decode) + else: + hidden.append(hidden_decode) if not hidden: return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) @@ -465,11 +492,17 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): - q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0 - ):] + if not envs.VLLM_USE_V1: + q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + num_prefills = getattr(attn_metadata, "num_prefills", 0) + slot_id = state_indices_tensor[num_prefills:] + else: + q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[:attn_metadata.num_decodes] hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, slot_id, 32) return hidden @@ -483,17 +516,49 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - kv_cache = kv_caches.minimax_cache - state_indices_tensor = kv_caches.state_indices_tensor + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + kv_cache = self.kv_cache[forward_context.virtual_engine][0] + state_indices_tensor = attn_metadata.state_indices_tensor + + num_prefills = getattr(attn_metadata, "num_prefills", 0) + if num_prefills > 0: + num_decode_tokens = getattr(attn_metadata, + "num_decode_tokens", 0) + for prefill_idx in range(num_prefills): + q_start = attn_metadata.query_start_loc[ + num_decode_tokens + prefill_idx] + q_end = attn_metadata.query_start_loc[num_decode_tokens + + prefill_idx + + 1] + query_len = q_end - q_start + context_len = attn_metadata.seq_lens[ + num_decode_tokens + prefill_idx] - query_len + if context_len == 0: + block_to_clear = state_indices_tensor[ + num_decode_tokens + prefill_idx] + kv_cache[block_to_clear, ...] = 0 + else: + kv_cache = kv_caches.minimax_cache + state_indices_tensor = kv_caches.state_indices_tensor decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - if not decode_only: - hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) + if attn_metadata is None: + hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]), + device=q.device, + dtype=q.dtype) else: - hidden = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, attn_metadata) + if not decode_only: + hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + else: + hidden = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) hidden = self.norm._forward(hidden) gate, _ = self.output_gate(hidden_states) @@ -541,6 +606,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.sliding_window = sliding_window + self.prefix = prefix self.qkv_proj = QKVParallelLinear( hidden_size, @@ -575,7 +641,12 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, attn_metadata = forward_context.attn_metadata qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = attn_metadata.rotary_emb(positions, q, k) + if envs.VLLM_USE_V1: + if attn_metadata is not None: + q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb( + positions, q, k) + else: + q, k = attn_metadata.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -595,6 +666,7 @@ def __init__( ) -> None: self._ilayer = layer_id self._irank = get_tensor_model_parallel_rank() + self.prefix = prefix super().__init__() self.hidden_size = config.hidden_size @@ -876,8 +948,9 @@ def layer_fn(prefix): self._dtype = _dummy.dtype del _dummy - self.minimax_cache = MinimaxCacheManager(dtype=torch.float32, - cache_shape=self.cache_shape) + if not envs.VLLM_USE_V1: + self.minimax_cache = MinimaxCacheManager( + dtype=torch.float32, cache_shape=self.cache_shape) rope_theta = getattr(config, "rope_theta", 10000) head_dim = getattr(config, "head_dim", None) @@ -944,23 +1017,27 @@ def forward(self, **kwargs) -> Union[torch.Tensor, IntermediateTensors]: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - if attn_metadata is None: + if not envs.VLLM_USE_V1 and attn_metadata is None: return None if "request_ids_to_seq_ids" not in kwargs: kwargs["request_ids_to_seq_ids"] = {} if "finished_requests_ids" not in kwargs: kwargs["finished_requests_ids"] = [] - ( - minimax_cache_tensors, - state_indices_tensor, - ) = self.minimax_cache.current_run_tensors(**kwargs) - if getattr(attn_metadata, "num_prefills", 0) > 0: - self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, - **kwargs) + if not envs.VLLM_USE_V1: + ( + minimax_cache_tensors, + state_indices_tensor, + ) = self.minimax_cache.current_run_tensors(**kwargs) + if getattr(attn_metadata, "num_prefills", 0) > 0: + self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, + **kwargs) + + minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, + state_indices_tensor) + else: + minimax_cache_params = None - minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, - state_indices_tensor) if get_pp_group().is_first_rank: if inputs_embeds is None: hidden_states = self.embed_scale * self.embed_tokens(input_ids) @@ -973,11 +1050,22 @@ def forward(self, residual = intermediate_tensors["residual"] minimax_cache_index = 0 - attn_metadata.rotary_emb = self.rotary_emb + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] + if attn_metadata is not None: + # TODO (tdoublep): this whole thing with the rotary_emb is + # weird. we shouldn't be passing it via attn_metadata imo. + if envs.VLLM_USE_V1: + if isinstance(layer.self_attn, MiniMaxText01Attention): + attn_metadata[layer.prefix + + ".attn"].rotary_emb = self.rotary_emb + else: + attn_metadata.rotary_emb = self.rotary_emb + _caches = None - if isinstance(layer.self_attn, MiniMaxText01LinearAttention): + if not envs.VLLM_USE_V1 and isinstance( + layer.self_attn, MiniMaxText01LinearAttention): current_state_layer = minimax_cache_index _caches = minimax_cache_params.at_layer_idx( current_state_layer) @@ -1002,8 +1090,7 @@ def forward(self, return hidden_states -class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, - SupportsV0Only): +class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -1321,3 +1408,28 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor, load_basic_weight(name, loaded_weight, self) return loaded_params + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, ...], ...]: + """Calculate shape for MiniMaxText01LinearAttention cache. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - state_shape: Shape of the cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.linear_attention_state_shape( + num_heads=hf_config.num_attention_heads, + tp_size=parallel_config.tensor_parallel_size, + head_dim=hf_config.head_dim, + ) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py new file mode 100644 index 000000000000..f08b6d7f177c --- /dev/null +++ b/vllm/v1/attention/backends/linear_attn.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class LinearAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: + return LinearAttentionMetadataBuilder + + +@dataclass +class LinearAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + + state_indices_tensor: torch.Tensor # shape: [batch,] + + +class LinearAttentionMetadataBuilder( + AttentionMetadataBuilder[LinearAttentionMetadata]): + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> LinearAttentionMetadata: + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=1)) + + attn_metadata = LinearAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + state_indices_tensor=state_indices_tensor, + ) + return attn_metadata diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py index f56f2fb7bf69..852e0dfe1b31 100644 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ b/vllm/v1/attention/backends/mamba_selectors.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.attention.backends.abstract import AttentionBackend +from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend @@ -8,9 +9,10 @@ def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: if mamba_type == "mamba1": return Mamba1AttentionBackend - if mamba_type == "mamba2": return Mamba2AttentionBackend + if mamba_type == "linear_attention": + return LinearAttentionBackend raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " "supported yet.")