From 497f22bfad3d9a133984561c6e38e09c150a69ad Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 4 Aug 2025 17:29:12 -0400 Subject: [PATCH 1/5] Initial enablement Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/lightning_attn.py | 2 +- vllm/model_executor/models/minimax_text_01.py | 191 ++++++++++++++---- vllm/v1/attention/backends/linear_attn.py | 75 +++++++ vllm/v1/attention/backends/mamba_selectors.py | 3 + 4 files changed, 228 insertions(+), 43 deletions(-) create mode 100644 vllm/v1/attention/backends/linear_attn.py diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index 978086d1909d..8ffc700ca5cd 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -532,7 +532,7 @@ def _linear_attn_decode_kernel( pid_d = tl.program_id(2) # dimension block index # Load slot index for the current batch - slot_id = tl.load(slot_idx + pid_b) + slot_id = tl.load(slot_idx + pid_b).to(tl.int64) # Skip if slot_id is -1 (padding) if slot_id == -1: diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index f2773af490c5..18e1b0b94559 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -14,8 +14,9 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig +from vllm import envs from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, @@ -33,6 +34,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -41,8 +43,9 @@ from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata -from .interfaces import HasInnerState, IsHybrid, SupportsV0Only +from .interfaces import HasInnerState, IsHybrid from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -327,7 +330,16 @@ def jit_linear_forward_prefix(q: torch.Tensor, return rearrange(output.squeeze(0), "h n d -> n (h d)") -class MiniMaxText01LinearAttention(nn.Module): +class MiniMaxText01LinearAttention(nn.Module, MambaBase): + + @property + def mamba_type(self) -> str: + return "MiniMaxText01LinearAttention" + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + state_shape = (self.num_heads // self.tp_size, self.head_dim, + self.head_dim) + return (state_shape, ) def __init__( self, @@ -359,6 +371,7 @@ def __init__( self.tp_heads = self.total_num_heads // self.tp_size self.qkv_size = self.num_heads * self.head_dim self.tp_hidden = self.head_dim * self.tp_heads + self.prefix = prefix self.qkv_proj = ColumnParallelLinear( hidden_size, @@ -397,6 +410,12 @@ def __init__( self.tp_heads:(self.tp_rank + 1) * self.tp_heads].contiguous() + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + @staticmethod def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: @@ -429,18 +448,28 @@ def get_slopes_power_of_2(n): def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): hidden = [] + if attn_metadata.num_decode_tokens > 0: + hidden.append( + self._decode_infer(q, k, v, kv_cache, state_indices_tensor, + attn_metadata)) for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): if _prefill_idx >= len(attn_metadata.query_start_loc): break if _prefill_idx >= len(state_indices_tensor): break - _start = attn_metadata.query_start_loc[_prefill_idx] - _end = attn_metadata.query_start_loc[_prefill_idx + 1] - slot_id = state_indices_tensor[_prefill_idx] + + if envs.VLLM_USE_V1: + # prefills are packed at end of batch + offset = getattr(attn_metadata, "num_decodes", 0) + else: + offset = 0 + + _start = attn_metadata.query_start_loc[offset + _prefill_idx] + _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] + slot_id = state_indices_tensor[offset + _prefill_idx] qs = q[_start:_end].transpose(0, 1).contiguous() ks = k[_start:_end].transpose(0, 1).contiguous() vs = v[_start:_end].transpose(0, 1).contiguous() - slot_id = state_indices_tensor[_prefill_idx] slice_layer_cache = kv_cache[slot_id, ...] out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( @@ -452,10 +481,6 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, self.BLOCK, layer_idx=self.layer_idx) hidden.append(out_slice.contiguous()) - if attn_metadata.num_decode_tokens > 0: - hidden.append( - self._decode_infer(q, k, v, kv_cache, state_indices_tensor, - attn_metadata)) if not hidden: return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) @@ -465,11 +490,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): - q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0 - ):] + q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[:attn_metadata.num_decode_tokens] hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, slot_id, 32) return hidden @@ -483,17 +507,52 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - kv_cache = kv_caches.minimax_cache - state_indices_tensor = kv_caches.state_indices_tensor + + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + kv_cache = self.kv_cache[forward_context.virtual_engine][0] + state_indices_tensor = attn_metadata.state_indices_tensor + + num_prefills = getattr(attn_metadata, "num_prefills", 0) + if num_prefills > 0: + num_decode_tokens = getattr(attn_metadata, + "num_decode_tokens", 0) + for prefill_idx in range(num_prefills): + q_start = attn_metadata.query_start_loc[ + num_decode_tokens + prefill_idx] + q_end = attn_metadata.query_start_loc[num_decode_tokens + + prefill_idx + + 1] + query_len = q_end - q_start + context_len = attn_metadata.seq_lens[ + num_decode_tokens + prefill_idx] - query_len + if context_len == 0: + block_to_clear = state_indices_tensor[ + num_decode_tokens + prefill_idx] + kv_cache[block_to_clear, ...] = 0 + else: + kv_cache = kv_caches.minimax_cache + state_indices_tensor = kv_caches.state_indices_tensor decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - if not decode_only: - hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) + + if attn_metadata is None: + hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]), + device=q.device, + dtype=q.dtype) else: - hidden = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, attn_metadata) + + if not decode_only: + hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + else: + hidden = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) hidden = self.norm._forward(hidden) gate, _ = self.output_gate(hidden_states) @@ -541,6 +600,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.sliding_window = sliding_window + self.prefix = prefix self.qkv_proj = QKVParallelLinear( hidden_size, @@ -575,7 +635,12 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, attn_metadata = forward_context.attn_metadata qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = attn_metadata.rotary_emb(positions, q, k) + if envs.VLLM_USE_V1: + if attn_metadata is not None: + q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb( + positions, q, k) + else: + q, k = attn_metadata.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -595,6 +660,8 @@ def __init__( ) -> None: self._ilayer = layer_id self._irank = get_tensor_model_parallel_rank() + self.prefix = prefix + super().__init__() self.hidden_size = config.hidden_size @@ -876,8 +943,9 @@ def layer_fn(prefix): self._dtype = _dummy.dtype del _dummy - self.minimax_cache = MinimaxCacheManager(dtype=torch.float32, - cache_shape=self.cache_shape) + if not envs.VLLM_USE_V1: + self.minimax_cache = MinimaxCacheManager( + dtype=torch.float32, cache_shape=self.cache_shape) rope_theta = getattr(config, "rope_theta", 10000) head_dim = getattr(config, "head_dim", None) @@ -944,23 +1012,26 @@ def forward(self, **kwargs) -> Union[torch.Tensor, IntermediateTensors]: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return None + if "request_ids_to_seq_ids" not in kwargs: kwargs["request_ids_to_seq_ids"] = {} if "finished_requests_ids" not in kwargs: kwargs["finished_requests_ids"] = [] - ( - minimax_cache_tensors, - state_indices_tensor, - ) = self.minimax_cache.current_run_tensors(**kwargs) - if getattr(attn_metadata, "num_prefills", 0) > 0: - self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, - **kwargs) + if not envs.VLLM_USE_V1: + ( + minimax_cache_tensors, + state_indices_tensor, + ) = self.minimax_cache.current_run_tensors(**kwargs) + if getattr(attn_metadata, "num_prefills", 0) > 0: + self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, + **kwargs) + + minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, + state_indices_tensor) + else: + minimax_cache_params = None - minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, - state_indices_tensor) if get_pp_group().is_first_rank: if inputs_embeds is None: hidden_states = self.embed_scale * self.embed_tokens(input_ids) @@ -973,11 +1044,22 @@ def forward(self, residual = intermediate_tensors["residual"] minimax_cache_index = 0 - attn_metadata.rotary_emb = self.rotary_emb + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] + if attn_metadata is not None: + # TODO (tdoublep): this whole thing with the rotary_emb is + # weird. we shouldn't be passing it via attn_metadata imo. + if envs.VLLM_USE_V1: + if isinstance(layer.self_attn, MiniMaxText01Attention): + attn_metadata[layer.prefix + + ".attn"].rotary_emb = self.rotary_emb + else: + attn_metadata.rotary_emb = self.rotary_emb + _caches = None - if isinstance(layer.self_attn, MiniMaxText01LinearAttention): + if not envs.VLLM_USE_V1 and isinstance( + layer.self_attn, MiniMaxText01LinearAttention): current_state_layer = minimax_cache_index _caches = minimax_cache_params.at_layer_idx( current_state_layer) @@ -1002,8 +1084,7 @@ def forward(self, return hidden_states -class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, - SupportsV0Only): +class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -1321,3 +1402,29 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor, load_basic_weight(name, loaded_weight, self) return loaded_params + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + state_shape = (hf_config.num_attention_heads // + parallel_config.tensor_parallel_size, + hf_config.head_dim, hf_config.head_dim) + + return [state_shape] diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py new file mode 100644 index 000000000000..22ed85ae46cf --- /dev/null +++ b/vllm/v1/attention/backends/linear_attn.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class LinearAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: + return LinearAttentionMetadataBuilder + + +@dataclass +class LinearAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + + state_indices_tensor: torch.Tensor # shape: [batch,] + + +class LinearAttentionMetadataBuilder( + AttentionMetadataBuilder[LinearAttentionMetadata]): + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> LinearAttentionMetadata: + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + print("query_start_loc: ", query_start_loc) + print("seq_lens: ", seq_lens) + print("state_indices_tensor: ", state_indices_tensor) + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=1)) + + print("num_prefills: ", num_prefills) + print("num_prefill_tokens: ", num_prefill_tokens) + print("num_decodes: ", num_decodes) + + attn_metadata = LinearAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + state_indices_tensor=state_indices_tensor, + ) + return attn_metadata diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py index 80021a216556..abe08a3f7885 100644 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ b/vllm/v1/attention/backends/mamba_selectors.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.attention.backends.abstract import AttentionBackend +from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: if mamba_type == "mamba2": return Mamba2AttentionBackend + elif mamba_type == "MiniMaxText01LinearAttention": + return LinearAttentionBackend raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " "supported yet.") From d72e474cda873f77a55e151b621c44d42c0099e0 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 4 Aug 2025 17:56:25 -0400 Subject: [PATCH 2/5] Remove debug print Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/linear_attn.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index 22ed85ae46cf..f08b6d7f177c 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -51,18 +51,10 @@ def build(self, state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] - print("query_start_loc: ", query_start_loc) - print("seq_lens: ", seq_lens) - print("state_indices_tensor: ", state_indices_tensor) - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills(common_attn_metadata, decode_threshold=1)) - print("num_prefills: ", num_prefills) - print("num_prefill_tokens: ", num_prefill_tokens) - print("num_decodes: ", num_decodes) - attn_metadata = LinearAttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, From 9a6d1fe2838f8850dcfb92e5326f5921d7c2afcd Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 4 Aug 2025 18:45:52 -0400 Subject: [PATCH 3/5] Cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 18e1b0b94559..0b4d5bfc85b0 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -448,22 +448,13 @@ def get_slopes_power_of_2(n): def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): hidden = [] - if attn_metadata.num_decode_tokens > 0: - hidden.append( - self._decode_infer(q, k, v, kv_cache, state_indices_tensor, - attn_metadata)) for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): if _prefill_idx >= len(attn_metadata.query_start_loc): break if _prefill_idx >= len(state_indices_tensor): break - - if envs.VLLM_USE_V1: - # prefills are packed at end of batch - offset = getattr(attn_metadata, "num_decodes", 0) - else: - offset = 0 - + # prefills are packed at end of batch in V1 + offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0 _start = attn_metadata.query_start_loc[offset + _prefill_idx] _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] slot_id = state_indices_tensor[offset + _prefill_idx] @@ -482,6 +473,15 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, layer_idx=self.layer_idx) hidden.append(out_slice.contiguous()) + if attn_metadata.num_decode_tokens > 0: + hidden_decode = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + if envs.VLLM_USE_V1: + hidden.insert(0, hidden_decode) + else: + hidden.append(hidden_decode) + if not hidden: return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) @@ -490,10 +490,17 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): - q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[:attn_metadata.num_decode_tokens] + if not envs.VLLM_USE_V1: + q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + num_prefills = getattr(attn_metadata, "num_prefills", 0) + slot_id = state_indices_tensor[num_prefills:] + else: + q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[:attn_metadata.num_decodes] hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, slot_id, 32) return hidden @@ -1012,7 +1019,8 @@ def forward(self, **kwargs) -> Union[torch.Tensor, IntermediateTensors]: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - + if not envs.VLLM_USE_V1 and attn_metadata is None: + return None if "request_ids_to_seq_ids" not in kwargs: kwargs["request_ids_to_seq_ids"] = {} if "finished_requests_ids" not in kwargs: From be5a318feacacf7b051af515fb420d06ce4d72bc Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 4 Aug 2025 18:49:09 -0400 Subject: [PATCH 4/5] minor whitespace Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 0b4d5bfc85b0..f2f62b538119 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -472,7 +472,6 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, self.BLOCK, layer_idx=self.layer_idx) hidden.append(out_slice.contiguous()) - if attn_metadata.num_decode_tokens > 0: hidden_decode = self._decode_infer(q, k, v, kv_cache, state_indices_tensor, @@ -514,7 +513,6 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: if attn_metadata is not None: assert isinstance(attn_metadata, dict) @@ -545,13 +543,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, state_indices_tensor = kv_caches.state_indices_tensor decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - if attn_metadata is None: hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype) else: - if not decode_only: hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, state_indices_tensor, @@ -668,7 +664,6 @@ def __init__( self._ilayer = layer_id self._irank = get_tensor_model_parallel_rank() self.prefix = prefix - super().__init__() self.hidden_size = config.hidden_size From 94b5282aa34fc329242f590808842bafe8b60f46 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 6 Aug 2025 08:09:24 -0400 Subject: [PATCH 5/5] Fix some stuff with get_mamba_state_shape_from_config Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index f2f62b538119..be544065dc2f 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -1411,8 +1411,8 @@ def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", use_v1: bool = True, - ) -> tuple[tuple[int, int], tuple[int, int, int]]: - """Calculate shapes for Mamba's convolutional and state caches. + ) -> tuple[tuple[int, ...], ...]: + """Calculate shape for MiniMaxText01LinearAttention cache. Args: vllm_config: vLLM config @@ -1420,8 +1420,7 @@ def get_mamba_state_shape_from_config( Returns: Tuple containing: - - conv_state_shape: Shape for convolutional state cache - - temporal_state_shape: Shape for state space model cache + - state_shape: Shape of the cache """ parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config @@ -1430,4 +1429,4 @@ def get_mamba_state_shape_from_config( parallel_config.tensor_parallel_size, hf_config.head_dim, hf_config.head_dim) - return [state_shape] + return (state_shape, )