Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/lightning_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def _linear_attn_decode_kernel(
pid_d = tl.program_id(2) # dimension block index

# Load slot index for the current batch
slot_id = tl.load(slot_idx + pid_b)
slot_id = tl.load(slot_idx + pid_b).to(tl.int64)

# Skip if slot_id is -1 (padding)
if slot_id == -1:
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/mamba/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@

class MambaStateShapeCalculator:

@classmethod
def linear_attention_state_shape(
cls,
num_heads: int,
tp_size: int,
head_dim: int,
) -> tuple[tuple[int, int, int], ...]:

state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape, )

@classmethod
def mamba1_state_shape(
cls,
Expand Down
192 changes: 152 additions & 40 deletions vllm/model_executor/models/minimax_text_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
Expand All @@ -33,6 +34,9 @@
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand All @@ -41,8 +45,9 @@
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata

from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
from .interfaces import HasInnerState, IsHybrid
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers

Expand Down Expand Up @@ -327,7 +332,17 @@ def jit_linear_forward_prefix(q: torch.Tensor,
return rearrange(output.squeeze(0), "h n d -> n (h d)")


class MiniMaxText01LinearAttention(nn.Module):
class MiniMaxText01LinearAttention(nn.Module, MambaBase):

@property
def mamba_type(self) -> str:
return "linear_attention"

def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=self.num_heads,
tp_size=self.tp_size,
head_dim=self.head_dim)

def __init__(
self,
Expand Down Expand Up @@ -359,6 +374,7 @@ def __init__(
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.prefix = prefix

self.qkv_proj = ColumnParallelLinear(
hidden_size,
Expand Down Expand Up @@ -397,6 +413,12 @@ def __init__(
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()

if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self

@staticmethod
def weight_direct_load(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
Expand Down Expand Up @@ -434,13 +456,14 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
break
if _prefill_idx >= len(state_indices_tensor):
break
_start = attn_metadata.query_start_loc[_prefill_idx]
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
slot_id = state_indices_tensor[_prefill_idx]
# prefills are packed at end of batch in V1
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slot_id = state_indices_tensor[_prefill_idx]
slice_layer_cache = kv_cache[slot_id, ...]

out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
Expand All @@ -453,9 +476,13 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
layer_idx=self.layer_idx)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden.append(
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
attn_metadata))
hidden_decode = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
if envs.VLLM_USE_V1:
hidden.insert(0, hidden_decode)
else:
hidden.append(hidden_decode)

if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
Expand All @@ -465,11 +492,17 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,

def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
):]
if not envs.VLLM_USE_V1:
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
num_prefills = getattr(attn_metadata, "num_prefills", 0)
slot_id = state_indices_tensor[num_prefills:]
else:
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
slot_id, 32)
return hidden
Expand All @@ -483,17 +516,49 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor

num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata,
"num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens
+ prefill_idx +
1]
query_len = q_end - q_start
context_len = attn_metadata.seq_lens[
num_decode_tokens + prefill_idx] - query_len
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx]
kv_cache[block_to_clear, ...] = 0
else:
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor

decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
if attn_metadata is None:
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
device=q.device,
dtype=q.dtype)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor, attn_metadata)
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)

hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states)
Expand Down Expand Up @@ -541,6 +606,7 @@ def __init__(
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.prefix = prefix

self.qkv_proj = QKVParallelLinear(
hidden_size,
Expand Down Expand Up @@ -575,7 +641,12 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
attn_metadata = forward_context.attn_metadata
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = attn_metadata.rotary_emb(positions, q, k)
if envs.VLLM_USE_V1:
if attn_metadata is not None:
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
positions, q, k)
else:
q, k = attn_metadata.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
Expand All @@ -595,6 +666,7 @@ def __init__(
) -> None:
self._ilayer = layer_id
self._irank = get_tensor_model_parallel_rank()
self.prefix = prefix
super().__init__()

self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -876,8 +948,9 @@ def layer_fn(prefix):
self._dtype = _dummy.dtype
del _dummy

self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
cache_shape=self.cache_shape)
if not envs.VLLM_USE_V1:
self.minimax_cache = MinimaxCacheManager(
dtype=torch.float32, cache_shape=self.cache_shape)

rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim", None)
Expand Down Expand Up @@ -944,23 +1017,27 @@ def forward(self,
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []

(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
if not envs.VLLM_USE_V1:
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)

minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
else:
minimax_cache_params = None

minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
if get_pp_group().is_first_rank:
if inputs_embeds is None:
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
Expand All @@ -973,11 +1050,22 @@ def forward(self,
residual = intermediate_tensors["residual"]

minimax_cache_index = 0
attn_metadata.rotary_emb = self.rotary_emb

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if attn_metadata is not None:
# TODO (tdoublep): this whole thing with the rotary_emb is
# weird. we shouldn't be passing it via attn_metadata imo.
if envs.VLLM_USE_V1:
if isinstance(layer.self_attn, MiniMaxText01Attention):
attn_metadata[layer.prefix +
".attn"].rotary_emb = self.rotary_emb
else:
attn_metadata.rotary_emb = self.rotary_emb

_caches = None
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
Expand All @@ -1002,8 +1090,7 @@ def forward(self,
return hidden_states


class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
SupportsV0Only):
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

Expand Down Expand Up @@ -1321,3 +1408,28 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,

load_basic_weight(name, loaded_weight, self)
return loaded_params

@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, ...], ...]:
"""Calculate shape for MiniMaxText01LinearAttention cache.

Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)

Returns:
Tuple containing:
- state_shape: Shape of the cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config

return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=hf_config.num_attention_heads,
tp_size=parallel_config.tensor_parallel_size,
head_dim=hf_config.head_dim,
)
Loading