Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/lightning_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def _linear_attn_decode_kernel(
pid_d = tl.program_id(2) # dimension block index

# Load slot index for the current batch
slot_id = tl.load(slot_idx + pid_b)
slot_id = tl.load(slot_idx + pid_b).to(tl.int64)

# Skip if slot_id is -1 (padding)
if slot_id == -1:
Expand Down
189 changes: 149 additions & 40 deletions vllm/model_executor/models/minimax_text_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
Expand All @@ -33,6 +34,7 @@
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand All @@ -41,8 +43,9 @@
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata

from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
from .interfaces import HasInnerState, IsHybrid
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers

Expand Down Expand Up @@ -327,7 +330,16 @@ def jit_linear_forward_prefix(q: torch.Tensor,
return rearrange(output.squeeze(0), "h n d -> n (h d)")


class MiniMaxText01LinearAttention(nn.Module):
class MiniMaxText01LinearAttention(nn.Module, MambaBase):

@property
def mamba_type(self) -> str:
return "MiniMaxText01LinearAttention"

def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
state_shape = (self.num_heads // self.tp_size, self.head_dim,
self.head_dim)
return (state_shape, )

def __init__(
self,
Expand Down Expand Up @@ -359,6 +371,7 @@ def __init__(
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.prefix = prefix

self.qkv_proj = ColumnParallelLinear(
hidden_size,
Expand Down Expand Up @@ -397,6 +410,12 @@ def __init__(
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()

if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self

@staticmethod
def weight_direct_load(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
Expand Down Expand Up @@ -434,13 +453,14 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
break
if _prefill_idx >= len(state_indices_tensor):
break
_start = attn_metadata.query_start_loc[_prefill_idx]
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
slot_id = state_indices_tensor[_prefill_idx]
# prefills are packed at end of batch in V1
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slot_id = state_indices_tensor[_prefill_idx]
slice_layer_cache = kv_cache[slot_id, ...]

out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
Expand All @@ -453,9 +473,13 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
layer_idx=self.layer_idx)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden.append(
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
attn_metadata))
hidden_decode = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
if envs.VLLM_USE_V1:
hidden.insert(0, hidden_decode)
else:
hidden.append(hidden_decode)

if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
Expand All @@ -465,11 +489,17 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,

def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
):]
if not envs.VLLM_USE_V1:
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
num_prefills = getattr(attn_metadata, "num_prefills", 0)
slot_id = state_indices_tensor[num_prefills:]
else:
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
slot_id, 32)
return hidden
Expand All @@ -483,17 +513,49 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor

num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata,
"num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens
+ prefill_idx +
1]
query_len = q_end - q_start
context_len = attn_metadata.seq_lens[
num_decode_tokens + prefill_idx] - query_len
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx]
kv_cache[block_to_clear, ...] = 0
else:
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor

decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
if attn_metadata is None:
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
device=q.device,
dtype=q.dtype)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor, attn_metadata)
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)

hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states)
Expand Down Expand Up @@ -541,6 +603,7 @@ def __init__(
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.prefix = prefix

self.qkv_proj = QKVParallelLinear(
hidden_size,
Expand Down Expand Up @@ -575,7 +638,12 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
attn_metadata = forward_context.attn_metadata
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = attn_metadata.rotary_emb(positions, q, k)
if envs.VLLM_USE_V1:
if attn_metadata is not None:
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
positions, q, k)
else:
q, k = attn_metadata.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
Expand All @@ -595,6 +663,7 @@ def __init__(
) -> None:
self._ilayer = layer_id
self._irank = get_tensor_model_parallel_rank()
self.prefix = prefix
super().__init__()

self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -876,8 +945,9 @@ def layer_fn(prefix):
self._dtype = _dummy.dtype
del _dummy

self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
cache_shape=self.cache_shape)
if not envs.VLLM_USE_V1:
self.minimax_cache = MinimaxCacheManager(
dtype=torch.float32, cache_shape=self.cache_shape)

rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim", None)
Expand Down Expand Up @@ -944,23 +1014,27 @@ def forward(self,
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []

(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
if not envs.VLLM_USE_V1:
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)

minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
else:
minimax_cache_params = None

minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
if get_pp_group().is_first_rank:
if inputs_embeds is None:
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
Expand All @@ -973,11 +1047,22 @@ def forward(self,
residual = intermediate_tensors["residual"]

minimax_cache_index = 0
attn_metadata.rotary_emb = self.rotary_emb

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if attn_metadata is not None:
# TODO (tdoublep): this whole thing with the rotary_emb is
# weird. we shouldn't be passing it via attn_metadata imo.
if envs.VLLM_USE_V1:
if isinstance(layer.self_attn, MiniMaxText01Attention):
attn_metadata[layer.prefix +
".attn"].rotary_emb = self.rotary_emb
else:
attn_metadata.rotary_emb = self.rotary_emb

_caches = None
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
Expand All @@ -1002,8 +1087,7 @@ def forward(self,
return hidden_states


class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
SupportsV0Only):
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

Expand Down Expand Up @@ -1321,3 +1405,28 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,

load_basic_weight(name, loaded_weight, self)
return loaded_params

@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, ...], ...]:
"""Calculate shape for MiniMaxText01LinearAttention cache.

Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)

Returns:
Tuple containing:
- state_shape: Shape of the cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config

state_shape = (hf_config.num_attention_heads //
parallel_config.tensor_parallel_size,
hf_config.head_dim, hf_config.head_dim)

return (state_shape, )
Loading