From e6a41baeeb8257afde52a7fed0bcc7c0ebe01c9d Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 29 Aug 2025 12:35:16 +0000 Subject: [PATCH 001/105] Initial APC for mamba Signed-off-by: Stanislaw Wozniak --- vllm/config.py | 8 + .../layers/mamba/mamba_mixer2.py | 96 ++++++++++-- .../layers/mamba/ops/mamba_ssm.py | 13 +- .../layers/mamba/ops/ssd_combined.py | 20 ++- vllm/model_executor/models/config.py | 37 ++++- .../model_executor/models/granitemoehybrid.py | 4 - vllm/v1/attention/backends/mamba_attn.py | 20 ++- vllm/v1/core/single_type_kv_cache_manager.py | 140 ++++++++++++++++-- vllm/v1/kv_cache_interface.py | 14 +- vllm/v1/worker/gpu_model_runner.py | 14 +- 10 files changed, 318 insertions(+), 48 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3bcbbe60652b..93e013a7351b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1796,6 +1796,14 @@ class CacheConfig: mamba_page_size_padded: Optional[int] = None """ Optional override for mamba page size; used by hybrid mamba/attention models to ensure exact alignment with attention page size.""" + mamba_block_size: Optional[int] = None + """Size of a contiguous cache block in number of tokens for mamba cache.""" + mamba_cache_strategy: str = "all" + """Logic for mamba cache: + * disabled - turn of prefix caching + * all - keep states for all prefixes + * last - keep the states of the last full blocks after each request + """ # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 36edac2375d0..1e8b4c441348 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -35,7 +35,7 @@ from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, cdiv from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 @@ -461,6 +461,7 @@ def forward_cuda( # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = forward_context.attn_metadata + cache_enabled = False if envs.VLLM_USE_V1: if attn_metadata is not None: assert isinstance(attn_metadata, dict) @@ -478,6 +479,9 @@ def forward_cuda( seq_idx_p = attn_metadata.seq_idx chunk_indices_p = attn_metadata.chunk_indices chunk_offsets_p = attn_metadata.chunk_offsets + mamba_block_size = attn_metadata.cache_spec.block_size + cache_strategy = attn_metadata.cache_spec.cache_strategy + cache_enabled = (cache_strategy != 'disabled') else: conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state @@ -559,6 +563,19 @@ def forward_cuda( [num_decodes, num_prefills], dim=0, ) + + # Note: Eventually this will be moved to mamba2 metadata builder: + seq_lens_pending = (torch.roll(attn_metadata.query_start_loc, -1, -1) - attn_metadata.query_start_loc)[:-1] + seq_lens_completed = (mamba2_metadata.seq_lens - seq_lens_pending) + last_computed_token_block_idx = seq_lens_completed // mamba_block_size - 1 # e.g. 16 blocks computed; 0th based indexing -> state[15] + last_computed_token_block_idx = last_computed_token_block_idx.clamp(min=0) #in case it's non-computed it's -1 and causes later issues with indexing + current_first_token_block_idx = cdiv(seq_lens_completed + 1, mamba_block_size) - 1 + current_last_token_block_idx = cdiv(seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 + + last_computed_idx_d, last_computed_idx_p = torch.split(last_computed_token_block_idx, [num_decodes, num_prefills], dim=0) + current_first_idx_d, current_first_idx_p = torch.split(current_first_token_block_idx, [num_decodes, num_prefills], dim=0) + current_last_idx_d, current_last_idx_p = torch.split(current_last_token_block_idx, [num_decodes, num_prefills], dim=0) + query_start_loc_p = ( attn_metadata.query_start_loc[-num_prefills - 1:] - num_decodes if has_prefill else None) @@ -595,6 +612,16 @@ def forward_cuda( if mamba2_metadata.cu_seqlen is None: mamba2_metadata = update_metadata(x, query_start_loc_p, mamba2_metadata) + + kernel_conv1d_indices = state_indices_tensor_p + if cache_enabled: + # Kernel expects to have the initial state here and overwrites it -> use final state location + if has_initial_states_p is not None and has_initial_states_p.sum() > 0: + conv_state_idx_input = state_indices_tensor_p.index_select(1, last_computed_idx_p).diag().unsqueeze(1) + conv_state_idx_output = state_indices_tensor_p.index_select(1, current_last_idx_p).diag().unsqueeze(1) + conv_state[conv_state_idx_output[has_initial_states_p]] = conv_state[conv_state_idx_input[has_initial_states_p]] + kernel_conv1d_indices = state_indices_tensor_p.index_select(1, current_last_idx_p).diag() + hidden_states_B_C_p = causal_conv1d_fn( x, conv_weights, @@ -602,10 +629,14 @@ def forward_cuda( activation=self.activation, conv_states=conv_state, has_initial_state=has_initial_states_p, - cache_indices=state_indices_tensor_p, + cache_indices=kernel_conv1d_indices, metadata=mamba2_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] + + if cache_enabled: + #TODO + pass hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_B_C_p) @@ -615,15 +646,19 @@ def forward_cuda( if (has_initial_states_p is not None and prep_initial_states): # making a copy of the states if envs.VLLM_USE_V1: + kernel_ssm_indices = state_indices_tensor_p + if cache_enabled: + kernel_ssm_indices = state_indices_tensor_p. \ + index_select(1, last_computed_idx_p).diag() initial_states = torch.where( has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) + ssm_state[kernel_ssm_indices], 0) else: initial_states = torch.where( has_initial_states_p[:num_prefills, None, None, None], ssm_state[state_indices_tensor_p], 0) - scan_output, varlen_state = mamba_chunk_scan_combined( + mamba_outputs = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim), @@ -642,21 +677,63 @@ def forward_cuda( chunk_offsets=chunk_offsets_p, cu_seqlens=query_start_loc_p, initial_states=initial_states, + return_intermediate_states=cache_enabled, return_varlen_states=True, return_final_states=False, dt_softplus=True, dt_limit=(0.0, float("inf")), ) - # update ssm states - # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor - ssm_state[state_indices_tensor_p] = varlen_state + if cache_enabled: + scan_output, states, varlen_state = mamba_outputs + + # update ssm states + # - varlen state at FINAL chunk is a (num_prefills, nheads, headdim, dstate) tensor + # states (num_prefills, states_at_INTERMEDIATE_chunks, nheads, headdim, dstate) tensor + # Combine to have all_states (num_prefills, ALL_states, nheads, headdim, dstate) tensor: + all_states = torch.concat([states[:,1:], varlen_state.unsqueeze(1)], 1) # first from returned states is zero + state_stride = mamba_block_size // chunk_size + # states for chunks 0,1,2,3,4 (chunk_size=256) correspond to + # states at blocks 0,0,1,1,2 (block_size=512). For first blocks, stride(=2). For last block can't strid + + # initial state: state_indices_tensor_p[, last_computed_idx_p[]] + # new states: state_indices_tensor_p[, current_first_idx_p[]:current_last_idx_p[]] + + # Note: Currently works for 1 request only: + states_at_blocks = torch.concat([all_states[:,state_stride-1:(current_last_idx_p[0]-current_first_idx_p[0])*state_stride:state_stride], varlen_state.unsqueeze(1)], 1) + if cache_strategy == "all": + ssm_state[state_indices_tensor_p[:,current_first_idx_p[0]:current_last_idx_p[0]+1]] = states_at_blocks + elif cache_strategy == "last": + ssm_state[state_indices_tensor_p[:,current_last_idx_p[0]-1:]] = states_at_blocks[:,-2:] + else: + scan_output, varlen_state = mamba_outputs + # update ssm states + # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor + ssm_state[state_indices_tensor_p] = varlen_state # - reshape ssd_output_list.append(scan_output.view(num_prefill_tokens, -1)) # Process decode requests if has_decode: + + if cache_enabled: + # if at_block_boundary, load states from previous blocks: + at_block_boundary = mamba2_metadata.seq_lens % mamba_block_size == 0 + finished_blocks = attn_metadata.seq_lens[0] // mamba_block_size #e.g. 1024 -> 2 blocks ; 1025 -> 2 blocks + input_block = cdiv(attn_metadata.seq_lens[0], mamba_block_size) #e.g. 1024 -> 2nd block, 1025 -> 3rd block + output_block = cdiv(attn_metadata.seq_lens[0]+1, mamba_block_size) #e.g. 1023 -> 2nd block, 1024 -> 3rd block + state_indices_tensor_d_input = state_indices_tensor_d[:,input_block-1] + state_indices_tensor_d_output = state_indices_tensor_d[:,output_block-1] + + # copy initial state to new location, as update kernel works in place + if output_block > input_block: + conv_state[state_indices_tensor_d_output] = conv_state[state_indices_tensor_d_input] + else: + # Without caching, read and write in-place to the same blocks: + state_indices_tensor_d_input = state_indices_tensor_d + state_indices_tensor_d_output = state_indices_tensor_d + # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, @@ -664,7 +741,7 @@ def forward_cuda( conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d) + conv_state_indices=state_indices_tensor_d_output) hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( hidden_states_B_C_d) @@ -696,7 +773,8 @@ def forward_cuda( z=None, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices_tensor_d, + state_batch_indices=state_indices_tensor_d_input, + dst_state_batch_indices=state_indices_tensor_d_output, ) if envs.VLLM_USE_V1: diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 3f67fc35afdf..6507fca34833 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -52,6 +52,7 @@ def _selective_scan_update_kernel( z_ptr, out_ptr, state_batch_indices_ptr, + dst_state_batch_indices_ptr, pad_slot_id, # Matrix dimensions batch, @@ -107,11 +108,16 @@ def _selective_scan_update_kernel( # is taken from the state_batch_indices_ptr Otherwise, the state coordinate # is the same as the batch id. if HAS_STATE_BATCH_INDICES: + dst_state_batch_indices_ptr += pid_b + dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64) + dst_state_ptr = state_ptr + (dst_state_batch_idx * stride_state_batch + + pid_h * stride_state_head) state_batch_indices_ptr += pid_b state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += (state_batch_idx * stride_state_batch + pid_h * stride_state_head) else: + dst_state_ptr = state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head @@ -131,6 +137,8 @@ def _selective_scan_update_kernel( offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) + dst_state_ptrs = dst_state_ptr + (offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate) x_ptrs = x_ptr + offs_m * stride_x_dim dt_ptrs = dt_ptr + offs_m * stride_dt_dim if HAS_DT_BIAS: @@ -185,7 +193,7 @@ def _selective_scan_update_kernel( mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: mask &= (state_batch_idx != pad_slot_id) - tl.store(state_ptrs, state, mask=mask) + tl.store(dst_state_ptrs, state, mask=mask) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D @@ -205,6 +213,7 @@ def selective_state_update(state, dt_bias=None, dt_softplus=False, state_batch_indices=None, + dst_state_batch_indices=None, pad_slot_id=PAD_SLOT_ID): """ Argument: @@ -264,6 +273,7 @@ def selective_state_update(state, assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: assert state_batch_indices.shape == (batch, ) + assert dst_state_batch_indices.shape == (batch, ) out = torch.empty_like(x) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else @@ -289,6 +299,7 @@ def selective_state_update(state, z, out, state_batch_indices, + dst_state_batch_indices, pad_slot_id, batch, nheads, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index b121275e9eb3..3251c3f19826 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -180,6 +180,7 @@ def mamba_chunk_scan_combined(x, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + return_intermediate_states=False, return_final_states=False, return_varlen_states=False): """ @@ -222,11 +223,16 @@ def mamba_chunk_scan_combined(x, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) - if not return_varlen_states: - return out if not return_final_states else (out, final_states) - else: + if return_varlen_states: varlen_states = rest[0] - return (out, - varlen_states) if not return_final_states else (out, - final_states, - varlen_states) + if return_final_states: + return (out, final_states, varlen_states) + elif return_intermediate_states: + return (out, states, varlen_states) + else: + return (out, varlen_states) + else: + if return_final_states: + return (out, final_states) + else: + return (out, states) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 6f50b1753098..d4b54d3172bb 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -265,12 +265,37 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=model_config.max_model_len, ).page_size_bytes - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, - 16 * attn_page_size_1_token) + if cache_config.enable_prefix_caching: + # With prefix caching, select attention block size to + # optimize for mamba kernel performance + + # mamba SSD kernel uses a chunk_size, e.g. 256. Align the block to the kernel: + # use lowest multiple of 256 attention tokens that would fit mamba_page_size + # e.g. mamba page size of 788kB ; attn_1_token 2kB -> fits ~394 tokens + # then round up to a mulitple of 256 -> 512 tokens + # attn_block_size = 512 + # mamba_block_size = 512 (aligned to a multiple of kernel chunk_size) + chunk_size = model_config.get_mamba_chunk_size() + attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) + attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) + cache_config.mamba_block_size = attn_block_size + + # This below might be redundant now: + if model_config.max_model_len % attn_block_size != 0: + # Currently HybridCacheManager uses max_model_len for Mamba block + # and requires it to be a multiple of attention block + model_config.max_model_len -= model_config.max_model_len % attn_block_size + print("Adjusting max_model_len to", model_config.max_model_len) + else: + # Without prefix caching, select minimum valid attention block size + # to minimize mamba state padding + + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + attn_block_size = 16 * cdiv(mamba_page_size, + 16 * attn_page_size_1_token) # override attention block size if either (a) the # user has not set it or (b) the user has set it diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 59c1dce48ee7..9145b4ea5279 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -567,10 +567,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - if cache_config.enable_prefix_caching: - raise RuntimeError( - "GraniteMoeHybrid currently does not support prefix caching") - self.quant_config = vllm_config.quant_config self.config = config self.scheduler_config = scheduler_config diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index dca5de46c065..8b1ee8a0c726 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -82,6 +82,7 @@ class Mamba2AttentionMetadata: cu_seqlen: Optional[int] = None batch_ptr: Optional[torch.tensor] = None token_chunk_offset_ptr: Optional[torch.tensor] = None + cache_spec: Optional[MambaSpec] = None class Mamba2AttentionMetadataBuilder( @@ -116,7 +117,23 @@ def build(self, has_initial_states = None prep_initial_states = False - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + if self.kv_cache_spec.cache_strategy == "disabled": + # Always return just a single block per each request: + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + else: + # Return a tensor of shape (#requests, #blocks for longest request) + # filled in with cached and newly allocated blocks for each request + cache_block_size = self.kv_cache_spec.block_size + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_bounds_cpu = (seq_lens_cpu + cache_block_size - 1) // cache_block_size + max_num_blocks = block_table_bounds_cpu.max() + paged_kv_indices = common_attn_metadata.block_table_tensor[:, :max_num_blocks] + if self.kv_cache_spec.cache_strategy == "last": + # TODO: The "last" strategy is not fully implemented yet + # In the "last" strategy, the allocator puts 2 block in front + # For easiness of handling, we move them to be two last in list + paged_kv_indices = torch.roll(paged_kv_indices, max_num_blocks.item()-2, -1) + state_indices_tensor = paged_kv_indices num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills(common_attn_metadata, @@ -162,6 +179,7 @@ def build(self, has_initial_states=has_initial_states, prep_initial_states=prep_initial_states, chunk_size=self.chunk_size, + cache_spec=self.kv_cache_spec, seq_idx=seq_idx, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e8a44c7773a7..cefc1fa47ab5 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -532,29 +532,149 @@ def find_longest_cache_hit( assert isinstance( kv_cache_spec, MambaSpec), ("MambaManager can only be used for mamba groups") - # Prefix caching is not supported for mamba now. Always return empty - # list. computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) + if kv_cache_spec.cache_strategy == "disabled": + return computed_blocks #return empty list if cache is disabled + + max_num_blocks = max_length // kv_cache_spec.block_size + # Search from right to left and early stop when a match is found. + for i in range(max_num_blocks - 1, -1, -1): + if cached_block := block_pool.get_cached_block( + block_hashes[i], kv_cache_group_ids): + for computed, cached in zip(computed_blocks, cached_block): + # the hit length logic later assumes: + # hit_length = len(hit_blocks_other_attn[0]) + # * self.other_block_size + # so we insert dummy blocks at the beginning: + if i > 0: + computed.extend([block_pool.null_block] * i) + computed.append(cached) + break # we just need the last match - early stopping + return computed_blocks def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: - # Each request will always have 1 block at this moment, so no need to - # remove blocks. + # TODO: For "all" strategy, and potentially for "last" + # we should already start removing initial blocks pass def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> int: - return 0 + if self.kv_cache_spec.cache_strategy == "disabled": + return 0 + + # Same as full attention logic: + blocks = self.req_to_blocks[request_id] + num_common_blocks = 0 + for block in blocks: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + return num_common_blocks def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: - new_blocks = super().allocate_new_blocks(request_id, num_tokens) - assert len(self.req_to_blocks[request_id]) == 1, ( - "MambaManager should only allocate 1 block for each request.") - return new_blocks + num_tokens: int) -> list[KVCacheBlock]: + if self.kv_cache_spec.cache_strategy == "disabled": + new_blocks = super().allocate_new_blocks(request_id, num_tokens) + assert len(self.req_to_blocks[request_id]) == 1, ( + "MambaManager should only allocate 1 block for each request.") + return new_blocks + + req_blocks = self.req_to_blocks[request_id] + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = num_required_blocks - len(req_blocks) + if num_new_blocks <= 0: + return [] + else: + if num_new_blocks > 2 and self.kv_cache_spec.cache_strategy == "last": + # for the last strategy only - allocate 2 blocks: + # one for block_size aligned state + # and one for the last temporary state + new_blocks = self.block_pool.get_new_blocks(2) + else: + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) + return new_blocks + + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, + new_computed_blocks: list[KVCacheBlock]) -> int: + """ + Get the number of blocks needed to be allocated for the request. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. + + Returns: + The number of blocks. + """ + + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = (num_required_blocks - len(new_computed_blocks) - + len(self.req_to_blocks[request_id])) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum( + blk.ref_cnt == 0 and not blk.is_null + for blk in new_computed_blocks) + return num_new_blocks + num_evictable_computed_blocks + + def cache_blocks(self, request: Request, block_hashes: list[BlockHash], + num_tokens: int) -> None: + """ + Cache the blocks for the request. + + Args: + request: The request. + block_hashes: The block hashes of the request. + num_tokens: The total number of tokens that need to be cached + (including tokens that are already cached). + """ + #TODO: Just copied parent class implementation here to verify logic. + num_cached_blocks = self.num_cached_block[request.request_id] + num_full_blocks = num_tokens // self.block_size + + self.block_pool.cache_full_blocks( + request=request, + blocks=self.req_to_blocks[request.request_id], + block_hashes=block_hashes, + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks, + block_size=self.block_size, + kv_cache_group_id=self.kv_cache_group_id, + hash_fn=self.caching_hash_fn, + ) + + self.num_cached_block[request.request_id] = num_full_blocks + + def free(self, request_id: str) -> None: + """ + Free the blocks for the request. + + Args: + request_id: The request ID. + """ + #TODO: Just copied parent class implementation here to verify logic. + + # Default to [] in case a request is freed (aborted) before alloc. + req_blocks = self.req_to_blocks.pop(request_id, []) + + # Free blocks in reverse order so that the tail blocks are + # freed first. + ordered_blocks = reversed(req_blocks) + + self.block_pool.free_blocks(ordered_blocks) + self.num_cached_block.pop(request_id, None) spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 1da5230116d2..e58f7705dd16 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -201,6 +201,7 @@ class MambaSpec(KVCacheSpec): dtype: torch.dtype page_size_padded: Optional[int] = None mamba_type: str = "mamba2" + cache_strategy: str = "disabled" def __post_init__(self): self.num_elements = sum(prod(shape) for shape in self.shapes) @@ -217,10 +218,15 @@ def page_size_bytes(self) -> int: return self.page_size_padded return page_size - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - # We allocate 1 block for each request now, so max_memory_usage_bytes is - # the same as page_size_bytes. - # Need to update this when supporting prefix caching. + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + if self.cache_strategy == "last": + # Keeps the last full block and one non-full block state: + return 2 * self.page_size_bytes + elif self.cache_strategy == "all": + # Keeps a state at every block boundary: + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + # By default keeps the last state only: return self.page_size_bytes diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fc55d09fc97e..efbda0b7e073 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2899,20 +2899,22 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") - max_model_len = self.vllm_config.model_config.max_model_len + mamba_block_size = self.vllm_config.cache_config.mamba_block_size + else: + # Set block_size to max_model_len, so that mamba model will always + # have only one block + mamba_block_size = self.vllm_config.model_config.max_model_len + self.vllm_config.cache_config.mamba_cache_strategy = "disabled" page_size_padded = ( self.vllm_config.cache_config.mamba_page_size_padded) - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), dtype=self.kv_cache_dtype, - block_size=max_model_len, + block_size=mamba_block_size, + cache_strategy=self.vllm_config.cache_config.mamba_cache_strategy, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type) From 87dd0a043fe57bf983900dd0696da30a8892dc00 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Fri, 29 Aug 2025 12:37:21 +0000 Subject: [PATCH 002/105] Conv kernel state handling Signed-off-by: Thomas Ortner --- .../layers/mamba/mamba_mixer2.py | 28 +++++++++++++++++-- .../layers/mamba/ops/causal_conv1d.py | 4 ++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 1e8b4c441348..016f650fd213 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -635,9 +635,33 @@ def forward_cuda( 0, 1)[:num_prefill_tokens] if cache_enabled: - #TODO - pass + def copy_x_to_conv_state(conv_state_block_idx, x_offset, x_end): + conv_state[conv_state_block_idx, :, 0] = torch.transpose(x[:, x_offset-3:x_end:mamba_block_size], 1, 0) + conv_state[conv_state_block_idx, :, 1] = torch.transpose(x[:, x_offset-2:x_end:mamba_block_size], 1, 0) + conv_state[conv_state_block_idx, :, 2] = torch.transpose(x[:, x_offset-1:x_end:mamba_block_size], 1, 0) + # initial state: state_indices_tensor_p[, last_computed_idx_p[]] + # new states: state_indices_tensor_p[, current_first_idx_p[]:current_last_idx_p[]] + if cache_strategy == "all": + # Iterate over all sequences to need prefill + for seq_idx in range(state_indices_tensor_p.shape[0]): + number_full_blocks = seq_lens_pending[seq_idx] // mamba_block_size + second_last_block_idx = number_full_blocks if seq_lens_pending[seq_idx] % mamba_block_size > 0 else number_full_blocks - 1 + if number_full_blocks > 0 and seq_lens_pending[seq_idx] % mamba_block_size > 0: + copy_x_to_conv_state(state_indices_tensor_p[seq_idx,current_first_idx_p[seq_idx]:current_first_idx_p[seq_idx] + second_last_block_idx], mamba_block_size, mamba_block_size * second_last_block_idx) + elif cache_strategy == "last": + # i.e. keep two states: either + # a) states at the last two block boundaries or + # b) state at the last block boundary and last state of the sequence, + # which might not be at a block boundary + # Iterate over all sequences to need prefill + for seq_idx in range(state_indices_tensor_p.shape[0]): + # Only store the additional second state if there are is at least one full block and a remainder. + # Otherwise, there is only one state to store + if number_full_blocks > 0 and seq_lens_pending[seq_idx] % mamba_block_size > 0: + second_last_block_idx = number_full_blocks if seq_lens_pending[seq_idx] % mamba_block_size > 0 else number_full_blocks - 1 + copy_x_to_conv_state(state_indices_tensor_p[seq_idx,current_last_idx_p[seq_idx]-1:current_last_idx_p[seq_idx]], mamba_block_size * second_last_block_idx, mamba_block_size * second_last_block_idx) + hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_B_C_p) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index b8d4bbc37105..18878652bf09 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -415,7 +415,9 @@ def causal_conv1d_fn( activation = "silu" args = None - out = torch.empty_like(x) + #out = torch.empty_like(x) + #TODO: Noticed strange behavior, maybe due to use of uninitialzed values? + out = torch.zeros_like(x) if metadata is not None: cu_seqlen = metadata.cu_seqlen nums_dict = metadata.nums_dict From 224c9e1c5269f07742ab8d550f00a9e6d6796f03 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 4 Sep 2025 04:25:32 -0400 Subject: [PATCH 003/105] Get things working with latest code Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 159 ++++++++++++------ .../layers/mamba/ops/mamba_ssm.py | 2 +- vllm/model_executor/models/config.py | 5 - vllm/v1/attention/backends/mamba2_attn.py | 18 +- vllm/v1/core/single_type_kv_cache_manager.py | 65 +------ 5 files changed, 130 insertions(+), 119 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index a6edf01fb02a..c3a4c010c5f2 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -482,7 +482,7 @@ def forward_cuda( has_initial_states_p = attn_metadata.has_initial_states_p prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size - seq_idx_p = attn_metadata.seq_idx + seq_idx_p = attn_metadata.seq_idx_p chunk_indices_p = attn_metadata.chunk_indices_p chunk_offsets_p = attn_metadata.chunk_offsets_p mamba_block_size = attn_metadata.cache_spec.block_size @@ -570,16 +570,28 @@ def forward_cuda( ) # Note: Eventually this will be moved to mamba2 metadata builder: - seq_lens_pending = (torch.roll(attn_metadata.query_start_loc, -1, -1) - attn_metadata.query_start_loc)[:-1] + seq_lens_pending = ( + torch.roll(attn_metadata.query_start_loc, -1, -1) - + attn_metadata.query_start_loc)[:-1] seq_lens_completed = (mamba2_metadata.seq_lens - seq_lens_pending) - last_computed_token_block_idx = seq_lens_completed // mamba_block_size - 1 # e.g. 16 blocks computed; 0th based indexing -> state[15] - last_computed_token_block_idx = last_computed_token_block_idx.clamp(min=0) #in case it's non-computed it's -1 and causes later issues with indexing - current_first_token_block_idx = cdiv(seq_lens_completed + 1, mamba_block_size) - 1 - current_last_token_block_idx = cdiv(seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 - - last_computed_idx_d, last_computed_idx_p = torch.split(last_computed_token_block_idx, [num_decodes, num_prefills], dim=0) - current_first_idx_d, current_first_idx_p = torch.split(current_first_token_block_idx, [num_decodes, num_prefills], dim=0) - current_last_idx_d, current_last_idx_p = torch.split(current_last_token_block_idx, [num_decodes, num_prefills], dim=0) + last_computed_token_block_idx = seq_lens_completed // mamba_block_size - 1 # e.g. 16 blocks computed; 0th based indexing -> state[15] + last_computed_token_block_idx = last_computed_token_block_idx.clamp( + min=0 + ) #in case it's non-computed it's -1 and causes later issues with indexing + current_first_token_block_idx = cdiv(seq_lens_completed + 1, + mamba_block_size) - 1 + current_last_token_block_idx = cdiv( + seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 + + last_computed_idx_d, last_computed_idx_p = torch.split( + last_computed_token_block_idx, [num_decodes, num_prefills], + dim=0) + current_first_idx_d, current_first_idx_p = torch.split( + current_first_token_block_idx, [num_decodes, num_prefills], + dim=0) + current_last_idx_d, current_last_idx_p = torch.split( + current_last_token_block_idx, [num_decodes, num_prefills], + dim=0) query_start_loc_p = ( attn_metadata.query_start_loc[-num_prefills - 1:] - @@ -638,15 +650,21 @@ def forward_cuda( if mamba2_metadata.cu_seqlen is None: mamba2_metadata = update_metadata(x, query_start_loc_p, mamba2_metadata) - + kernel_conv1d_indices = state_indices_tensor_p if cache_enabled: # Kernel expects to have the initial state here and overwrites it -> use final state location - if has_initial_states_p is not None and has_initial_states_p.sum() > 0: - conv_state_idx_input = state_indices_tensor_p.index_select(1, last_computed_idx_p).diag().unsqueeze(1) - conv_state_idx_output = state_indices_tensor_p.index_select(1, current_last_idx_p).diag().unsqueeze(1) - conv_state[conv_state_idx_output[has_initial_states_p]] = conv_state[conv_state_idx_input[has_initial_states_p]] - kernel_conv1d_indices = state_indices_tensor_p.index_select(1, current_last_idx_p).diag() + if has_initial_states_p is not None and has_initial_states_p.sum( + ) > 0: + conv_state_idx_input = state_indices_tensor_p.index_select( + 1, last_computed_idx_p).diag().unsqueeze(1) + conv_state_idx_output = state_indices_tensor_p.index_select( + 1, current_last_idx_p).diag().unsqueeze(1) + conv_state[conv_state_idx_output[ + has_initial_states_p]] = conv_state[ + conv_state_idx_input[has_initial_states_p]] + kernel_conv1d_indices = state_indices_tensor_p.index_select( + 1, current_last_idx_p).diag() hidden_states_B_C_p = causal_conv1d_fn( x, @@ -659,35 +677,55 @@ def forward_cuda( metadata=mamba2_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] - + if cache_enabled: - def copy_x_to_conv_state(conv_state_block_idx, x_offset, x_end): - conv_state[conv_state_block_idx, :, 0] = torch.transpose(x[:, x_offset-3:x_end:mamba_block_size], 1, 0) - conv_state[conv_state_block_idx, :, 1] = torch.transpose(x[:, x_offset-2:x_end:mamba_block_size], 1, 0) - conv_state[conv_state_block_idx, :, 2] = torch.transpose(x[:, x_offset-1:x_end:mamba_block_size], 1, 0) + + def copy_x_to_conv_state(conv_state_block_idx, x_offset, + x_end): + conv_state[conv_state_block_idx, :, 0] = torch.transpose( + x[:, x_offset - 3:x_end:mamba_block_size], 1, 0) + conv_state[conv_state_block_idx, :, 1] = torch.transpose( + x[:, x_offset - 2:x_end:mamba_block_size], 1, 0) + conv_state[conv_state_block_idx, :, 2] = torch.transpose( + x[:, x_offset - 1:x_end:mamba_block_size], 1, 0) # initial state: state_indices_tensor_p[, last_computed_idx_p[]] - # new states: state_indices_tensor_p[, current_first_idx_p[]:current_last_idx_p[]] + # new states: state_indices_tensor_p[, current_first_idx_p[]:current_last_idx_p[]] if cache_strategy == "all": # Iterate over all sequences to need prefill for seq_idx in range(state_indices_tensor_p.shape[0]): - number_full_blocks = seq_lens_pending[seq_idx] // mamba_block_size - second_last_block_idx = number_full_blocks if seq_lens_pending[seq_idx] % mamba_block_size > 0 else number_full_blocks - 1 - if number_full_blocks > 0 and seq_lens_pending[seq_idx] % mamba_block_size > 0: - copy_x_to_conv_state(state_indices_tensor_p[seq_idx,current_first_idx_p[seq_idx]:current_first_idx_p[seq_idx] + second_last_block_idx], mamba_block_size, mamba_block_size * second_last_block_idx) - elif cache_strategy == "last": - # i.e. keep two states: either - # a) states at the last two block boundaries or - # b) state at the last block boundary and last state of the sequence, + number_full_blocks = seq_lens_pending[ + seq_idx] // mamba_block_size + second_last_block_idx = number_full_blocks if seq_lens_pending[ + seq_idx] % mamba_block_size > 0 else number_full_blocks - 1 + if number_full_blocks > 0 and seq_lens_pending[ + seq_idx] % mamba_block_size > 0: + copy_x_to_conv_state( + state_indices_tensor_p[ + seq_idx, current_first_idx_p[seq_idx]: + current_first_idx_p[seq_idx] + + second_last_block_idx], mamba_block_size, + mamba_block_size * second_last_block_idx) + elif cache_strategy == "last": + # i.e. keep two states: either + # a) states at the last two block boundaries or + # b) state at the last block boundary and last state of the sequence, # which might not be at a block boundary # Iterate over all sequences to need prefill for seq_idx in range(state_indices_tensor_p.shape[0]): # Only store the additional second state if there are is at least one full block and a remainder. # Otherwise, there is only one state to store - if number_full_blocks > 0 and seq_lens_pending[seq_idx] % mamba_block_size > 0: - second_last_block_idx = number_full_blocks if seq_lens_pending[seq_idx] % mamba_block_size > 0 else number_full_blocks - 1 - copy_x_to_conv_state(state_indices_tensor_p[seq_idx,current_last_idx_p[seq_idx]-1:current_last_idx_p[seq_idx]], mamba_block_size * second_last_block_idx, mamba_block_size * second_last_block_idx) - + if number_full_blocks > 0 and seq_lens_pending[ + seq_idx] % mamba_block_size > 0: + second_last_block_idx = number_full_blocks if seq_lens_pending[ + seq_idx] % mamba_block_size > 0 else number_full_blocks - 1 + copy_x_to_conv_state( + state_indices_tensor_p[ + seq_idx, current_last_idx_p[seq_idx] - + 1:current_last_idx_p[seq_idx]], + mamba_block_size * second_last_block_idx, + mamba_block_size * second_last_block_idx) + hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_B_C_p) @@ -744,20 +782,31 @@ def copy_x_to_conv_state(conv_state_block_idx, x_offset, x_end): # - varlen state at FINAL chunk is a (num_prefills, nheads, headdim, dstate) tensor # states (num_prefills, states_at_INTERMEDIATE_chunks, nheads, headdim, dstate) tensor # Combine to have all_states (num_prefills, ALL_states, nheads, headdim, dstate) tensor: - all_states = torch.concat([states[:,1:], varlen_state.unsqueeze(1)], 1) # first from returned states is zero + all_states = torch.concat( + [states[:, 1:], varlen_state.unsqueeze(1)], + 1) # first from returned states is zero state_stride = mamba_block_size // chunk_size - # states for chunks 0,1,2,3,4 (chunk_size=256) correspond to + # states for chunks 0,1,2,3,4 (chunk_size=256) correspond to # states at blocks 0,0,1,1,2 (block_size=512). For first blocks, stride(=2). For last block can't strid - + # initial state: state_indices_tensor_p[, last_computed_idx_p[]] - # new states: state_indices_tensor_p[, current_first_idx_p[]:current_last_idx_p[]] - + # new states: state_indices_tensor_p[, current_first_idx_p[]:current_last_idx_p[]] + # Note: Currently works for 1 request only: - states_at_blocks = torch.concat([all_states[:,state_stride-1:(current_last_idx_p[0]-current_first_idx_p[0])*state_stride:state_stride], varlen_state.unsqueeze(1)], 1) - if cache_strategy == "all": - ssm_state[state_indices_tensor_p[:,current_first_idx_p[0]:current_last_idx_p[0]+1]] = states_at_blocks + states_at_blocks = torch.concat([ + all_states[:, state_stride - 1:(current_last_idx_p[0] - + current_first_idx_p[0]) * + state_stride:state_stride], + varlen_state.unsqueeze(1) + ], 1) + if cache_strategy == "all": + ssm_state[state_indices_tensor_p[:, current_first_idx_p[0]: + current_last_idx_p[0] + + 1]] = states_at_blocks elif cache_strategy == "last": - ssm_state[state_indices_tensor_p[:,current_last_idx_p[0]-1:]] = states_at_blocks[:,-2:] + ssm_state[ + state_indices_tensor_p[:, current_last_idx_p[0] - + 1:]] = states_at_blocks[:, -2:] else: varlen_state = mamba_outputs @@ -767,19 +816,29 @@ def copy_x_to_conv_state(conv_state_block_idx, x_offset, x_end): # Process decode requests if has_decode: - + if cache_enabled: # if at_block_boundary, load states from previous blocks: at_block_boundary = mamba2_metadata.seq_lens % mamba_block_size == 0 - finished_blocks = attn_metadata.seq_lens[0] // mamba_block_size #e.g. 1024 -> 2 blocks ; 1025 -> 2 blocks - input_block = cdiv(attn_metadata.seq_lens[0], mamba_block_size) #e.g. 1024 -> 2nd block, 1025 -> 3rd block - output_block = cdiv(attn_metadata.seq_lens[0]+1, mamba_block_size) #e.g. 1023 -> 2nd block, 1024 -> 3rd block - state_indices_tensor_d_input = state_indices_tensor_d[:,input_block-1] - state_indices_tensor_d_output = state_indices_tensor_d[:,output_block-1] + finished_blocks = attn_metadata.seq_lens[ + 0] // mamba_block_size #e.g. 1024 -> 2 blocks ; 1025 -> 2 blocks + input_block = cdiv( + attn_metadata.seq_lens[0], mamba_block_size + ) #e.g. 1024 -> 2nd block, 1025 -> 3rd block + output_block = cdiv( + attn_metadata.seq_lens[0] + 1, mamba_block_size + ) #e.g. 1023 -> 2nd block, 1024 -> 3rd block + state_indices_tensor_d_input = state_indices_tensor_d[:, + input_block + - 1] + state_indices_tensor_d_output = state_indices_tensor_d[:, + output_block + - 1] # copy initial state to new location, as update kernel works in place if output_block > input_block: - conv_state[state_indices_tensor_d_output] = conv_state[state_indices_tensor_d_input] + conv_state[state_indices_tensor_d_output] = conv_state[ + state_indices_tensor_d_input] else: # Without caching, read and write in-place to the same blocks: state_indices_tensor_d_input = state_indices_tensor_d diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 4411fffa82d5..30d59d45813b 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -214,7 +214,7 @@ def selective_state_update(state, dt_softplus=False, state_batch_indices=None, dst_state_batch_indices=None, - pad_slot_id=PAD_SLOT_ID): + pad_slot_id=PAD_SLOT_ID, out=None): """ Argument: diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 49cd68650331..9bed7d833fe7 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -297,11 +297,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config compilation_config = vllm_config.compilation_config - # TODO(tdoublep): remove once prefix caching is enabled - cache_config.enable_prefix_caching = False - logger.info("Hybrid or mamba-based model detected: disabling prefix " - "caching since it is not yet supported.") - # TODO(tdoublep): remove as full cuda graph support is added FCG_NOT_SUPPORTED_MODELS = [ "Lfm2ForCausalLM", "MiniMaxText01ForCausalLM" diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 01c7b5c9df7b..75a34748133e 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -13,7 +13,7 @@ BaseMambaAttentionMetadataBuilder) from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, @@ -85,7 +85,7 @@ class Mamba2AttentionMetadata: cu_seqlen: Optional[int] = None batch_ptr: Optional[torch.tensor] = None token_chunk_offset_ptr: Optional[torch.tensor] = None - + cache_spec: Optional[MambaSpec] = None class Mamba2AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): @@ -114,20 +114,24 @@ def build(self, if self.kv_cache_spec.cache_strategy == "disabled": # Always return just a single block per each request: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + state_indices_tensor = common_attn_metadata.block_table_tensor[:, + 0] else: # Return a tensor of shape (#requests, #blocks for longest request) # filled in with cached and newly allocated blocks for each request cache_block_size = self.kv_cache_spec.block_size seq_lens_cpu = common_attn_metadata.seq_lens_cpu - block_table_bounds_cpu = (seq_lens_cpu + cache_block_size - 1) // cache_block_size - max_num_blocks = block_table_bounds_cpu.max() - paged_kv_indices = common_attn_metadata.block_table_tensor[:, :max_num_blocks] + block_table_bounds_cpu = (seq_lens_cpu + cache_block_size - + 1) // cache_block_size + max_num_blocks = block_table_bounds_cpu.max() + paged_kv_indices = common_attn_metadata.block_table_tensor[:, : + max_num_blocks] if self.kv_cache_spec.cache_strategy == "last": # TODO: The "last" strategy is not fully implemented yet # In the "last" strategy, the allocator puts 2 block in front # For easiness of handling, we move them to be two last in list - paged_kv_indices = torch.roll(paged_kv_indices, max_num_blocks.item()-2, -1) + paged_kv_indices = torch.roll(paged_kv_indices, + max_num_blocks.item() - 2, -1) state_indices_tensor = paged_kv_indices num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 9bb01f040ea0..2f4ae3995f34 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -532,7 +532,7 @@ def find_longest_cache_hit( computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) if kv_cache_spec.cache_strategy == "disabled": - return computed_blocks #return empty list if cache is disabled + return computed_blocks #return empty list if cache is disabled max_num_blocks = max_length // kv_cache_spec.block_size # Search from right to left and early stop when a match is found. @@ -540,15 +540,15 @@ def find_longest_cache_hit( if cached_block := block_pool.get_cached_block( block_hashes[i], kv_cache_group_ids): for computed, cached in zip(computed_blocks, cached_block): - # the hit length logic later assumes: - # hit_length = len(hit_blocks_other_attn[0]) + # the hit length logic later assumes: + # hit_length = len(hit_blocks_other_attn[0]) # * self.other_block_size # so we insert dummy blocks at the beginning: if i > 0: computed.extend([block_pool.null_block] * i) computed.append(cached) - break # we just need the last match - early stopping - + break # we just need the last match - early stopping + return computed_blocks def remove_skipped_blocks(self, request_id: str, @@ -561,7 +561,7 @@ def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> int: if self.kv_cache_spec.cache_strategy == "disabled": return 0 - + # Same as full attention logic: blocks = self.req_to_blocks[request_id] num_common_blocks = 0 @@ -573,7 +573,7 @@ def get_num_common_prefix_blocks(self, request_id: str, return num_common_blocks def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: + num_tokens: int) -> list[KVCacheBlock]: if self.kv_cache_spec.cache_strategy == "disabled": new_blocks = super().allocate_new_blocks(request_id, num_tokens) assert len(self.req_to_blocks[request_id]) == 1, ( @@ -587,14 +587,14 @@ def allocate_new_blocks(self, request_id: str, return [] else: if num_new_blocks > 2 and self.kv_cache_spec.cache_strategy == "last": - # for the last strategy only - allocate 2 blocks: + # for the last strategy only - allocate 2 blocks: # one for block_size aligned state # and one for the last temporary state new_blocks = self.block_pool.get_new_blocks(2) else: new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) - return new_blocks + return new_blocks def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, @@ -625,53 +625,6 @@ def get_num_blocks_to_allocate( for blk in new_computed_blocks) return num_new_blocks + num_evictable_computed_blocks - def cache_blocks(self, request: Request, block_hashes: list[BlockHash], - num_tokens: int) -> None: - """ - Cache the blocks for the request. - - Args: - request: The request. - block_hashes: The block hashes of the request. - num_tokens: The total number of tokens that need to be cached - (including tokens that are already cached). - """ - - #TODO: Just copied parent class implementation here to verify logic. - num_cached_blocks = self.num_cached_block[request.request_id] - num_full_blocks = num_tokens // self.block_size - - self.block_pool.cache_full_blocks( - request=request, - blocks=self.req_to_blocks[request.request_id], - block_hashes=block_hashes, - num_cached_blocks=num_cached_blocks, - num_full_blocks=num_full_blocks, - block_size=self.block_size, - kv_cache_group_id=self.kv_cache_group_id, - hash_fn=self.caching_hash_fn, - ) - - self.num_cached_block[request.request_id] = num_full_blocks - - def free(self, request_id: str) -> None: - """ - Free the blocks for the request. - - Args: - request_id: The request ID. - """ - #TODO: Just copied parent class implementation here to verify logic. - - # Default to [] in case a request is freed (aborted) before alloc. - req_blocks = self.req_to_blocks.pop(request_id, []) - - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(req_blocks) - - self.block_pool.free_blocks(ordered_blocks) - self.num_cached_block.pop(request_id, None) class CrossAttentionManager(SingleTypeKVCacheManager): """Manager for cross-attention KV cache in encoder-decoder models.""" From dddb650c10f6f2015cf84bd2983ff2a5b12a417a Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 05:45:32 -0400 Subject: [PATCH 004/105] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_state.py | 83 +++++++++++++++++-- .../layers/mamba/ops/ssd_combined.py | 10 ++- .../layers/mamba/ops/ssd_state_passing.py | 23 +++-- vllm/v1/attention/backends/mamba2_attn.py | 4 + vllm/v1/core/sched/scheduler.py | 2 +- 5 files changed, 106 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index ad58a9918f03..a9b07660589a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -14,6 +14,22 @@ from .mamba_ssm import softplus +@triton.jit +def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, + BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 @triton.autotune( configs=[ @@ -35,6 +51,7 @@ def _chunk_cumsum_fwd_kernel( dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, + cu_seqlens_ptr, # Matrix dimension batch, seqlen, @@ -42,6 +59,7 @@ def _chunk_cumsum_fwd_kernel( chunk_size, dt_min, dt_max, + num_seqs, # Strides stride_dt_batch, stride_dt_seqlen, @@ -68,7 +86,23 @@ def _chunk_cumsum_fwd_kernel( # https://github.com/triton-lang/triton/issues/1058 pid_c = tl.program_id(axis=1).to(tl.int64) pid_h = tl.program_id(axis=2) - dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + + + seq_idx = find_seq_idx(cu_seqlens_ptr, pid_c, num_seqs, chunk_size, True) + + chunk_start_idx = tl.load(cu_seqlens_ptr + seq_idx) // chunk_size + seq_idx + + chunk_local_idx = pid_c - chunk_start_idx + + cur_batch_in_all_start_idx = tl.load(cu_seqlens_ptr + seq_idx) + cur_batch_in_all_stop_idx = tl.load(cu_seqlens_ptr + seq_idx + 1) + cur_batch_query_len = cur_batch_in_all_stop_idx - cur_batch_in_all_start_idx + + # skip any unncessary work + if chunk_local_idx * chunk_size >= cur_batch_query_len: + return + + dt_ptr += pid_b * stride_dt_batch + (cur_batch_in_all_start_idx + chunk_local_idx * chunk_size) * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk @@ -81,7 +115,7 @@ def _chunk_cumsum_fwd_kernel( offs_c[None, :] * stride_dt_out_csize) dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit = min(chunk_size, cur_batch_query_len - chunk_local_idx * chunk_size) dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & @@ -102,13 +136,13 @@ def _chunk_cumsum_fwd_kernel( 0.0) tl.store(dt_out_ptrs, dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) tl.store(dA_cs_ptrs, dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) @triton.autotune( @@ -196,6 +230,7 @@ def _chunk_state_fwd_kernel( states_ptr, dt_ptr, dA_cumsum_ptr, + cu_seqlens_ptr, seq_idx_ptr, # Matrix dimensions hdim, @@ -204,6 +239,7 @@ def _chunk_state_fwd_kernel( batch, seqlen, nheads_ngroups_ratio, + num_seqs, # Strides stride_x_batch, stride_x_seqlen, @@ -241,13 +277,26 @@ def _chunk_state_fwd_kernel( num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + + seq_idx = find_seq_idx(cu_seqlens_ptr, pid_c, num_seqs, chunk_size, True) + chunk_start_idx = tl.load(cu_seqlens_ptr + seq_idx) // chunk_size + seq_idx + chunk_local_idx = pid_c - chunk_start_idx + cur_batch_in_all_start_idx = tl.load(cu_seqlens_ptr + seq_idx) + cur_batch_in_all_stop_idx = tl.load(cu_seqlens_ptr + seq_idx + 1) + cur_batch_query_len = cur_batch_in_all_stop_idx - cur_batch_in_all_start_idx + + # skip any unncessary work + if chunk_local_idx * chunk_size >= cur_batch_query_len: + return + + seqlen_offset = cur_batch_in_all_start_idx + chunk_local_idx*chunk_size + b_ptr += pid_b * stride_b_batch + seqlen_offset * stride_b_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_b_head - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + x_ptr += pid_b * stride_x_batch + seqlen_offset * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + seq_idx_ptr += pid_b * stride_seq_idx_batch + seqlen_offset * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -263,7 +312,8 @@ def _chunk_state_fwd_kernel( if HAS_SEQ_IDX: seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit = min(chunk_size, cur_batch_query_len - chunk_local_idx * chunk_size) + if HAS_SEQ_IDX: seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) @@ -556,26 +606,35 @@ def _chunk_state_varlen_kernel( def _chunk_cumsum_fwd(dt, A, chunk_size, + cu_seqlens, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape assert A.shape == (nheads, ) if dt_bias is not None: assert dt_bias.shape == (nheads, ) - nchunks = math.ceil(seqlen / chunk_size) + num_seqs = len(cu_seqlens)-1 + nchunks = seqlen // chunk_size + num_seqs + print("dt.shape: ", dt.shape) + print("A.shape: ", A.shape) + print("nchunks: ", nchunks) + print("type(cu_seqlens): ", type(cu_seqlens)) dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + print("dt_out.shape: ", dt_out.shape) dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + print("dA_cumsum.shape: ", dA_cumsum.shape) grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): @@ -585,12 +644,14 @@ def _chunk_cumsum_fwd(dt, dt_bias, dt_out, dA_cumsum, + cu_seqlens, batch, seqlen, nheads, chunk_size, dt_limit[0], dt_limit[1], + num_seqs, dt.stride(0), dt.stride(1), dt.stride(2), @@ -615,6 +676,7 @@ def _chunk_state_fwd(B, x, dt, dA_cumsum, + cu_seqlens, seq_idx=None, states=None, states_in_fp32=True): @@ -634,6 +696,7 @@ def _chunk_state_fwd(B, states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) + print("[_chunk_state_fwd] states.shape: ", states.shape) grid = lambda META: ( triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) @@ -644,6 +707,7 @@ def _chunk_state_fwd(B, states, dt, dA_cumsum, + cu_seqlens, seq_idx, headdim, dstate, @@ -651,6 +715,7 @@ def _chunk_state_fwd(B, batch, seqlen, nheads // ngroups, + len(cu_seqlens)-1, x.stride(0), x.stride(1), x.stride(2), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index fcc5c905bf77..c11f1a4c4b24 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -85,21 +85,27 @@ def _mamba_chunk_scan_combined_fwd(x, # - see the blog and paper for a visualization of the submatrices # which we refer to in the comments below + num_seqs = len(cu_seqlens) - 1 # 1. Compute chunked cumsum of A * dt # - here dt may go through a softplus activation dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, + cu_seqlens, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + print("dA_cumsum.shape: ", dA_cumsum.shape) + print("dt.shape: ", dt.shape) + # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd(B, x, dt, dA_cumsum, + cu_seqlens, seq_idx=seq_idx, states_in_fp32=True) @@ -117,6 +123,7 @@ def _mamba_chunk_scan_combined_fwd(x, states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum, + cu_seqlens, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, seq_idx=seq_idx, @@ -213,7 +220,8 @@ def mamba_chunk_scan_combined(x, out: Preallocated output tensor state_dtype: The data type of the ssm state """ - + print("-------------------------") + print("cu_seqlens: ", cu_seqlens) if not return_varlen_states: cu_seqlens = None else: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index d61c3a8cdbe9..56c61ad3cdd8 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -10,6 +10,7 @@ from vllm.triton_utils import tl, triton +from .ssd_chunk_state import find_seq_idx @triton.autotune( configs=[ @@ -33,11 +34,13 @@ def _state_passing_fwd_kernel( seq_idx_ptr, chunk_offsets_ptr, chunk_meta_num, + cu_seqlens_ptr, # Matrix dimensions dim, nchunks, seqlen, chunk_size, + num_seqs, # Strides stride_states_batch, stride_states_chunk, @@ -102,16 +105,21 @@ def _state_passing_fwd_kernel( prev_seq_idx_chunk_end = 0 logical_chunk_idx = 0 for c in range(nchunks): + + # now a chunk can only contain one sequence + seq_idx_chunk = find_seq_idx(cu_seqlens_ptr, c, num_seqs, chunk_size, True) + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) scale_mask = True if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right # boundary. # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. - seq_idx_chunk_end = tl.load(seq_idx_ptr + (min( - (c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) + seq_idx_chunk_end = seq_idx_chunk + if HAS_INITSTATES: if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: # this means in the current chunk the rightmost flushed seq @@ -130,9 +138,8 @@ def _state_passing_fwd_kernel( # - and subtract the cumsum just before that position from the total cumsum # - first, update the logical chunk index (add the number of sequences in the current physical chunk): # sequence index at the start of the current chunk - seq_idx_chunk_start = tl.load(seq_idx_ptr + - min(c * chunk_size, seqlen) * - stride_seq_idx_seqlen) + seq_idx_chunk_start = seq_idx_chunk + logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start # - load the chunk offset: c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, @@ -168,6 +175,7 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, dA_cumsum, + cu_seqlens, initial_states=None, seq_idx=None, chunk_size=None, @@ -207,6 +215,9 @@ def _state_passing_fwd( device=states.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + + print("[_state_passing_fwd] seq_idx.shape: ", seq_idx.shape) + print("[_state_passing_fwd] chunk_offsets: ", chunk_offsets) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( states, @@ -217,10 +228,12 @@ def _state_passing_fwd( seq_idx, chunk_offsets, len(chunk_offsets) if chunk_offsets is not None else 0, + cu_seqlens, dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size, + len(cu_seqlens)-1, states.stride(0), states.stride(1), states.stride(2), diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 359bad1ea9de..971f32848460 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -127,6 +127,10 @@ class Mamba2AttentionMetadata: chunk_indices_p: Optional[torch.Tensor] chunk_offsets_p: Optional[torch.Tensor] + # tpa + chunk_seqlen_start_p: Optional[torch.Tensor] + chunk_seqlen_end_p: Optional[torch.Tensor] + state_indices_tensor: torch.Tensor # shape: [batch,] # The following attributes are for triton implementation of causal_conv1d diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ed7c16dc520f..d66fa8967978 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -598,7 +598,7 @@ def schedule(self) -> SchedulerOutput: structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) - + print(scheduler_output) # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store # 2. Wrap up all the KV cache load / save ops into an opaque object From 2a7b2166c223f61e819f8a8302a004ed2a961c65 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 11:34:10 -0400 Subject: [PATCH 005/105] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 4 + .../layers/mamba/ops/ssd_bmm.py | 25 ++-- .../layers/mamba/ops/ssd_chunk_scan.py | 123 +++++------------- .../layers/mamba/ops/ssd_chunk_state.py | 114 ++++------------ .../layers/mamba/ops/ssd_combined.py | 37 +++--- .../layers/mamba/ops/ssd_state_passing.py | 59 ++------- vllm/v1/attention/backends/mamba2_attn.py | 65 ++++++++- 7 files changed, 168 insertions(+), 259 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 04ebdbca85e5..222f89f2c35b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -487,6 +487,8 @@ def forward_cuda( seq_idx_p = attn_metadata.seq_idx_p chunk_indices_p = attn_metadata.chunk_indices_p chunk_offsets_p = attn_metadata.chunk_offsets_p + cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p + last_chunk_p = attn_metadata.last_chunk_p else: conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state @@ -671,6 +673,8 @@ def forward_cuda( chunk_indices=chunk_indices_p, chunk_offsets=chunk_offsets_p, cu_seqlens=query_start_loc_p, + cu_chunk_seqlens=cu_chunk_seqlen_p, + last_chunk=last_chunk_p, initial_states=initial_states, return_varlen_states=True, return_final_states=False, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 11ca1255ebfb..e7980af51a74 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -97,6 +97,7 @@ def _bmm_chunk_fwd_kernel( b_ptr, out_ptr, seq_idx_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions seqlen, chunk_size, @@ -135,8 +136,12 @@ def _bmm_chunk_fwd_kernel( if IS_CAUSAL: if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: return - a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + a_ptr += pid_b * stride_a_batch + chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * chunk_size * stride_b_seqlen + pid_h * stride_b_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen @@ -147,7 +152,7 @@ def _bmm_chunk_fwd_kernel( offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit = chunk_seqlen_end acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): @@ -165,15 +170,7 @@ def _bmm_chunk_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if HAS_SEQ_IDX: - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) - seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, - mask=offs_n < chunk_size_limit, - other=-2) - acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head @@ -188,6 +185,7 @@ def _bmm_chunk_fwd_kernel( def _bmm_chunk_fwd(a, b, chunk_size, + cu_chunk_seqlens, seq_idx=None, causal=False, output_dtype=None): @@ -214,7 +212,7 @@ def _bmm_chunk_fwd(a, a = a.contiguous() if b.stride(-1) != 1 and b.stride(1) != 1: b = b.contiguous() - nchunks = math.ceil(seqlen / chunk_size) + nchunks = len(cu_chunk_seqlens)-1 # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype out = torch.empty( @@ -236,6 +234,7 @@ def _bmm_chunk_fwd(a, b, out, seq_idx, + cu_chunk_seqlens, seqlen, chunk_size, k, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index fb8350e191c9..9f23b4103a97 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -125,6 +125,7 @@ def _chunk_scan_fwd_kernel( chunk_indices_ptr, chunk_offsets_ptr, chunk_meta_num, + cu_chunk_seqlens_ptr, # Matrix dimensions chunk_size, hdim, @@ -190,12 +191,11 @@ def _chunk_scan_fwd_kernel( pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch - if not HAS_INITSTATES: - c_idx = pid_c - c_off = 0 - else: - c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) - c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + + # logical chunks = physical chunks + # always start from beginning + c_idx = pid_c + c_off = 0 pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) @@ -203,10 +203,14 @@ def _chunk_scan_fwd_kernel( pid_n = tl.program_id(axis=0) % num_pid_n cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( pid_h // nheads_ngroups_ratio) * stride_cb_head - x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + x_ptr += pid_b * stride_x_batch + chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( + C_ptr += pid_b * stride_C_batch + chunk_seqlen_start * stride_C_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_C_head # M-block offsets and prev states @@ -216,94 +220,30 @@ def _chunk_scan_fwd_kernel( prev_states_hdim = stride_states_hdim prev_states_dstate = stride_states_dstate - chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) + chunk_size_limit = chunk_seqlen_end + if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + seq_idx_ptr += pid_b * stride_seq_idx_batch + chunk_seqlen_start * stride_seq_idx_seqlen - # - we only need seq_idx_prev to be aligned to chunk boundary + # current sequence index + seq_idx = tl.load(seq_idx_ptr) + + # previous sequence index seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, other=0) - if HAS_INITSTATES: - # if there are init states, we only need seq_idx_m to point - # what is the current seq_idx - - # get current seq idx - if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: - seq_idx_m = tl.load( - seq_idx_ptr + - (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) - - # - recall that in ssd_state_passing, for the case c_off == 0 - # i.e., the very first sequence, we made states_ptr hold its initial state - # so this edge case is taken care of - if ((c_off == 0) and - (seq_idx_prev != seq_idx_m - ) # if a seq is changed exactly on boundary - or (c_off > 0) # implies a new example (pseudo chunk) - ): - - # - replace prev_states_ptr with init_states - prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head - prev_states_hdim = stride_init_states_hdim # override strides - prev_states_dstate = stride_init_states_dstate + if HAS_INITSTATES and (seq_idx != seq_idx_prev): + # - replace prev_states_ptr with init_states + prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # - handle chunk state limit - if HAS_INITSTATES: - - # have to split this if otherwise compilation will have problems - dA_cs_m_boundary = 0.0 - - # get the c_idx for the next (logica) chunk - c_idx_n = tl.load( - chunk_indices_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=-1 # to trigger different chunk - ) - - # - there are things to consider - # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct - # contribution of past states - # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to - # encroach into the next sequence, where c_off_n is the offset of the next - # (logical) chunk. - # An equivalent check for B is c_idx == c_idx_n, where there is repetition in - # (logical) chunk indices. - - if (c_idx == c_idx_n) or c_off > 0: - - # get the next offset - c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=chunk_size) - - # in this case, adjust down the chunk_size_limit - if c_idx == c_idx_n: - chunk_size_limit = min(c_off_n, chunk_size_limit) - - # get the cs at the offset boundary - # - c_off == 0 is a passthrough - # - We need dA_cs at the boundary, defined by c_off - no need - # to increase pointer by pid_m (it is a constant offset, - # i.e. the same for all blocks) - dA_cs_m_boundary = tl.load( - dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, - mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), - other=0.0).to(tl.float32) - - if HAS_SEQ_IDX: - # - handle seq idx when HAS_INITSTATES==False - if not HAS_INITSTATES: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Without the if (pid_c > -1), with Triton 2.1.0, I get @@ -323,12 +263,12 @@ def _chunk_scan_fwd_kernel( if not HAS_INITSTATES: # - this is for continuous batching where there is no init states - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), + scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) else: # - if there is initstates, we will rely on prev_states, no zeroing # required. - scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + scale_m = tl.exp(dA_cs_m) else: scale_m = tl.exp(dA_cs_m) if BLOCK_SIZE_DSTATE <= 128: @@ -416,7 +356,7 @@ def _chunk_scan_fwd_kernel( acc += x_residual * D if HAS_Z: - out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptr += pid_b * stride_out_batch + chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) tl.store(out_x_ptrs, @@ -424,7 +364,7 @@ def _chunk_scan_fwd_kernel( mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) - z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptr += pid_b * stride_z_batch + chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) z = tl.load(z_ptrs, @@ -433,7 +373,7 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) - out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptr += pid_b * stride_out_batch + chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) tl.store(out_ptrs, @@ -449,6 +389,7 @@ def _chunk_scan_fwd( dA_cumsum, C, states, + cu_chunk_seqlens, D=None, z=None, seq_idx=None, @@ -495,8 +436,7 @@ def _chunk_scan_fwd( grid = lambda META: ( triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( - headdim, META['BLOCK_SIZE_N']), batch * nchunks - if chunk_offsets is None else len(chunk_offsets), nheads) + headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0)) _chunk_scan_fwd_kernel[grid]( @@ -515,6 +455,7 @@ def _chunk_scan_fwd( chunk_indices, chunk_offsets, len(chunk_indices) if chunk_indices is not None else 0, + cu_chunk_seqlens, chunk_size, headdim, dstate, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 3379054c17b9..f68c8ca7d5e2 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -14,22 +14,6 @@ from .mamba_ssm import softplus -@triton.jit -def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, - BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): - left: tl.int32 = 0 - right = num_seqs - while left < right: - mid = (left + right) // 2 - val = tl.load(query_start_len_ptr + mid) - mid_val = val // BLOCK_Q + mid if use_q_block_mode else val - - if mid_val <= target_idx: - left = mid + 1 - else: - right = mid - - return left - 1 @triton.autotune( configs=[ @@ -51,7 +35,7 @@ def _chunk_cumsum_fwd_kernel( dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, - cu_seqlens_ptr, + cu_chunk_seqlens_ptr, # Matrix dimension batch, seqlen, @@ -59,7 +43,6 @@ def _chunk_cumsum_fwd_kernel( chunk_size, dt_min, dt_max, - num_seqs, # Strides stride_dt_batch, stride_dt_seqlen, @@ -87,22 +70,10 @@ def _chunk_cumsum_fwd_kernel( pid_c = tl.program_id(axis=1).to(tl.int64) pid_h = tl.program_id(axis=2) + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - seq_idx = find_seq_idx(cu_seqlens_ptr, pid_c, num_seqs, chunk_size, True) - - chunk_start_idx = tl.load(cu_seqlens_ptr + seq_idx) // chunk_size + seq_idx - - chunk_local_idx = pid_c - chunk_start_idx - - cur_batch_in_all_start_idx = tl.load(cu_seqlens_ptr + seq_idx) - cur_batch_in_all_stop_idx = tl.load(cu_seqlens_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_idx - cur_batch_in_all_start_idx - - # skip any unncessary work - if chunk_local_idx * chunk_size >= cur_batch_query_len: - return - - dt_ptr += pid_b * stride_dt_batch + (cur_batch_in_all_start_idx + chunk_local_idx * chunk_size) * stride_dt_seqlen + dt_ptr += pid_b * stride_dt_batch + chunk_seqlen_start * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk @@ -115,11 +86,10 @@ def _chunk_cumsum_fwd_kernel( offs_c[None, :] * stride_dt_out_csize) dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) - chunk_size_limit = min(chunk_size, cur_batch_query_len - chunk_local_idx * chunk_size) dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & - (offs_c[None, :] < chunk_size_limit), + (offs_c[None, :] < chunk_seqlen_end), other=0.0).to(tl.float32) if HAS_DT_BIAS: dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, @@ -132,17 +102,17 @@ def _chunk_cumsum_fwd_kernel( # dt = tl.clamp(dt, dt_min, dt_max) dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) dt = tl.where( - (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_seqlen_end), dt, 0.0) tl.store(dt_out_ptrs, dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) tl.store(dA_cs_ptrs, dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) @triton.autotune( @@ -230,8 +200,8 @@ def _chunk_state_fwd_kernel( states_ptr, dt_ptr, dA_cumsum_ptr, - cu_seqlens_ptr, seq_idx_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions hdim, dstate, @@ -239,7 +209,6 @@ def _chunk_state_fwd_kernel( batch, seqlen, nheads_ngroups_ratio, - num_seqs, # Strides stride_x_batch, stride_x_seqlen, @@ -278,25 +247,15 @@ def _chunk_state_fwd_kernel( pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - seq_idx = find_seq_idx(cu_seqlens_ptr, pid_c, num_seqs, chunk_size, True) - chunk_start_idx = tl.load(cu_seqlens_ptr + seq_idx) // chunk_size + seq_idx - chunk_local_idx = pid_c - chunk_start_idx - cur_batch_in_all_start_idx = tl.load(cu_seqlens_ptr + seq_idx) - cur_batch_in_all_stop_idx = tl.load(cu_seqlens_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_idx - cur_batch_in_all_start_idx - # skip any unncessary work - if chunk_local_idx * chunk_size >= cur_batch_query_len: - return + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - seqlen_offset = cur_batch_in_all_start_idx + chunk_local_idx*chunk_size - b_ptr += pid_b * stride_b_batch + seqlen_offset * stride_b_seqlen + ( + b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * stride_b_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_b_head - x_ptr += pid_b * stride_x_batch + seqlen_offset * stride_x_seqlen + pid_h * stride_x_head + x_ptr += pid_b * stride_x_batch + chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + seqlen_offset * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -306,17 +265,13 @@ def _chunk_state_fwd_kernel( b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen - chunk_size_limit = min(chunk_size, cur_batch_query_len - chunk_local_idx * chunk_size) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_last = tl.load(seq_idx_ptr + - (chunk_size_limit - 1) * stride_seq_idx_seqlen) + chunk_size_limit = chunk_seqlen_end acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): @@ -331,17 +286,11 @@ def _chunk_state_fwd_kernel( dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - if HAS_SEQ_IDX: - seq_idx_k = tl.load(seq_idx_ptrs, - mask=offs_k < chunk_size_limit - k, - other=-1) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k - else: - scale = tl.where(seq_idx_k == seq_idx_last, - tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -349,8 +298,7 @@ def _chunk_state_fwd_kernel( b_ptrs += BLOCK_SIZE_K * stride_b_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head @@ -606,35 +554,27 @@ def _chunk_state_varlen_kernel( def _chunk_cumsum_fwd(dt, A, chunk_size, - cu_seqlens, + cu_chunk_seqlens, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): - batch, seqlen, nheads = dt.shape assert A.shape == (nheads, ) if dt_bias is not None: assert dt_bias.shape == (nheads, ) - num_seqs = len(cu_seqlens)-1 - nchunks = seqlen // chunk_size + num_seqs - print("dt.shape: ", dt.shape) - print("A.shape: ", A.shape) - print("nchunks: ", nchunks) - print("type(cu_seqlens): ", type(cu_seqlens)) + nchunks = cu_chunk_seqlens.shape[0]-1 dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - print("dt_out.shape: ", dt_out.shape) dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - print("dA_cumsum.shape: ", dA_cumsum.shape) grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): @@ -644,14 +584,13 @@ def _chunk_cumsum_fwd(dt, dt_bias, dt_out, dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, batch, seqlen, nheads, chunk_size, dt_limit[0], dt_limit[1], - num_seqs, dt.stride(0), dt.stride(1), dt.stride(2), @@ -676,7 +615,7 @@ def _chunk_state_fwd(B, x, dt, dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, seq_idx=None, states=None, states_in_fp32=True): @@ -696,7 +635,9 @@ def _chunk_state_fwd(B, states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) + print("[_chunk_state_fwd] states.shape: ", states.shape) + grid = lambda META: ( triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) @@ -707,7 +648,7 @@ def _chunk_state_fwd(B, states, dt, dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, seq_idx, headdim, dstate, @@ -715,7 +656,6 @@ def _chunk_state_fwd(B, batch, seqlen, nheads // ngroups, - len(cu_seqlens)-1, x.stride(0), x.stride(1), x.stride(2), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index c11f1a4c4b24..c3f8db39e7ae 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -39,6 +39,8 @@ def _mamba_chunk_scan_combined_fwd(x, chunk_indices=None, chunk_offsets=None, cu_seqlens=None, + cu_chunk_seqlens=None, + last_chunk=None, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, @@ -85,27 +87,23 @@ def _mamba_chunk_scan_combined_fwd(x, # - see the blog and paper for a visualization of the submatrices # which we refer to in the comments below - num_seqs = len(cu_seqlens) - 1 # 1. Compute chunked cumsum of A * dt # - here dt may go through a softplus activation dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, - cu_seqlens, + cu_chunk_seqlens, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) - print("dA_cumsum.shape: ", dA_cumsum.shape) - print("dt.shape: ", dt.shape) - # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd(B, x, dt, dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, seq_idx=seq_idx, states_in_fp32=True) @@ -123,7 +121,7 @@ def _mamba_chunk_scan_combined_fwd(x, states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, seq_idx=seq_idx, @@ -138,6 +136,7 @@ def _mamba_chunk_scan_combined_fwd(x, CB = _bmm_chunk_fwd(C, B, chunk_size, + cu_chunk_seqlens, seq_idx=seq_idx, output_dtype=torch.float32) @@ -158,6 +157,7 @@ def _mamba_chunk_scan_combined_fwd(x, dA_cumsum, C, states, + cu_chunk_seqlens, D=D, z=z, seq_idx=seq_idx, @@ -170,16 +170,11 @@ def _mamba_chunk_scan_combined_fwd(x, return out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - varlen_states = chunk_state_varlen( - B.squeeze(0), - x.squeeze(0), - dt.squeeze(0), - dA_cumsum.squeeze(0), - cu_seqlens, - states.squeeze(0), - initial_states=initial_states, - ) - return out_x, dt, dA_cumsum, states, final_states, varlen_states + print("last_chunk: ", last_chunk) + print(states.shape) + varlen_states = states[last_chunk] + print(varlen_states.shape) + return out_x, dt, dA_cumsum, states, final_states, states def mamba_chunk_scan_combined(x, @@ -196,6 +191,8 @@ def mamba_chunk_scan_combined(x, chunk_indices=None, chunk_offsets=None, cu_seqlens=None, + cu_chunk_seqlens=None, + last_chunk=None, dt_softplus=False, dt_limit=(0.0, float("inf")), out=None, @@ -216,12 +213,12 @@ def mamba_chunk_scan_combined(x, initial_states: (batch, nheads, headdim, dstate) seq_idx: (batch, seqlen) cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + cu_chunk_seqlens: (num_chunks + 1) dt_softplus: Whether to apply softplus to dt out: Preallocated output tensor state_dtype: The data type of the ssm state """ - print("-------------------------") - print("cu_seqlens: ", cu_seqlens) + if not return_varlen_states: cu_seqlens = None else: @@ -241,6 +238,8 @@ def mamba_chunk_scan_combined(x, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk=last_chunk, dt_softplus=dt_softplus, dt_limit=dt_limit, out=out, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 56c61ad3cdd8..4d2bea947be5 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -10,7 +10,6 @@ from vllm.triton_utils import tl, triton -from .ssd_chunk_state import find_seq_idx @triton.autotune( configs=[ @@ -34,13 +33,12 @@ def _state_passing_fwd_kernel( seq_idx_ptr, chunk_offsets_ptr, chunk_meta_num, - cu_seqlens_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions dim, nchunks, seqlen, chunk_size, - num_seqs, # Strides stride_states_batch, stride_states_chunk, @@ -102,12 +100,12 @@ def _state_passing_fwd_kernel( tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk - prev_seq_idx_chunk_end = 0 - logical_chunk_idx = 0 + + prev_seq_idx = 0 for c in range(nchunks): - # now a chunk can only contain one sequence - seq_idx_chunk = find_seq_idx(cu_seqlens_ptr, c, num_seqs, chunk_size, True) + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + c) new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) @@ -115,51 +113,24 @@ def _state_passing_fwd_kernel( scale_mask = True if HAS_SEQ_IDX: - # - the seq to pass forward is the one that is flushed to the right - # boundary. - # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. - seq_idx_chunk_end = seq_idx_chunk + seq_idx = tl.load(seq_idx_ptr + chunk_seqlen_start * stride_seq_idx_seqlen) if HAS_INITSTATES: - if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: + if IS_CONT_BATCHED and prev_seq_idx != seq_idx: # this means in the current chunk the rightmost flushed seq # has changed. # - so we do not propagate the state from previous chunk # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch + initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch # - update state with seq_idx_new's init state states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - - # - we need to consider the cumsum only of the last sequence in the chunk - # - find its starting position (given by c_off of the logical chunk index) - # - and subtract the cumsum just before that position from the total cumsum - # - first, update the logical chunk index (add the number of sequences in the current physical chunk): - # sequence index at the start of the current chunk - seq_idx_chunk_start = seq_idx_chunk - - logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start - # - load the chunk offset: - c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, - mask=logical_chunk_idx < chunk_meta_num, - other=0) - # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything - if c_off > 0: - # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset - dA_cs_boundary = tl.load( - dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + - (c_off - 1) * stride_dA_cs_csize, - mask=(c_off - 1) > -1 and c_off < chunk_size, - other=0.0) - dA_cs -= dA_cs_boundary - - # - increment logical chunk index for every physical chunk - logical_chunk_idx += 1 else: - scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end - prev_seq_idx_chunk_end = seq_idx_chunk_end + scale_mask = seq_idx == prev_seq_idx + + prev_seq_idx = seq_idx scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) states = scale * states + new_states @@ -175,7 +146,7 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, initial_states=None, seq_idx=None, chunk_size=None, @@ -215,9 +186,6 @@ def _state_passing_fwd( device=states.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) - - print("[_state_passing_fwd] seq_idx.shape: ", seq_idx.shape) - print("[_state_passing_fwd] chunk_offsets: ", chunk_offsets) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( states, @@ -228,12 +196,11 @@ def _state_passing_fwd( seq_idx, chunk_offsets, len(chunk_offsets) if chunk_offsets is not None else 0, - cu_seqlens, + cu_chunk_seqlens, dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size, - len(cu_seqlens)-1, states.stride(0), states.stride(1), states.stride(2), diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 971f32848460..91fb63dc486c 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -14,7 +14,7 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec - +from vllm.utils import cdiv def _query_start_loc_to_chunk_indices_offsets( query_start_loc: torch.Tensor, chunk_size: int, @@ -128,8 +128,8 @@ class Mamba2AttentionMetadata: chunk_offsets_p: Optional[torch.Tensor] # tpa - chunk_seqlen_start_p: Optional[torch.Tensor] - chunk_seqlen_end_p: Optional[torch.Tensor] + cu_chunk_seqlen_p: Optional[torch.Tensor] + last_chunk_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] @@ -165,6 +165,10 @@ def build(self, has_initial_states_p = None prep_initial_states = False + + cu_chunk_seqlen_p = None + last_chunk_p = None + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( @@ -172,6 +176,11 @@ def build(self, common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) + print("num_decodes: ", num_decodes) + print("num_prefills: ", num_prefills) + print("num_decode_tokens: ", num_decode_tokens) + print("num_prefill_tokens: ", num_prefill_tokens) + # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: #[batch,] @@ -182,9 +191,11 @@ def build(self, has_initial_states_p = has_initial_states_cpu.to( query_start_loc.device) + query_start_loc_p = common_attn_metadata.query_start_loc[ -num_prefills - 1:] - num_decode_tokens + seq_idx_p = torch.repeat_interleave(torch.arange( num_prefills, dtype=torch.int32, @@ -193,6 +204,52 @@ def build(self, output_size=num_prefill_tokens) seq_idx_p.unsqueeze_(0) + + num_computed_tokens_p = common_attn_metadata.num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] + query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ + -num_prefills - 1:] - num_decode_tokens + + print("num_computed_tokens_p: ", num_computed_tokens_p) + print("query_start_loc_p: ", query_start_loc_p) + + cu_chunk_seqlen = [] + last_chunk = [] + seqlen_pos = 0 + for req_idx in range(num_prefills): + this_num_computed = num_computed_tokens_p[req_idx].item() + this_new_tokens = query_start_loc_p_cpu[req_idx+1].item() - query_start_loc_p_cpu[req_idx].item() + print(req_idx, this_num_computed, this_new_tokens) + + # if computed tokens are not chunk-aligned, use the first + # chunk to finish it off + # TODO(tdoublep): I guess we need block size actually? + if this_num_computed % self.chunk_size != 0: + cu_chunk_seqlen.append(seqlen_pos) + # how many tokens to finish the chunk? + chunk_len = cdiv(this_num_computed, self.chunk_size)*self.chunk_size - this_num_computed + # we can only use at most this_new_tokens + chunk_len = min(chunk_len, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + n_chunks = cdiv(this_new_tokens, self.chunk_size) + for chunk in range(n_chunks): + cu_chunk_seqlen.append(seqlen_pos) + chunk_len = min(self.chunk_size, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + assert this_new_tokens == 0 + last_chunk.append(len(cu_chunk_seqlen)-1) + + cu_chunk_seqlen.append(seqlen_pos) + + cu_chunk_seqlen_p = torch.as_tensor(cu_chunk_seqlen, device=query_start_loc.device, dtype=torch.int32) + last_chunk_p = torch.as_tensor(last_chunk, device=query_start_loc.device, dtype=torch.int32) + + print("cu_chunk_seqlen: ", cu_chunk_seqlen) + print("cu_chunk_seqlen_p: ", cu_chunk_seqlen_p) + # We compute metadata for chunked prefill once at the top level # model forward and reuse them in mamba layers. If not needed, # they will be ignored inside mamba kernels. @@ -224,5 +281,7 @@ def build(self, chunk_indices_p=chunk_indices_p, chunk_offsets_p=chunk_offsets_p, state_indices_tensor=state_indices_tensor, + cu_chunk_seqlen_p=cu_chunk_seqlen_p, + last_chunk_p=last_chunk_p, ) return attn_metadata From a30bb5e3e8d40df7d63072e86bb09f875a3860a5 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 11 Sep 2025 17:03:53 +0000 Subject: [PATCH 006/105] Initial varlen for APC Signed-off-by: Stanislaw Wozniak --- .../layers/mamba/mamba_mixer2.py | 183 ++++++++++++------ .../layers/mamba/ops/ssd_combined.py | 45 ++++- 2 files changed, 169 insertions(+), 59 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index c3a4c010c5f2..e49c417c9ce7 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -38,7 +38,7 @@ from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op, cdiv +from vllm.utils import cdiv, direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 @@ -574,10 +574,12 @@ def forward_cuda( torch.roll(attn_metadata.query_start_loc, -1, -1) - attn_metadata.query_start_loc)[:-1] seq_lens_completed = (mamba2_metadata.seq_lens - seq_lens_pending) - last_computed_token_block_idx = seq_lens_completed // mamba_block_size - 1 # e.g. 16 blocks computed; 0th based indexing -> state[15] + # e.g. 16 blocks computed; 0th based indexing -> state[15] + last_computed_token_block_idx = \ + seq_lens_completed // mamba_block_size - 1 + # -1 in case it's non-computed and causes later issues with indexing last_computed_token_block_idx = last_computed_token_block_idx.clamp( - min=0 - ) #in case it's non-computed it's -1 and causes later issues with indexing + min=0) current_first_token_block_idx = cdiv(seq_lens_completed + 1, mamba_block_size) - 1 current_last_token_block_idx = cdiv( @@ -647,24 +649,25 @@ def forward_cuda( # pointed to by "state_indices_tensor" x = hidden_states_B_C_p.transpose( 0, 1) # this is the form that causal-conv see - if mamba2_metadata.cu_seqlen is None: + if mamba2_metadata.cu_seqlen is None: #TODO: move to MDBuilder? mamba2_metadata = update_metadata(x, query_start_loc_p, mamba2_metadata) kernel_conv1d_indices = state_indices_tensor_p if cache_enabled: - # Kernel expects to have the initial state here and overwrites it -> use final state location - if has_initial_states_p is not None and has_initial_states_p.sum( - ) > 0: - conv_state_idx_input = state_indices_tensor_p.index_select( - 1, last_computed_idx_p).diag().unsqueeze(1) - conv_state_idx_output = state_indices_tensor_p.index_select( - 1, current_last_idx_p).diag().unsqueeze(1) + # Kernel expects to have the initial state here + # and overwrites it -> use final state location + if has_initial_states_p is not None \ + and has_initial_states_p.sum() > 0: + conv_state_idx_input = state_indices_tensor_p.gather( + 1, last_computed_idx_p.unsqueeze(1)) + conv_state_idx_output = state_indices_tensor_p.gather( + 1, current_last_idx_p.unsqueeze(1)) conv_state[conv_state_idx_output[ has_initial_states_p]] = conv_state[ conv_state_idx_input[has_initial_states_p]] - kernel_conv1d_indices = state_indices_tensor_p.index_select( - 1, current_last_idx_p).diag() + kernel_conv1d_indices = state_indices_tensor_p.gather( + 1, current_last_idx_p.unsqueeze(1)).squeeze(1) hidden_states_B_C_p = causal_conv1d_fn( x, @@ -680,64 +683,89 @@ def forward_cuda( if cache_enabled: - def copy_x_to_conv_state(conv_state_block_idx, x_offset, - x_end): + def copy_x_to_conv_state(conv_state_block_idx, x_offset, x_end, + query_start_loc): conv_state[conv_state_block_idx, :, 0] = torch.transpose( - x[:, x_offset - 3:x_end:mamba_block_size], 1, 0) + x[:, query_start_loc + x_offset - 3:query_start_loc + + x_end:mamba_block_size], 1, 0) conv_state[conv_state_block_idx, :, 1] = torch.transpose( - x[:, x_offset - 2:x_end:mamba_block_size], 1, 0) + x[:, query_start_loc + x_offset - 2:query_start_loc + + x_end:mamba_block_size], 1, 0) conv_state[conv_state_block_idx, :, 2] = torch.transpose( - x[:, x_offset - 1:x_end:mamba_block_size], 1, 0) - - # initial state: state_indices_tensor_p[, last_computed_idx_p[]] - # new states: state_indices_tensor_p[, current_first_idx_p[]:current_last_idx_p[]] + x[:, query_start_loc + x_offset - 1:query_start_loc + + x_end:mamba_block_size], 1, 0) + + # initial state: + # state_indices_tensor_p[, last_computed_idx_p[]] + # new states: + # state_indices_tensor_p[, current_first_idx_p[]: + # current_last_idx_p[]] if cache_strategy == "all": # Iterate over all sequences to need prefill for seq_idx in range(state_indices_tensor_p.shape[0]): number_full_blocks = seq_lens_pending[ seq_idx] // mamba_block_size - second_last_block_idx = number_full_blocks if seq_lens_pending[ - seq_idx] % mamba_block_size > 0 else number_full_blocks - 1 - if number_full_blocks > 0 and seq_lens_pending[ - seq_idx] % mamba_block_size > 0: + if seq_lens_pending[seq_idx] % mamba_block_size > 0: + second_last_block_idx = number_full_blocks + else: + second_last_block_idx = number_full_blocks - 1 + #TODO: simpler logic via?: + # if (current_last_idx_p - current_first_idx_p) + # [seq_idx] > 0: + if number_full_blocks > 0: # and seq_lens_pending[ + #seq_idx] % mamba_block_size > 0: # unnecessary? copy_x_to_conv_state( state_indices_tensor_p[ seq_idx, current_first_idx_p[seq_idx]: current_first_idx_p[seq_idx] + second_last_block_idx], mamba_block_size, - mamba_block_size * second_last_block_idx) + mamba_block_size * second_last_block_idx, + query_start_loc_p[seq_idx]) elif cache_strategy == "last": # i.e. keep two states: either # a) states at the last two block boundaries or - # b) state at the last block boundary and last state of the sequence, - # which might not be at a block boundary + # b) state at the last block boundary and last state of + # the sequence, which might not be at a block boundary # Iterate over all sequences to need prefill for seq_idx in range(state_indices_tensor_p.shape[0]): - # Only store the additional second state if there are is at least one full block and a remainder. + # Only store the additional second state if there are + # is at least one full block and a remainder. # Otherwise, there is only one state to store if number_full_blocks > 0 and seq_lens_pending[ seq_idx] % mamba_block_size > 0: - second_last_block_idx = number_full_blocks if seq_lens_pending[ - seq_idx] % mamba_block_size > 0 else number_full_blocks - 1 + if seq_lens_pending[seq_idx] % mamba_block_size > 0: + second_last_block_idx = number_full_blocks + else: + second_last_block_idx = number_full_blocks - 1 copy_x_to_conv_state( state_indices_tensor_p[ seq_idx, current_last_idx_p[seq_idx] - 1:current_last_idx_p[seq_idx]], mamba_block_size * second_last_block_idx, - mamba_block_size * second_last_block_idx) + mamba_block_size * second_last_block_idx, + query_start_loc_p[seq_idx]) hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_B_C_p) # 3. State Space Model sequence transformation initial_states = None + seq_pad = None if (has_initial_states_p is not None and prep_initial_states): # making a copy of the states if envs.VLLM_USE_V1: kernel_ssm_indices = state_indices_tensor_p if cache_enabled: - kernel_ssm_indices = state_indices_tensor_p. \ - index_select(1, last_computed_idx_p).diag() + #TODO: Move to attn metadata builder + kernel_ssm_indices = state_indices_tensor_p.gather( + 1, last_computed_idx_p.unsqueeze(1)).squeeze(1) + if num_prefills > 1: + # Padding for mamba_chunk_scan_combined + seq_lens_pad = cdiv(seq_lens_pending[num_decodes:], chunk_size) * chunk_size # [6144, 1024, 1024, 256] + seq_offsets_pad = seq_lens_pad.cumsum(0)[:-1] # [6144, 7168, 8192] + seq_pad = seq_lens_pad - seq_lens_pending[num_decodes:] # [ 41, 38, 41, 136] + else: + seq_pad = None initial_states = torch.where( has_initial_states_p[:, None, None, None], ssm_state[kernel_ssm_indices], 0) @@ -773,9 +801,10 @@ def copy_x_to_conv_state(conv_state_block_idx, x_offset, dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), - state_dtype=ssm_state.dtype) + state_dtype=ssm_state.dtype, + seq_pad=seq_pad) - if cache_enabled: + if cache_enabled and num_prefills == 1: states, varlen_state = mamba_outputs # update ssm states @@ -784,15 +813,19 @@ def copy_x_to_conv_state(conv_state_block_idx, x_offset, # Combine to have all_states (num_prefills, ALL_states, nheads, headdim, dstate) tensor: all_states = torch.concat( [states[:, 1:], varlen_state.unsqueeze(1)], - 1) # first from returned states is zero + 1) # for num_prefills=1 first returned state is zero state_stride = mamba_block_size // chunk_size # states for chunks 0,1,2,3,4 (chunk_size=256) correspond to - # states at blocks 0,0,1,1,2 (block_size=512). For first blocks, stride(=2). For last block can't strid + # states at blocks 0,0,1,1,2 (block_size=512). + # For first blocks, stride(=2). For last block can't stride. - # initial state: state_indices_tensor_p[, last_computed_idx_p[]] - # new states: state_indices_tensor_p[, current_first_idx_p[]:current_last_idx_p[]] + # initial state: + # state_indices_tensor_p[, last_computed_idx_p[]] + # new states: + # state_indices_tensor_p[, current_first_idx_p[]: + # current_last_idx_p[]] - # Note: Currently works for 1 request only: + # Code assuming 1 prefill request: states_at_blocks = torch.concat([ all_states[:, state_stride - 1:(current_last_idx_p[0] - current_first_idx_p[0]) * @@ -807,36 +840,70 @@ def copy_x_to_conv_state(conv_state_block_idx, x_offset, ssm_state[ state_indices_tensor_p[:, current_last_idx_p[0] - 1:]] = states_at_blocks[:, -2:] + elif cache_enabled and num_prefills > 1: + if self.prefix == 'model.layers.0.mixer' and attn_metadata.num_prefills == 4: + pass + states, varlen_state = mamba_outputs + last_states_indices = cdiv(seq_lens_pending[num_decodes:], chunk_size).cumsum(0)-1 + all_states = states + #layout: [full states 1, partial state 1, full states 2, partial state 2, ... ] + # update all partial states with correct varlen_states + all_states[0, last_states_indices] = varlen_state + state_stride = mamba_block_size // chunk_size + + states_indices = torch.cat([torch.zeros(1, dtype=last_states_indices.dtype, device=last_states_indices.device), last_states_indices + 1]) + # seq_till_chunk [0, 24, 28, 32, 33] -> e.g. 32:33 is the last one + # seq_till_chunk = torch.concat([torch.tensor([0]), cdiv(seq_lens_pending[num_decodes:], chunk_size).cumsum(0)]) + for seq_idx in range(state_indices_tensor_p.shape[0]): + pass + all_seq_states = all_states[:,states_indices[seq_idx]:states_indices[seq_idx+1]] + states_at_blocks = torch.concat([ + all_seq_states[:, state_stride - 1:(current_last_idx_p[seq_idx] - + current_first_idx_p[seq_idx]) * + state_stride:state_stride], + varlen_state[seq_idx].unsqueeze(0).unsqueeze(0) + ], 1) + if cache_strategy == "all": + ssm_state[state_indices_tensor_p[seq_idx, current_first_idx_p[seq_idx]: + current_last_idx_p[seq_idx] + + 1]] = states_at_blocks + elif cache_strategy == "last": + ssm_state[ + state_indices_tensor_p[:, current_last_idx_p[seq_idx] - + 1:]] = states_at_blocks[:, -2:] else: varlen_state = mamba_outputs # update ssm states - # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor + # - varlen state is (num_prefills, nheads, headdim, dstate) ssm_state[state_indices_tensor_p] = varlen_state # Process decode requests if has_decode: if cache_enabled: - # if at_block_boundary, load states from previous blocks: - at_block_boundary = mamba2_metadata.seq_lens % mamba_block_size == 0 - finished_blocks = attn_metadata.seq_lens[ - 0] // mamba_block_size #e.g. 1024 -> 2 blocks ; 1025 -> 2 blocks + # # if at_block_boundary, load states from previous blocks: + # at_block_boundary = mamba2_metadata.seq_lens \ + # % mamba_block_size == 0 + # finished_blocks = attn_metadata.seq_lens[ + # 0] // mamba_block_size #e.g. 1024:2 blocks; 1025:2 blocks input_block = cdiv( - attn_metadata.seq_lens[0], mamba_block_size + attn_metadata.seq_lens[:num_decodes], mamba_block_size ) #e.g. 1024 -> 2nd block, 1025 -> 3rd block output_block = cdiv( - attn_metadata.seq_lens[0] + 1, mamba_block_size + attn_metadata.seq_lens[:num_decodes] + 1, mamba_block_size ) #e.g. 1023 -> 2nd block, 1024 -> 3rd block - state_indices_tensor_d_input = state_indices_tensor_d[:, - input_block - - 1] - state_indices_tensor_d_output = state_indices_tensor_d[:, - output_block - - 1] - - # copy initial state to new location, as update kernel works in place - if output_block > input_block: + + state_indices_tensor_d_input = \ + state_indices_tensor_d.gather(1, + (input_block-1).unsqueeze(1)).squeeze(1) + state_indices_tensor_d_output = \ + state_indices_tensor_d.gather(1, + (output_block-1).unsqueeze(1)).squeeze(1) + + # copy initial state to new location, + # as update kernel works in place + if (output_block > input_block).any(): conv_state[state_indices_tensor_d_output] = conv_state[ state_indices_tensor_d_input] else: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index a9666a67f0ee..7c5e166f3231 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -195,7 +195,8 @@ def mamba_chunk_scan_combined(x, out=None, return_final_states=False, return_varlen_states=False, - state_dtype=None): + state_dtype=None, + seq_pad=None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -219,6 +220,39 @@ def mamba_chunk_scan_combined(x, cu_seqlens = None else: assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + # Padding logic used by prefix caching: + if seq_pad is not None and seq_pad.sum() > 0: + pass #Padding needed + seq_pad[-1] = 0 #never pad last + def pad(v): + v2 = [] + for idx in torch.arange(len(cu_seqlens)-1): + v2.append(v[:, cu_seqlens[idx]:cu_seqlens[idx+1]]) + v2.append(v[:, cu_seqlens[idx]:cu_seqlens[idx]+1].repeat_interleave(seq_pad[idx], 1)) #last value pad + #v2.append(torch.zeros_like(x[:,0:1]).repeat_interleave(seq_pad[idx], 1)) #zero pad + return torch.cat(v2[:-1], 1) #don't need to pad the last + + x = pad(x) + dt = pad(dt) + B = pad(B) + C = pad(C) + seq_idx = pad(seq_idx) + + # adjust the cu_seqlens + org_cu_seqlens = cu_seqlens + cu_seqlens = cu_seqlens.clone() + cu_seqlens[1:] = cu_seqlens[1:] + seq_pad.cumsum(0) + + # no shared chunks -> just feed an incremental list, no offsets + chunk_indices = torch.arange(-(- x.shape[1] // chunk_size), device=x.device) + chunk_offsets = torch.zeros_like(chunk_indices, device=x.device) + + # Return by value - we need to store the old tensor + org_out = out + # allocate a new larger tensor + out = torch.zeros_like(x) + # and after the kernel write back the results + out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( x, dt, @@ -238,6 +272,15 @@ def mamba_chunk_scan_combined(x, dt_limit=dt_limit, out=out, state_dtype=state_dtype) + + # Padding logic used by prefix caching: + if return_varlen_states and seq_pad is not None and seq_pad.sum() > 0: + #unpad the output and write to the originally passed tensor: + offset = 0 + for idx in torch.arange(len(org_cu_seqlens)-1): + org_out[0, org_cu_seqlens[idx]:org_cu_seqlens[idx+1]] = \ + out[0, offset+org_cu_seqlens[idx]:offset+org_cu_seqlens[idx+1]] + offset += seq_pad[idx] if not return_varlen_states: if not return_final_states: From 664a21a31c56c457bcba2fc785dc4bfa0bc2f9ab Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 14:06:48 -0400 Subject: [PATCH 007/105] fix bug Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_bmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index e7980af51a74..4debe5375f78 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -141,7 +141,7 @@ def _bmm_chunk_fwd_kernel( chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) a_ptr += pid_b * stride_a_batch + chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head - b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * chunk_size * stride_b_seqlen + pid_h * stride_b_head + b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen From 6c475d62e6f4e93b7f5bcbc091301ba362bce697 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 14:14:08 -0400 Subject: [PATCH 008/105] fix bug Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index c3f8db39e7ae..174ba9c0ae26 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -174,7 +174,7 @@ def _mamba_chunk_scan_combined_fwd(x, print(states.shape) varlen_states = states[last_chunk] print(varlen_states.shape) - return out_x, dt, dA_cumsum, states, final_states, states + return out_x, dt, dA_cumsum, states, final_states, varlen_states def mamba_chunk_scan_combined(x, From 0d00c69acd045cd3f5800927841c837d54eaf13d Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 14:19:11 -0400 Subject: [PATCH 009/105] fix bug Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 174ba9c0ae26..14f4903ac7df 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -172,7 +172,7 @@ def _mamba_chunk_scan_combined_fwd(x, assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" print("last_chunk: ", last_chunk) print(states.shape) - varlen_states = states[last_chunk] + varlen_states = states[:, last_chunk, ...] print(varlen_states.shape) return out_x, dt, dA_cumsum, states, final_states, varlen_states From b7ae698358db8290f8e04471544bd48579e82d60 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 17:34:37 -0400 Subject: [PATCH 010/105] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 3 + .../layers/mamba/ops/ssd_bmm.py | 4 +- .../layers/mamba/ops/ssd_chunk_scan.py | 103 +++++++++++------- .../layers/mamba/ops/ssd_chunk_state.py | 32 +++--- .../layers/mamba/ops/ssd_combined.py | 19 +++- .../layers/mamba/ops/ssd_state_passing.py | 70 ++++-------- 6 files changed, 118 insertions(+), 113 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 222f89f2c35b..34c5a9280752 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -684,6 +684,9 @@ def forward_cuda( self.head_dim), state_dtype=ssm_state.dtype) + print("preallocated_ssm_out_p: ", preallocated_ssm_out_p[0,:10]) + print("varlen_state: ", varlen_state[0,0,0,:10]) + # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor ssm_state[state_indices_tensor_p] = varlen_state diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 4debe5375f78..3a245b127f01 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -142,8 +142,6 @@ def _bmm_chunk_fwd_kernel( a_ptr += pid_b * stride_a_batch + chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -152,7 +150,7 @@ def _bmm_chunk_fwd_kernel( offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) - chunk_size_limit = chunk_seqlen_end + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 9f23b4103a97..65777d4e0789 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -13,7 +13,7 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') - +''' @triton.autotune( configs=[ triton.Config( @@ -107,6 +107,7 @@ ], key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], ) +''' @triton.jit def _chunk_scan_fwd_kernel( # Pointers to matrices @@ -216,28 +217,19 @@ def _chunk_scan_fwd_kernel( # M-block offsets and prev states # - logic in next block may override these if there is an active offset offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) - prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head - prev_states_hdim = stride_states_hdim - prev_states_dstate = stride_states_dstate - - chunk_size_limit = chunk_seqlen_end + #prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head + #prev_states_hdim = stride_states_hdim + #prev_states_dstate = stride_states_dstate - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + chunk_seqlen_start * stride_seq_idx_seqlen + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - # current sequence index - seq_idx = tl.load(seq_idx_ptr) - # previous sequence index - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + seq_idx_ptr += pid_b * stride_seq_idx_batch + chunk_seqlen_start * stride_seq_idx_seqlen + seq_idx = tl.load(seq_idx_ptr) + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, - other=0) + other=-1) - if HAS_INITSTATES and (seq_idx != seq_idx_prev): - # - replace prev_states_ptr with init_states - prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head - prev_states_hdim = stride_init_states_hdim # override strides - prev_states_dstate = stride_init_states_dstate offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, @@ -256,31 +248,39 @@ def _chunk_scan_fwd_kernel( C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - prev_states_ptrs = prev_states_ptr + ( - offs_n[None, :] * prev_states_hdim + - offs_k_dstate[:, None] * prev_states_dstate) - if HAS_SEQ_IDX: - - if not HAS_INITSTATES: - # - this is for continuous batching where there is no init states - scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), - 0.0) - else: - # - if there is initstates, we will rely on prev_states, no zeroing - # required. - scale_m = tl.exp(dA_cs_m) - else: - scale_m = tl.exp(dA_cs_m) + scale_m = tl.exp(dA_cs_m) if BLOCK_SIZE_DSTATE <= 128: C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) - prev_states = tl.load(prev_states_ptrs, + + if seq_idx != seq_idx_prev: + if HAS_INITSTATES: + # load from init states + init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ + + pid_h * stride_init_states_head \ + + offs_n[None, :] * stride_init_states_hdim \ + + offs_k_dstate[:, None] * stride_init_states_dstate + prev_states = tl.load(init_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + else: + # Set to zero + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) + else: + # Load from previous chunk + states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ + + pid_h * stride_states_head \ + + offs_n[None, :] * stride_states_hdim \ + + offs_k_dstate[:, None] * stride_states_dstate + prev_states = tl.load(states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: @@ -290,11 +290,31 @@ def _chunk_scan_fwd_kernel( (offs_k_dstate[None, :] < dstate - k), other=0.0) # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) - prev_states = tl.load( - prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate - k) & - (offs_n[None, :] < hdim), - other=0.0) + if seq_idx != seq_idx_prev: + if HAS_INITSTATES: + # load from init states + init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ + + pid_h * stride_init_states_head \ + + offs_n[None, :] * stride_init_states_hdim \ + + offs_k_dstate[:, None] * stride_init_states_dstate + prev_states = tl.load(init_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + else: + # Set to zero + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) + else: + # Load from previous chunk + states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ + + pid_h * stride_states_head \ + + offs_n[None, :] * stride_states_hdim \ + + offs_k_dstate[:, None] * stride_states_dstate + prev_states = tl.load(states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K @@ -412,6 +432,8 @@ def _chunk_scan_fwd( assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) + print("out.shape: ", out.shape) + if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) @@ -511,5 +533,8 @@ def _chunk_scan_fwd( HAS_SEQ_IDX=seq_idx is not None, IS_TRITON_22=TRITON_22, HAS_INITSTATES=initial_states is not None, + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=32, ) return out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index f68c8ca7d5e2..49c4678b4a87 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -15,18 +15,6 @@ from .mamba_ssm import softplus -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_H': 1}), - triton.Config({'BLOCK_SIZE_H': 2}), - triton.Config({'BLOCK_SIZE_H': 4}), - triton.Config({'BLOCK_SIZE_H': 8}), - triton.Config({'BLOCK_SIZE_H': 16}), - triton.Config({'BLOCK_SIZE_H': 32}), - triton.Config({'BLOCK_SIZE_H': 64}), - ], - key=['chunk_size', 'nheads'], -) @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices @@ -73,6 +61,7 @@ def _chunk_cumsum_fwd_kernel( chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + dt_ptr += pid_b * stride_dt_batch + chunk_seqlen_start * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk @@ -86,10 +75,11 @@ def _chunk_cumsum_fwd_kernel( offs_c[None, :] * stride_dt_out_csize) dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & - (offs_c[None, :] < chunk_seqlen_end), + (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) if HAS_DT_BIAS: dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, @@ -102,17 +92,17 @@ def _chunk_cumsum_fwd_kernel( # dt = tl.clamp(dt, dt_min, dt_max) dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) dt = tl.where( - (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_seqlen_end), dt, + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) tl.store(dt_out_ptrs, dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) tl.store(dA_cs_ptrs, dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) @triton.autotune( @@ -200,8 +190,8 @@ def _chunk_state_fwd_kernel( states_ptr, dt_ptr, dA_cumsum_ptr, - seq_idx_ptr, cu_chunk_seqlens_ptr, + seq_idx_ptr, # Matrix dimensions hdim, dstate, @@ -271,7 +261,7 @@ def _chunk_state_fwd_kernel( dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - chunk_size_limit = chunk_seqlen_end + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): @@ -577,6 +567,10 @@ def _chunk_cumsum_fwd(dt, dtype=torch.float32) grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + + print("dt_out.shape: ", dt_out.shape) + print("dA_cumsum.shape: ", dA_cumsum.shape) + with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( dt, @@ -606,6 +600,7 @@ def _chunk_cumsum_fwd(dt, dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_H=1, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out @@ -637,6 +632,7 @@ def _chunk_state_fwd(B, dtype=states_dtype) print("[_chunk_state_fwd] states.shape: ", states.shape) + print("[_chunk_state_fwd] cu_chunk_seqlens: ", cu_chunk_seqlens) grid = lambda META: ( triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 14f4903ac7df..2c158f04b8cf 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -107,6 +107,10 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) + print("after chunk_state_fwd: ") + print("states.shape: ", states.shape) + print("states: ", states[0,0,0,0,:10]) + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states @@ -118,7 +122,7 @@ def _mamba_chunk_scan_combined_fwd(x, # - this will ensure that states will be updated with the rightmost flushed seq_idx # of the previous chunk. This implies that the first chunk of states is either 0 # or equal to init_states of the first example. - states, final_states = _state_passing_fwd( + states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum, cu_chunk_seqlens, @@ -129,8 +133,13 @@ def _mamba_chunk_scan_combined_fwd(x, out_dtype=state_dtype if state_dtype is not None else C.dtype, is_cont_batched=cu_seqlens is not None, chunk_offsets=chunk_offsets) - states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) - for t in [states, final_states]) + + print("after state passing: ") + print("states.shape: ", states.shape) + + print("states: ", states[0 ,0, 0,:10]) + + states = rearrange(states, "... (p n) -> ... p n", n=dstate) # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, @@ -172,8 +181,10 @@ def _mamba_chunk_scan_combined_fwd(x, assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" print("last_chunk: ", last_chunk) print(states.shape) - varlen_states = states[:, last_chunk, ...] + varlen_states = states[:, last_chunk, ...].clone() print(varlen_states.shape) + print("varlen_states: ", varlen_states[0,0,0,:10]) + final_states = states[:, -1, ...] return out_x, dt, dA_cumsum, states, final_states, varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 4d2bea947be5..b084ed317ee6 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -27,7 +27,6 @@ def _state_passing_fwd_kernel( # Pointers to matrices states_ptr, out_ptr, - final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, @@ -48,9 +47,6 @@ def _state_passing_fwd_kernel( stride_out_chunk, stride_out_head, stride_out_dim, - stride_final_states_batch, - stride_final_states_head, - stride_final_states_dim, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, @@ -73,11 +69,11 @@ def _state_passing_fwd_kernel( dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head if HAS_INITSTATES: initstates_ptr += pid_h * stride_initstates_head if not IS_CONT_BATCHED: initstates_ptr += pid_b * stride_initstates_batch + initstates_ptr += offs_m * stride_initstates_dim if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch @@ -85,59 +81,40 @@ def _state_passing_fwd_kernel( offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) states_ptrs = states_ptr + offs_m * stride_states_dim out_ptrs = out_ptr + offs_m * stride_out_dim - final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim - # - states will be the past state of the sequence that continues on the current check - if not HAS_INITSTATES: - states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - else: - initstates_ptr += offs_m * stride_initstates_dim - initstates_ptrs = initstates_ptr - # - for cont batches, for the first chunk mean it will be the first batch's - # init state - states = tl.load(initstates_ptrs, mask=offs_m < dim, + if HAS_INITSTATES: + initstates_ptrs = initstates_ptr + 0 * stride_initstates_batch + states = tl.load(initstates_ptrs, + mask=offs_m < dim, other=0.0).to(tl.float32) - - tl.store(out_ptrs, states, mask=offs_m < dim) - out_ptrs += stride_out_chunk + else: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) prev_seq_idx = 0 for c in range(nchunks): - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + c) new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale_mask = True - if HAS_SEQ_IDX: - seq_idx = tl.load(seq_idx_ptr + chunk_seqlen_start * stride_seq_idx_seqlen) + seq_idx = tl.load(seq_idx_ptr + chunk_seqlen_start * stride_seq_idx_seqlen) + # we are started a new sequence + if prev_seq_idx != seq_idx: if HAS_INITSTATES: - if IS_CONT_BATCHED and prev_seq_idx != seq_idx: - # this means in the current chunk the rightmost flushed seq - # has changed. - # - so we do not propagate the state from previous chunk - # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch - - # - update state with seq_idx_new's init state - states = tl.load(initstates_ptrs, - mask=offs_m < dim, - other=0.0).to(tl.float32) + initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch + states = tl.load(initstates_ptrs, + mask=offs_m < dim, + other=0.0).to(tl.float32) else: - scale_mask = seq_idx == prev_seq_idx + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - prev_seq_idx = seq_idx + prev_seq_idx = seq_idx - scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) - states = scale * states + new_states - if c < nchunks - 1: - tl.store(out_ptrs, states, mask=offs_m < dim) - else: - tl.store(final_states_ptrs, states, mask=offs_m < dim) + states = tl.exp(dA_cs) * states + new_states + tl.store(out_ptrs, states, mask=offs_m < dim) states_ptrs += stride_states_chunk dA_cs_ptr += stride_dA_cs_chunk out_ptrs += stride_out_chunk @@ -155,6 +132,7 @@ def _state_passing_fwd( chunk_offsets=None, ): batch, nchunks, nheads, dim = states.shape + assert batch == 1 if chunk_size is None: chunk_size = dA_cumsum.shape[-1] else: @@ -182,15 +160,12 @@ def _state_passing_fwd( out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) - final_states = torch.empty((batch, nheads, dim), - device=states.device, - dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( states, out, - final_states, dA_cumsum, initial_states, seq_idx, @@ -209,9 +184,6 @@ def _state_passing_fwd( out.stride(1), out.stride(2), out.stride(3), - final_states.stride(0), - final_states.stride(1), - final_states.stride(2), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), @@ -225,4 +197,4 @@ def _state_passing_fwd( HAS_SEQ_IDX=seq_idx is not None, IS_CONT_BATCHED=is_cont_batched, ) - return out, final_states + return out From 9b24bce7b1a328aa4e056939f97a5a6efe9942b9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 17:56:57 -0400 Subject: [PATCH 011/105] Fix bugs Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 6 +++--- vllm/model_executor/layers/mamba/ops/ssd_combined.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 49c4678b4a87..edfa4b12bd1a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -256,13 +256,13 @@ def _chunk_state_fwd_kernel( offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + dA_cs_last = tl.load(dA_cumsum_ptr + - (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + (chunk_size_limit - 1) * stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): x = tl.load(x_ptrs, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 2c158f04b8cf..5c62e01f6dd7 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -97,6 +97,9 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=dt_softplus, dt_limit=dt_limit) + print("dA_cumsum: ", dA_cumsum[0,0,0,:10]) + print("dt: ", dt[0,0,0,:10]) + # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd(B, @@ -107,8 +110,6 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) - print("after chunk_state_fwd: ") - print("states.shape: ", states.shape) print("states: ", states[0,0,0,0,:10]) # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries From e850661da32c7c8c86f95a4f929e1d5ae3c89ee6 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 00:32:10 -0400 Subject: [PATCH 012/105] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 27 +++++++++++-------- .../layers/mamba/ops/ssd_combined.py | 6 +---- .../layers/mamba/ops/ssd_state_passing.py | 14 +++++----- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 65777d4e0789..0aa7f2b28159 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -233,7 +233,7 @@ def _chunk_scan_fwd_kernel( offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, - mask=offs_m < chunk_size, + mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) @@ -248,7 +248,8 @@ def _chunk_scan_fwd_kernel( C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - scale_m = tl.exp(dA_cs_m) + scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + if BLOCK_SIZE_DSTATE <= 128: C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & @@ -284,6 +285,7 @@ def _chunk_scan_fwd_kernel( prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: + offset_tpa = 0 for k in range(0, dstate, BLOCK_SIZE_K): C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & @@ -296,9 +298,10 @@ def _chunk_scan_fwd_kernel( init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ + pid_h * stride_init_states_head \ + offs_n[None, :] * stride_init_states_hdim \ - + offs_k_dstate[:, None] * stride_init_states_dstate + + offs_k_dstate[:, None] * stride_init_states_dstate \ + + offset_tpa prev_states = tl.load(init_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & + mask=(offs_k_dstate[:, None] < dstate-k) & (offs_n[None, :] < hdim), other=0.0) else: @@ -309,16 +312,18 @@ def _chunk_scan_fwd_kernel( states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ + pid_h * stride_states_head \ + offs_n[None, :] * stride_states_hdim \ - + offs_k_dstate[:, None] * stride_states_dstate + + offs_k_dstate[:, None] * stride_states_dstate \ + + offset_tpa prev_states = tl.load(states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & + mask=(offs_k_dstate[:, None] < dstate-k) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K - prev_states_ptrs += BLOCK_SIZE_K + offset_tpa += BLOCK_SIZE_K + acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off @@ -332,16 +337,16 @@ def _chunk_scan_fwd_kernel( (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) for k in range(0, K_MAX, BLOCK_SIZE_K): cb = tl.load(cb_ptrs, - mask=(offs_m[:, None] < chunk_size) & - (offs_k[None, :] < chunk_size - k), + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, - mask=offs_k < chunk_size - k, + mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 5c62e01f6dd7..84ac31542506 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -136,8 +136,6 @@ def _mamba_chunk_scan_combined_fwd(x, chunk_offsets=chunk_offsets) print("after state passing: ") - print("states.shape: ", states.shape) - print("states: ", states[0 ,0, 0,:10]) states = rearrange(states, "... (p n) -> ... p n", n=dstate) @@ -181,9 +179,7 @@ def _mamba_chunk_scan_combined_fwd(x, else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" print("last_chunk: ", last_chunk) - print(states.shape) - varlen_states = states[:, last_chunk, ...].clone() - print(varlen_states.shape) + varlen_states = states[:, last_chunk, ...].clone().squeeze(0) print("varlen_states: ", varlen_states[0,0,0,:10]) final_states = states[:, -1, ...] return out_x, dt, dA_cumsum, states, final_states, varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index b084ed317ee6..e7d00a8fdd89 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -69,11 +69,6 @@ def _state_passing_fwd_kernel( dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - if HAS_INITSTATES: - initstates_ptr += pid_h * stride_initstates_head - if not IS_CONT_BATCHED: - initstates_ptr += pid_b * stride_initstates_batch - initstates_ptr += offs_m * stride_initstates_dim if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch @@ -83,7 +78,10 @@ def _state_passing_fwd_kernel( out_ptrs = out_ptr + offs_m * stride_out_dim if HAS_INITSTATES: - initstates_ptrs = initstates_ptr + 0 * stride_initstates_batch + initstates_ptrs = initstates_ptr + 0 * stride_initstates_batch \ + + pid_h * stride_initstates_head \ + + offs_m * stride_initstates_dim + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) @@ -104,7 +102,9 @@ def _state_passing_fwd_kernel( # we are started a new sequence if prev_seq_idx != seq_idx: if HAS_INITSTATES: - initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch + initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \ + + pid_h * stride_initstates_head \ + + offs_m * stride_initstates_dim states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) From 0d5c3ae9559716882f44f4c8aa7ad7739df5f1cf Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 01:02:40 -0400 Subject: [PATCH 013/105] revert some changes Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index edfa4b12bd1a..47077872356e 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -96,13 +96,13 @@ def _chunk_cumsum_fwd_kernel( 0.0) tl.store(dt_out_ptrs, dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) tl.store(dA_cs_ptrs, dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) @triton.autotune( From 31e05fae6baef01fda20df7577d581e48aeade9c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 01:03:04 -0400 Subject: [PATCH 014/105] fmt Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 47077872356e..12196806e272 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -6,7 +6,6 @@ # ruff: noqa: E501 -import math import torch @@ -61,7 +60,6 @@ def _chunk_cumsum_fwd_kernel( chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - dt_ptr += pid_b * stride_dt_batch + chunk_seqlen_start * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk @@ -237,7 +235,6 @@ def _chunk_state_fwd_kernel( pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) @@ -259,7 +256,8 @@ def _chunk_state_fwd_kernel( chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start dA_cs_last = tl.load(dA_cumsum_ptr + - (chunk_size_limit - 1) * stride_dA_cs_csize).to(tl.float32) + (chunk_size_limit - 1) * stride_dA_cs_csize).to( + tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize @@ -552,7 +550,7 @@ def _chunk_cumsum_fwd(dt, assert A.shape == (nheads, ) if dt_bias is not None: assert dt_bias.shape == (nheads, ) - nchunks = cu_chunk_seqlens.shape[0]-1 + nchunks = cu_chunk_seqlens.shape[0] - 1 dt_out = torch.empty(batch, nheads, nchunks, From a8aff97cc0180fd9547b24f9d1036dc0e74b40ce Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 03:07:00 -0400 Subject: [PATCH 015/105] workign Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 15 +++++---------- .../layers/mamba/ops/ssd_chunk_state.py | 15 +++++++++++++-- .../layers/mamba/ops/ssd_combined.py | 2 ++ 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 0aa7f2b28159..0637a7cda4dd 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -13,7 +13,6 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') -''' @triton.autotune( configs=[ triton.Config( @@ -107,7 +106,6 @@ ], key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], ) -''' @triton.jit def _chunk_scan_fwd_kernel( # Pointers to matrices @@ -233,7 +231,7 @@ def _chunk_scan_fwd_kernel( offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, - mask=offs_m < chunk_size_limit, + mask=offs_m < chunk_size, other=0.0).to(tl.float32) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) @@ -337,16 +335,16 @@ def _chunk_scan_fwd_kernel( (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) for k in range(0, K_MAX, BLOCK_SIZE_K): cb = tl.load(cb_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k[None, :] < chunk_size_limit - k), + mask=(offs_m[:, None] < chunk_size) & + (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, - mask=offs_k < chunk_size_limit - k, + mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: @@ -538,8 +536,5 @@ def _chunk_scan_fwd( HAS_SEQ_IDX=seq_idx is not None, IS_TRITON_22=TRITON_22, HAS_INITSTATES=initial_states is not None, - BLOCK_SIZE_M=64, - BLOCK_SIZE_N=64, - BLOCK_SIZE_K=32, ) return out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 12196806e272..7715a5107467 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -13,7 +13,18 @@ from .mamba_ssm import softplus - +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices @@ -255,6 +266,7 @@ def _chunk_state_fwd_kernel( chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + # should this be limit or not? dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size_limit - 1) * stride_dA_cs_csize).to( tl.float32) @@ -598,7 +610,6 @@ def _chunk_cumsum_fwd(dt, dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, - BLOCK_SIZE_H=1, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 84ac31542506..a89bcb14d90b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -148,6 +148,8 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, output_dtype=torch.float32) + print("CB: ", CB[0,0,0,0,:10]) + # 5. Scan and compute the diagonal blocks, taking into # account past causal states. # - if initial states are provided, then states information will be From 67db9b4175b0e861792914e580706f20ecce4cac Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 03:45:51 -0400 Subject: [PATCH 016/105] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 5 ++++- .../layers/mamba/ops/ssd_combined.py | 18 +++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 34c5a9280752..0ae15cc1aa2a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -682,7 +682,10 @@ def forward_cuda( dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), - state_dtype=ssm_state.dtype) + state_dtype=ssm_state.dtype, + layer=self.prefix, + ) + print("preallocated_ssm_out_p: ", preallocated_ssm_out_p[0,:10]) print("varlen_state: ", varlen_state[0,0,0,:10]) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index a89bcb14d90b..71f43b445e73 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -44,7 +44,8 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, - out=None): + out=None, + layer=None): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape @@ -97,7 +98,16 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=dt_softplus, dt_limit=dt_limit) + + print("layer: ", layer) + + + dA_cumsum_ref = torch.load("dump/dA_cumsum_%s_main" % (layer)) + print("dA_cumsum: ", dA_cumsum[0,0,0,:10]) + print("dA_cumsum_ref: ", dA_cumsum_ref[0,0,0,:10]) + torch.testing.assert_close(dA_cumsum, dA_cumsum_ref, atol=0.0, rtol=0.0) + print("dt: ", dt[0,0,0,:10]) # 2. Compute the state for each intra-chunk @@ -208,7 +218,8 @@ def mamba_chunk_scan_combined(x, out=None, return_final_states=False, return_varlen_states=False, - state_dtype=None): + state_dtype=None, + layer=None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -253,7 +264,8 @@ def mamba_chunk_scan_combined(x, dt_softplus=dt_softplus, dt_limit=dt_limit, out=out, - state_dtype=state_dtype) + state_dtype=state_dtype, + layer=layer) if not return_varlen_states: if not return_final_states: return From d841e827592644152266387da2b52abd79cb2a0a Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 03:58:03 -0400 Subject: [PATCH 017/105] working changes Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_combined.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 71f43b445e73..73a6507c4cb9 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -103,12 +103,11 @@ def _mamba_chunk_scan_combined_fwd(x, dA_cumsum_ref = torch.load("dump/dA_cumsum_%s_main" % (layer)) - - print("dA_cumsum: ", dA_cumsum[0,0,0,:10]) - print("dA_cumsum_ref: ", dA_cumsum_ref[0,0,0,:10]) torch.testing.assert_close(dA_cumsum, dA_cumsum_ref, atol=0.0, rtol=0.0) - print("dt: ", dt[0,0,0,:10]) + dt_ref = torch.load("dump/dt_%s_main" % (layer)) + torch.testing.assert_close(dt, dt_ref, atol=0.0, rtol=0.0) + # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) @@ -120,7 +119,9 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) - print("states: ", states[0,0,0,0,:10]) + states_ref = torch.load("dump/states_%s_main" % (layer)) + torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) From 908aecb646b084ad71d70f816a1f02d8adbd18ee Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 04:30:25 -0400 Subject: [PATCH 018/105] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_state.py | 6 +++-- .../layers/mamba/ops/ssd_combined.py | 23 +++++++++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 7715a5107467..d15147dd5e29 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -13,6 +13,7 @@ from .mamba_ssm import softplus +''' @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_H': 1}), @@ -25,6 +26,7 @@ ], key=['chunk_size', 'nheads'], ) +''' @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices @@ -266,9 +268,8 @@ def _chunk_state_fwd_kernel( chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - # should this be limit or not? dA_cs_last = tl.load(dA_cumsum_ptr + - (chunk_size_limit - 1) * stride_dA_cs_csize).to( + (chunk_size - 1) * stride_dA_cs_csize).to( tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize @@ -610,6 +611,7 @@ def _chunk_cumsum_fwd(dt, dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_H=1, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 73a6507c4cb9..6ef63e07b417 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -100,12 +100,15 @@ def _mamba_chunk_scan_combined_fwd(x, print("layer: ", layer) + has_init = initial_states is not None + print("has_init: ", has_init) - - dA_cumsum_ref = torch.load("dump/dA_cumsum_%s_main" % (layer)) + dA_cumsum_ref = torch.load("dump/dA_cumsum_%s_main_%d" % (layer, has_init)) + torch.cuda.synchronize() torch.testing.assert_close(dA_cumsum, dA_cumsum_ref, atol=0.0, rtol=0.0) - dt_ref = torch.load("dump/dt_%s_main" % (layer)) + dt_ref = torch.load("dump/dt_%s_main_%d" % (layer, has_init)) + torch.cuda.synchronize() torch.testing.assert_close(dt, dt_ref, atol=0.0, rtol=0.0) @@ -119,7 +122,8 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) - states_ref = torch.load("dump/states_%s_main" % (layer)) + states_ref = torch.load("dump/states_%s_main_%d" % (layer, has_init)) + torch.cuda.synchronize() torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) @@ -146,11 +150,16 @@ def _mamba_chunk_scan_combined_fwd(x, is_cont_batched=cu_seqlens is not None, chunk_offsets=chunk_offsets) - print("after state passing: ") - print("states: ", states[0 ,0, 0,:10]) - states = rearrange(states, "... (p n) -> ... p n", n=dstate) + ''' + print("after state passing: ") + states_ref = torch.load("dump/final_states_%s_main_%d" % (layer, has_init)).unsqueeze(0) + print("states.shape: ", states.shape) + print("states_ref.shape: ", states_ref.shape) + torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) + ''' + # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, B, From af7a2465d5b6559fe2e67f3a40386f7e68edc353 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 04:57:14 -0400 Subject: [PATCH 019/105] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 84 +------------------ .../layers/mamba/ops/ssd_chunk_state.py | 73 ---------------- .../layers/mamba/ops/ssd_combined.py | 13 ++- .../layers/mamba/ops/ssd_state_passing.py | 5 -- 4 files changed, 13 insertions(+), 162 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 0637a7cda4dd..81814b5ce8bc 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -15,86 +15,6 @@ @triton.autotune( configs=[ - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, - num_stages=3, - num_warps=8), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 64 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, - num_stages=5, - num_warps=2), - triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, - num_stages=5, - num_warps=2), triton.Config( { 'BLOCK_SIZE_M': 64, @@ -246,9 +166,11 @@ def _chunk_scan_fwd_kernel( C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + #scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + scale_m = tl.exp(dA_cs_m) if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index d15147dd5e29..e11dab8c5c34 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -13,20 +13,12 @@ from .mamba_ssm import softplus -''' @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_H': 1}), - triton.Config({'BLOCK_SIZE_H': 2}), - triton.Config({'BLOCK_SIZE_H': 4}), triton.Config({'BLOCK_SIZE_H': 8}), - triton.Config({'BLOCK_SIZE_H': 16}), - triton.Config({'BLOCK_SIZE_H': 32}), - triton.Config({'BLOCK_SIZE_H': 64}), ], key=['chunk_size', 'nheads'], ) -''' @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices @@ -118,70 +110,6 @@ def _chunk_cumsum_fwd_kernel( @triton.autotune( configs=[ - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, - num_stages=3, - num_warps=8), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, - num_stages=5, - num_warps=2), - triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, - num_stages=5, - num_warps=2), triton.Config( { 'BLOCK_SIZE_M': 64, @@ -611,7 +539,6 @@ def _chunk_cumsum_fwd(dt, dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, - BLOCK_SIZE_H=1, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 6ef63e07b417..fcf439b7a496 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -152,13 +152,11 @@ def _mamba_chunk_scan_combined_fwd(x, states = rearrange(states, "... (p n) -> ... p n", n=dstate) - ''' print("after state passing: ") states_ref = torch.load("dump/final_states_%s_main_%d" % (layer, has_init)).unsqueeze(0) print("states.shape: ", states.shape) print("states_ref.shape: ", states_ref.shape) torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) - ''' # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, @@ -168,7 +166,8 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, output_dtype=torch.float32) - print("CB: ", CB[0,0,0,0,:10]) + CB_ref = torch.load("dump/CB_%s_main_%d" % (layer, has_init)) + torch.testing.assert_close(CB, CB_ref, atol=0.0, rtol=0.0) # 5. Scan and compute the diagonal blocks, taking into # account past causal states. @@ -196,6 +195,14 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, out=out, ) + + out_x_ref = torch.load("dump/out_x_%s_main_%d" % (layer, has_init)) + torch.testing.assert_close(out_x, out_x_ref, atol=0.0, rtol=0.0) + + out_ref = torch.load("dump/out_%s_main_%d" % (layer, has_init)) + torch.testing.assert_close(out, out_ref, atol=0.0, rtol=0.0) + + if cu_seqlens is None: return out_x, dt, dA_cumsum, states, final_states else: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index e7d00a8fdd89..a345fad6795c 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -14,11 +14,6 @@ @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE': 64}), - triton.Config({'BLOCK_SIZE': 128}), - triton.Config({'BLOCK_SIZE': 256}), - triton.Config({'BLOCK_SIZE': 512}), - triton.Config({'BLOCK_SIZE': 1024}), - triton.Config({'BLOCK_SIZE': 2048}), ], key=['dim'], ) From 7ce2b5972568447eddc95110f63ec316f05576d9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 05:42:59 -0400 Subject: [PATCH 020/105] Some test cases working Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 4 ++-- .../layers/mamba/ops/ssd_chunk_scan.py | 10 +++++++--- .../layers/mamba/ops/ssd_chunk_state.py | 8 ++++---- .../layers/mamba/ops/ssd_combined.py | 18 ++++++++++++------ vllm/v1/attention/backends/mamba2_attn.py | 18 +++++++++--------- vllm/v1/core/sched/scheduler.py | 2 +- 6 files changed, 35 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 0ae15cc1aa2a..25ac56b72740 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -687,8 +687,8 @@ def forward_cuda( ) - print("preallocated_ssm_out_p: ", preallocated_ssm_out_p[0,:10]) - print("varlen_state: ", varlen_state[0,0,0,:10]) + #print("preallocated_ssm_out_p: ", preallocated_ssm_out_p[0,:10]) + #print("varlen_state: ", varlen_state[0,0,0,:10]) # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 81814b5ce8bc..dc573bd01e68 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -188,9 +188,11 @@ def _chunk_scan_fwd_kernel( mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: # Set to zero prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: # Load from previous chunk states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ @@ -201,8 +203,8 @@ def _chunk_scan_fwd_kernel( mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) - prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: offset_tpa = 0 @@ -224,9 +226,11 @@ def _chunk_scan_fwd_kernel( mask=(offs_k_dstate[:, None] < dstate-k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: # Set to zero prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: # Load from previous chunk states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ @@ -238,8 +242,8 @@ def _chunk_scan_fwd_kernel( mask=(offs_k_dstate[:, None] < dstate-k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) - prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K offset_tpa += BLOCK_SIZE_K @@ -357,7 +361,7 @@ def _chunk_scan_fwd( assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) - print("out.shape: ", out.shape) + #print("out.shape: ", out.shape) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index e11dab8c5c34..0e029b4de199 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -507,8 +507,8 @@ def _chunk_cumsum_fwd(dt, grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) - print("dt_out.shape: ", dt_out.shape) - print("dA_cumsum.shape: ", dA_cumsum.shape) + #print("dt_out.shape: ", dt_out.shape) + #print("dA_cumsum.shape: ", dA_cumsum.shape) with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( @@ -569,8 +569,8 @@ def _chunk_state_fwd(B, device=x.device, dtype=states_dtype) - print("[_chunk_state_fwd] states.shape: ", states.shape) - print("[_chunk_state_fwd] cu_chunk_seqlens: ", cu_chunk_seqlens) + #print("[_chunk_state_fwd] states.shape: ", states.shape) + #print("[_chunk_state_fwd] cu_chunk_seqlens: ", cu_chunk_seqlens) grid = lambda META: ( triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index fcf439b7a496..71afe952788b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -98,7 +98,7 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=dt_softplus, dt_limit=dt_limit) - + ''' print("layer: ", layer) has_init = initial_states is not None print("has_init: ", has_init) @@ -110,6 +110,7 @@ def _mamba_chunk_scan_combined_fwd(x, dt_ref = torch.load("dump/dt_%s_main_%d" % (layer, has_init)) torch.cuda.synchronize() torch.testing.assert_close(dt, dt_ref, atol=0.0, rtol=0.0) + ''' # 2. Compute the state for each intra-chunk @@ -122,10 +123,11 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) + ''' states_ref = torch.load("dump/states_%s_main_%d" % (layer, has_init)) torch.cuda.synchronize() torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) - + ''' # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) @@ -152,11 +154,13 @@ def _mamba_chunk_scan_combined_fwd(x, states = rearrange(states, "... (p n) -> ... p n", n=dstate) + ''' print("after state passing: ") states_ref = torch.load("dump/final_states_%s_main_%d" % (layer, has_init)).unsqueeze(0) print("states.shape: ", states.shape) print("states_ref.shape: ", states_ref.shape) torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) + ''' # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, @@ -166,8 +170,10 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, output_dtype=torch.float32) + ''' CB_ref = torch.load("dump/CB_%s_main_%d" % (layer, has_init)) torch.testing.assert_close(CB, CB_ref, atol=0.0, rtol=0.0) + ''' # 5. Scan and compute the diagonal blocks, taking into # account past causal states. @@ -195,21 +201,21 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, out=out, ) - + ''' out_x_ref = torch.load("dump/out_x_%s_main_%d" % (layer, has_init)) torch.testing.assert_close(out_x, out_x_ref, atol=0.0, rtol=0.0) out_ref = torch.load("dump/out_%s_main_%d" % (layer, has_init)) torch.testing.assert_close(out, out_ref, atol=0.0, rtol=0.0) - + ''' if cu_seqlens is None: return out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - print("last_chunk: ", last_chunk) + #print("last_chunk: ", last_chunk) varlen_states = states[:, last_chunk, ...].clone().squeeze(0) - print("varlen_states: ", varlen_states[0,0,0,:10]) + #print("varlen_states: ", varlen_states[0,0,0,:10]) final_states = states[:, -1, ...] return out_x, dt, dA_cumsum, states, final_states, varlen_states diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 91fb63dc486c..dd5961bc0553 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -176,10 +176,10 @@ def build(self, common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) - print("num_decodes: ", num_decodes) - print("num_prefills: ", num_prefills) - print("num_decode_tokens: ", num_decode_tokens) - print("num_prefill_tokens: ", num_prefill_tokens) + #print("num_decodes: ", num_decodes) + #print("num_prefills: ", num_prefills) + #print("num_decode_tokens: ", num_decode_tokens) + #print("num_prefill_tokens: ", num_prefill_tokens) # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: @@ -209,8 +209,8 @@ def build(self, query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ -num_prefills - 1:] - num_decode_tokens - print("num_computed_tokens_p: ", num_computed_tokens_p) - print("query_start_loc_p: ", query_start_loc_p) + #print("num_computed_tokens_p: ", num_computed_tokens_p) + #print("query_start_loc_p: ", query_start_loc_p) cu_chunk_seqlen = [] last_chunk = [] @@ -218,7 +218,7 @@ def build(self, for req_idx in range(num_prefills): this_num_computed = num_computed_tokens_p[req_idx].item() this_new_tokens = query_start_loc_p_cpu[req_idx+1].item() - query_start_loc_p_cpu[req_idx].item() - print(req_idx, this_num_computed, this_new_tokens) + #print(req_idx, this_num_computed, this_new_tokens) # if computed tokens are not chunk-aligned, use the first # chunk to finish it off @@ -247,8 +247,8 @@ def build(self, cu_chunk_seqlen_p = torch.as_tensor(cu_chunk_seqlen, device=query_start_loc.device, dtype=torch.int32) last_chunk_p = torch.as_tensor(last_chunk, device=query_start_loc.device, dtype=torch.int32) - print("cu_chunk_seqlen: ", cu_chunk_seqlen) - print("cu_chunk_seqlen_p: ", cu_chunk_seqlen_p) + #print("cu_chunk_seqlen: ", cu_chunk_seqlen) + #print("cu_chunk_seqlen_p: ", cu_chunk_seqlen_p) # We compute metadata for chunked prefill once at the top level # model forward and reuse them in mamba layers. If not needed, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 49b6d99e4ab1..101867f5cfc5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -598,7 +598,7 @@ def schedule(self) -> SchedulerOutput: structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) - print(scheduler_output) + #print(scheduler_output) # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store # 2. Wrap up all the KV cache load / save ops into an opaque object From f950f2eb7d4eb849878aa80b0b24a588689fa427 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 07:58:45 -0400 Subject: [PATCH 021/105] Fix IMA Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 141 ++++++++++-------- 1 file changed, 78 insertions(+), 63 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index dc573bd01e68..49daec377d2e 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -52,6 +52,7 @@ def _chunk_scan_fwd_kernel( batch, seqlen, nheads_ngroups_ratio, + nchunks, # Strides stride_cb_batch, stride_cb_chunk, @@ -107,6 +108,7 @@ def _chunk_scan_fwd_kernel( IS_TRITON_22: tl.constexpr, HAS_INITSTATES: tl.constexpr, ): + pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch @@ -156,84 +158,92 @@ def _chunk_scan_fwd_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Without the if (pid_c > -1), with Triton 2.1.0, I get - # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. - # With Triton 2.2.0, this works - if IS_TRITON_22 or c_idx > -1: - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k_dstate = tl.arange( - 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + - offs_k_dstate[None, :] * stride_C_dstate) + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + + + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + + #scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + scale_m = tl.exp(dA_cs_m) - #scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) - scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) + + if seq_idx != seq_idx_prev: + if HAS_INITSTATES: + # load from init states + init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ + + pid_h * stride_init_states_head \ + + offs_n[None, :] * stride_init_states_hdim \ + + offs_k_dstate[:, None] * stride_init_states_dstate + prev_states = tl.load(init_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + else: + # Set to zero + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + else: + if c_idx > 0: + tl.device_assert(c_idx < nchunks) + # Load from praevious chunk + states_ptrs = states_ptr + (c_idx-1) * stride_states_chunk \ + + pid_h * stride_states_head \ + + offs_n[None, :] * stride_states_hdim \ + + offs_k_dstate[:, None] * stride_states_dstate + prev_states = tl.load(states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + else: + # Set to zero + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) + prev_states = prev_states.to(C_ptr.dtype.element_ty) - if BLOCK_SIZE_DSTATE <= 128: + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + offset_tpa = 0 + for k in range(0, dstate, BLOCK_SIZE_K): C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k_dstate[None, :] < dstate), + (offs_k_dstate[None, :] < dstate - k), other=0.0) - - + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) if seq_idx != seq_idx_prev: if HAS_INITSTATES: # load from init states init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ + pid_h * stride_init_states_head \ + offs_n[None, :] * stride_init_states_hdim \ - + offs_k_dstate[:, None] * stride_init_states_dstate + + offs_k_dstate[:, None] * stride_init_states_dstate \ + + offset_tpa prev_states = tl.load(init_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & + mask=(offs_k_dstate[:, None] < dstate-k) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) else: # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) prev_states = prev_states.to(C_ptr.dtype.element_ty) else: - # Load from previous chunk - states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ - + pid_h * stride_states_head \ - + offs_n[None, :] * stride_states_hdim \ - + offs_k_dstate[:, None] * stride_states_dstate - prev_states = tl.load(states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - - acc = tl.dot(C, prev_states) * scale_m[:, None] - else: - offset_tpa = 0 - for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load(C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k_dstate[None, :] < dstate - k), - other=0.0) - # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) - if seq_idx != seq_idx_prev: - if HAS_INITSTATES: - # load from init states - init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ - + pid_h * stride_init_states_head \ - + offs_n[None, :] * stride_init_states_hdim \ - + offs_k_dstate[:, None] * stride_init_states_dstate \ - + offset_tpa - prev_states = tl.load(init_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate-k) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: + if c_idx > 0: # Load from previous chunk - states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ + states_ptrs = states_ptr + (c_idx-1) * stride_states_chunk \ + pid_h * stride_states_head \ + offs_n[None, :] * stride_states_hdim \ + offs_k_dstate[:, None] * stride_states_dstate \ @@ -243,12 +253,16 @@ def _chunk_scan_fwd_kernel( (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) + else: + # Set to zero + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) + prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc += tl.dot(C, prev_states) - C_ptrs += BLOCK_SIZE_K - offset_tpa += BLOCK_SIZE_K + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + offset_tpa += BLOCK_SIZE_K - acc *= scale_m[:, None] + acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + @@ -322,9 +336,11 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) + out_ptr += pid_b * stride_out_batch + chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & @@ -361,8 +377,6 @@ def _chunk_scan_fwd( assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) - #print("out.shape: ", out.shape) - if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) @@ -413,6 +427,7 @@ def _chunk_scan_fwd( batch, seqlen, nheads // ngroups, + nchunks, cb.stride(0), cb.stride(1), cb.stride(2), From 039267d4fe7a6305b1605129639c28a190d7ded7 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 12 Sep 2025 15:14:45 +0000 Subject: [PATCH 022/105] Small cleanup Signed-off-by: Stanislaw Wozniak --- .../layers/mamba/mamba_mixer2.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index e49c417c9ce7..33358c0689f7 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -568,8 +568,32 @@ def forward_cuda( [num_decodes, num_prefills], dim=0, ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + else: + hidden_states_B_C_p, hidden_states_B_C_d = torch.split( + hidden_states_B_C, + [num_prefill_tokens, num_decodes], + dim=0, + ) + dt_p, dt_d = torch.split( + dt, + [num_prefill_tokens, num_decodes], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + + 1] + if has_prefill else None) - # Note: Eventually this will be moved to mamba2 metadata builder: + if envs.VLLM_USE_V1 and cache_enabled: + # Additional variables used by caching logic: seq_lens_pending = ( torch.roll(attn_metadata.query_start_loc, -1, -1) - attn_metadata.query_start_loc)[:-1] @@ -595,30 +619,6 @@ def forward_cuda( current_last_token_block_idx, [num_decodes, num_prefills], dim=0) - query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1:] - - num_decodes if has_prefill else None) - else: - hidden_states_B_C_p, hidden_states_B_C_d = torch.split( - hidden_states_B_C, - [num_prefill_tokens, num_decodes], - dim=0, - ) - dt_p, dt_d = torch.split( - dt, - [num_prefill_tokens, num_decodes], - dim=0, - ) - # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - state_indices_tensor, - [num_prefills, num_decodes], - dim=0, - ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + - 1] - if has_prefill else None) - # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs preallocated_ssm_out = torch.empty( @@ -649,7 +649,7 @@ def forward_cuda( # pointed to by "state_indices_tensor" x = hidden_states_B_C_p.transpose( 0, 1) # this is the form that causal-conv see - if mamba2_metadata.cu_seqlen is None: #TODO: move to MDBuilder? + if mamba2_metadata.cu_seqlen is None: mamba2_metadata = update_metadata(x, query_start_loc_p, mamba2_metadata) From fe095c4fb87d115050de75850f309d5f29aad620 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Fri, 12 Sep 2025 15:15:20 +0000 Subject: [PATCH 023/105] State fix Signed-off-by: Thomas Ortner --- .../layers/mamba/ops/ssd_combined.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 7c5e166f3231..a5ae03be90f1 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -42,7 +42,8 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, - out=None): + out=None, + org_cu_seqlens=None): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape @@ -162,7 +163,13 @@ def _mamba_chunk_scan_combined_fwd(x, if cu_seqlens is None: return out_x, dt, dA_cumsum, states, final_states else: - assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + if org_cu_seqlens is not None: + # Padding logic: + # The varlen_state, i.e., last state of the request, for the first request is wrong. + # The reason for this is that the varlen state is computed for the full last chunk and not for the partial chunk + # The workaround solution is that the cu_seqlens[1] needs to be corrected to be the original cu_seq_len + cu_seqlens[1] = org_cu_seqlens[1] varlen_states = chunk_state_varlen( B.squeeze(0), x.squeeze(0), @@ -220,6 +227,7 @@ def mamba_chunk_scan_combined(x, cu_seqlens = None else: assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + org_cu_seqlens = cu_seqlens # Padding logic used by prefix caching: if seq_pad is not None and seq_pad.sum() > 0: pass #Padding needed @@ -271,7 +279,8 @@ def pad(v): dt_softplus=dt_softplus, dt_limit=dt_limit, out=out, - state_dtype=state_dtype) + state_dtype=state_dtype, + org_cu_seqlens=org_cu_seqlens) # Padding logic used by prefix caching: if return_varlen_states and seq_pad is not None and seq_pad.sum() > 0: From 75e01c87e1871bb3a5e94c8f55268c8b6953c3fc Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 13:56:41 -0400 Subject: [PATCH 024/105] Add back autotune config Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 80 +++++++++++++++++++ .../layers/mamba/ops/ssd_chunk_state.py | 70 ++++++++++++++++ .../layers/mamba/ops/ssd_state_passing.py | 5 ++ 3 files changed, 155 insertions(+) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 49daec377d2e..4e116ae9fc8b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -15,6 +15,86 @@ @triton.autotune( configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), triton.Config( { 'BLOCK_SIZE_M': 64, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 0e029b4de199..3f4a5b4e6006 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -15,7 +15,13 @@ @triton.autotune( configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), ], key=['chunk_size', 'nheads'], ) @@ -110,6 +116,70 @@ def _chunk_cumsum_fwd_kernel( @triton.autotune( configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), triton.Config( { 'BLOCK_SIZE_M': 64, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index a345fad6795c..e7d00a8fdd89 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -14,6 +14,11 @@ @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), ], key=['dim'], ) From 2698f2eca296767ccab8c6def3428dc4bef6511c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 14:47:49 -0400 Subject: [PATCH 025/105] cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 25ac56b72740..222f89f2c35b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -682,13 +682,7 @@ def forward_cuda( dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), - state_dtype=ssm_state.dtype, - layer=self.prefix, - ) - - - #print("preallocated_ssm_out_p: ", preallocated_ssm_out_p[0,:10]) - #print("varlen_state: ", varlen_state[0,0,0,:10]) + state_dtype=ssm_state.dtype) # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor From d3f05b7a8682c5e21bc22416f83b4a6e6c84f0e7 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 14:49:45 -0400 Subject: [PATCH 026/105] cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_bmm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 3a245b127f01..786721733af7 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -168,7 +168,6 @@ def _bmm_chunk_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out = acc.to(out_ptr.dtype.element_ty) out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head From df635038e9b724c894ac4fdb20fb0bc44cc7fc8f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 14:55:24 -0400 Subject: [PATCH 027/105] cleanup Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_state.py | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 3f4a5b4e6006..6f710c76f5f3 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -13,6 +13,7 @@ from .mamba_ssm import softplus + @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_H': 1}), @@ -263,15 +264,12 @@ def _chunk_state_fwd_kernel( b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize - - chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - dA_cs_last = tl.load(dA_cumsum_ptr + - (chunk_size - 1) * stride_dA_cs_csize).to( - tl.float32) - + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): x = tl.load(x_ptrs, @@ -287,9 +285,7 @@ def _chunk_state_fwd_kernel( other=0.0).to(tl.float32) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k - b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -297,7 +293,6 @@ def _chunk_state_fwd_kernel( b_ptrs += BLOCK_SIZE_K * stride_b_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - states = acc.to(states_ptr.dtype.element_ty) states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head @@ -577,9 +572,6 @@ def _chunk_cumsum_fwd(dt, grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) - #print("dt_out.shape: ", dt_out.shape) - #print("dA_cumsum.shape: ", dA_cumsum.shape) - with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( dt, @@ -638,10 +630,6 @@ def _chunk_state_fwd(B, states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) - - #print("[_chunk_state_fwd] states.shape: ", states.shape) - #print("[_chunk_state_fwd] cu_chunk_seqlens: ", cu_chunk_seqlens) - grid = lambda META: ( triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) From c5edccdae6fbb740c7c40259ca4ddc26f595cdeb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:00:17 -0400 Subject: [PATCH 028/105] cleanup Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_state.py | 2 - .../layers/mamba/ops/ssd_combined.py | 54 ++----------------- 2 files changed, 3 insertions(+), 53 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 6f710c76f5f3..4d4e593b21c4 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -246,10 +246,8 @@ def _chunk_state_fwd_kernel( num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * stride_b_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_b_head x_ptr += pid_b * stride_x_batch + chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 71afe952788b..48d4c7e6da09 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -44,8 +44,7 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, - out=None, - layer=None): + out=None): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape @@ -98,21 +97,6 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=dt_softplus, dt_limit=dt_limit) - ''' - print("layer: ", layer) - has_init = initial_states is not None - print("has_init: ", has_init) - - dA_cumsum_ref = torch.load("dump/dA_cumsum_%s_main_%d" % (layer, has_init)) - torch.cuda.synchronize() - torch.testing.assert_close(dA_cumsum, dA_cumsum_ref, atol=0.0, rtol=0.0) - - dt_ref = torch.load("dump/dt_%s_main_%d" % (layer, has_init)) - torch.cuda.synchronize() - torch.testing.assert_close(dt, dt_ref, atol=0.0, rtol=0.0) - ''' - - # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd(B, @@ -123,12 +107,6 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) - ''' - states_ref = torch.load("dump/states_%s_main_%d" % (layer, has_init)) - torch.cuda.synchronize() - torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) - ''' - # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states @@ -151,17 +129,8 @@ def _mamba_chunk_scan_combined_fwd(x, out_dtype=state_dtype if state_dtype is not None else C.dtype, is_cont_batched=cu_seqlens is not None, chunk_offsets=chunk_offsets) - states = rearrange(states, "... (p n) -> ... p n", n=dstate) - ''' - print("after state passing: ") - states_ref = torch.load("dump/final_states_%s_main_%d" % (layer, has_init)).unsqueeze(0) - print("states.shape: ", states.shape) - print("states_ref.shape: ", states_ref.shape) - torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) - ''' - # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, B, @@ -170,11 +139,6 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, output_dtype=torch.float32) - ''' - CB_ref = torch.load("dump/CB_%s_main_%d" % (layer, has_init)) - torch.testing.assert_close(CB, CB_ref, atol=0.0, rtol=0.0) - ''' - # 5. Scan and compute the diagonal blocks, taking into # account past causal states. # - if initial states are provided, then states information will be @@ -201,21 +165,11 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, out=out, ) - ''' - out_x_ref = torch.load("dump/out_x_%s_main_%d" % (layer, has_init)) - torch.testing.assert_close(out_x, out_x_ref, atol=0.0, rtol=0.0) - - out_ref = torch.load("dump/out_%s_main_%d" % (layer, has_init)) - torch.testing.assert_close(out, out_ref, atol=0.0, rtol=0.0) - ''' - if cu_seqlens is None: return out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - #print("last_chunk: ", last_chunk) varlen_states = states[:, last_chunk, ...].clone().squeeze(0) - #print("varlen_states: ", varlen_states[0,0,0,:10]) final_states = states[:, -1, ...] return out_x, dt, dA_cumsum, states, final_states, varlen_states @@ -241,8 +195,7 @@ def mamba_chunk_scan_combined(x, out=None, return_final_states=False, return_varlen_states=False, - state_dtype=None, - layer=None): + state_dtype=None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -287,8 +240,7 @@ def mamba_chunk_scan_combined(x, dt_softplus=dt_softplus, dt_limit=dt_limit, out=out, - state_dtype=state_dtype, - layer=layer) + state_dtype=state_dtype) if not return_varlen_states: if not return_final_states: return From 712ced11f6ae00dce9433c76a25db343d0942f0c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:06:08 -0400 Subject: [PATCH 029/105] cleanup Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_state_passing.py | 11 ++-------- vllm/v1/attention/backends/mamba2_attn.py | 21 +------------------ 2 files changed, 3 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index e7d00a8fdd89..c1207424a9a1 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -78,7 +78,7 @@ def _state_passing_fwd_kernel( out_ptrs = out_ptr + offs_m * stride_out_dim if HAS_INITSTATES: - initstates_ptrs = initstates_ptr + 0 * stride_initstates_batch \ + initstates_ptrs = initstates_ptr + stride_initstates_batch \ + pid_h * stride_initstates_head \ + offs_m * stride_initstates_dim @@ -90,16 +90,12 @@ def _state_passing_fwd_kernel( prev_seq_idx = 0 for c in range(nchunks): - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + c) - new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - seq_idx = tl.load(seq_idx_ptr + chunk_seqlen_start * stride_seq_idx_seqlen) - - # we are started a new sequence + # we have started a new sequence if prev_seq_idx != seq_idx: if HAS_INITSTATES: initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \ @@ -112,7 +108,6 @@ def _state_passing_fwd_kernel( states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) prev_seq_idx = seq_idx - states = tl.exp(dA_cs) * states + new_states tl.store(out_ptrs, states, mask=offs_m < dim) states_ptrs += stride_states_chunk @@ -132,7 +127,6 @@ def _state_passing_fwd( chunk_offsets=None, ): batch, nchunks, nheads, dim = states.shape - assert batch == 1 if chunk_size is None: chunk_size = dA_cumsum.shape[-1] else: @@ -160,7 +154,6 @@ def _state_passing_fwd( out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) - grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index dd5961bc0553..010c8f25946c 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -126,8 +126,6 @@ class Mamba2AttentionMetadata: seq_idx_p: Optional[torch.Tensor] chunk_indices_p: Optional[torch.Tensor] chunk_offsets_p: Optional[torch.Tensor] - - # tpa cu_chunk_seqlen_p: Optional[torch.Tensor] last_chunk_p: Optional[torch.Tensor] @@ -164,8 +162,6 @@ def build(self, # currently we really only support the FlashAttention backend has_initial_states_p = None prep_initial_states = False - - cu_chunk_seqlen_p = None last_chunk_p = None @@ -176,11 +172,6 @@ def build(self, common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) - #print("num_decodes: ", num_decodes) - #print("num_prefills: ", num_prefills) - #print("num_decode_tokens: ", num_decode_tokens) - #print("num_prefill_tokens: ", num_prefill_tokens) - # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: #[batch,] @@ -191,11 +182,9 @@ def build(self, has_initial_states_p = has_initial_states_cpu.to( query_start_loc.device) - query_start_loc_p = common_attn_metadata.query_start_loc[ -num_prefills - 1:] - num_decode_tokens - seq_idx_p = torch.repeat_interleave(torch.arange( num_prefills, dtype=torch.int32, @@ -204,25 +193,20 @@ def build(self, output_size=num_prefill_tokens) seq_idx_p.unsqueeze_(0) - num_computed_tokens_p = common_attn_metadata.num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ -num_prefills - 1:] - num_decode_tokens - #print("num_computed_tokens_p: ", num_computed_tokens_p) - #print("query_start_loc_p: ", query_start_loc_p) - + # TODO (tdoublep): Optimize the code cu_chunk_seqlen = [] last_chunk = [] seqlen_pos = 0 for req_idx in range(num_prefills): this_num_computed = num_computed_tokens_p[req_idx].item() this_new_tokens = query_start_loc_p_cpu[req_idx+1].item() - query_start_loc_p_cpu[req_idx].item() - #print(req_idx, this_num_computed, this_new_tokens) # if computed tokens are not chunk-aligned, use the first # chunk to finish it off - # TODO(tdoublep): I guess we need block size actually? if this_num_computed % self.chunk_size != 0: cu_chunk_seqlen.append(seqlen_pos) # how many tokens to finish the chunk? @@ -247,9 +231,6 @@ def build(self, cu_chunk_seqlen_p = torch.as_tensor(cu_chunk_seqlen, device=query_start_loc.device, dtype=torch.int32) last_chunk_p = torch.as_tensor(last_chunk, device=query_start_loc.device, dtype=torch.int32) - #print("cu_chunk_seqlen: ", cu_chunk_seqlen) - #print("cu_chunk_seqlen_p: ", cu_chunk_seqlen_p) - # We compute metadata for chunked prefill once at the top level # model forward and reuse them in mamba layers. If not needed, # they will be ignored inside mamba kernels. From dc85f7ea1748be77d1fb599aeb50d9237160224f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:06:47 -0400 Subject: [PATCH 030/105] cleanup Signed-off-by: Thomas Parnell --- vllm/v1/core/sched/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 101867f5cfc5..d1a6dd73e85c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -598,7 +598,7 @@ def schedule(self) -> SchedulerOutput: structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) - #print(scheduler_output) + # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store # 2. Wrap up all the KV cache load / save ops into an opaque object From 5e827a635cebf277986a88f5c4555bd68b0ec2f6 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:15:14 -0400 Subject: [PATCH 031/105] cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_state_passing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index c1207424a9a1..93cb1b485c53 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -78,7 +78,7 @@ def _state_passing_fwd_kernel( out_ptrs = out_ptr + offs_m * stride_out_dim if HAS_INITSTATES: - initstates_ptrs = initstates_ptr + stride_initstates_batch \ + initstates_ptrs = initstates_ptr \ + pid_h * stride_initstates_head \ + offs_m * stride_initstates_dim From 42e4b27d9e1754808a8f7355d4dba7a75a6b9452 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:36:25 -0400 Subject: [PATCH 032/105] cleanup Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 138 +++++------------- 1 file changed, 39 insertions(+), 99 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 4e116ae9fc8b..0aa12f7f138b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -13,6 +13,7 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + @triton.autotune( configs=[ triton.Config( @@ -188,48 +189,43 @@ def _chunk_scan_fwd_kernel( IS_TRITON_22: tl.constexpr, HAS_INITSTATES: tl.constexpr, ): - pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch - - # logical chunks = physical chunks - # always start from beginning - c_idx = pid_c - c_off = 0 - pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + ( pid_h // nheads_ngroups_ratio) * stride_cb_head - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - x_ptr += pid_b * stride_x_batch + chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head C_ptr += pid_b * stride_C_batch + chunk_seqlen_start * stride_C_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_C_head # M-block offsets and prev states # - logic in next block may override these if there is an active offset - offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) - #prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head - #prev_states_hdim = stride_states_hdim - #prev_states_dstate = stride_states_dstate - - chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) seq_idx_ptr += pid_b * stride_seq_idx_batch + chunk_seqlen_start * stride_seq_idx_seqlen seq_idx = tl.load(seq_idx_ptr) seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, - mask=c_idx >= 1, + mask=pid_c >= 1, other=-1) + if HAS_INITSTATES: + prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim + prev_states_dstate = stride_init_states_dstate + else: + prev_states_ptr = states_ptr + (pid_c-1) * stride_states_chunk + pid_h * stride_states_head + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, @@ -241,110 +237,56 @@ def _chunk_scan_fwd_kernel( offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 offs_k_dstate = tl.arange( 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - #scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) scale_m = tl.exp(dA_cs_m) - if BLOCK_SIZE_DSTATE <= 128: - C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) - if seq_idx != seq_idx_prev: - if HAS_INITSTATES: - # load from init states - init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ - + pid_h * stride_init_states_head \ - + offs_n[None, :] * stride_init_states_hdim \ - + offs_k_dstate[:, None] * stride_init_states_dstate - prev_states = tl.load(init_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) - prev_states = prev_states.to(C_ptr.dtype.element_ty) + if (seq_idx != seq_idx_prev and HAS_INITSTATES) or pid_c > 0: + prev_states_ptrs = prev_states_ptr \ + + offs_n[None, :] * prev_states_hdim \ + + offs_k_dstate[:, None] * prev_states_dstate + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: - if c_idx > 0: - tl.device_assert(c_idx < nchunks) - # Load from praevious chunk - states_ptrs = states_ptr + (c_idx-1) * stride_states_chunk \ - + pid_h * stride_states_head \ - + offs_n[None, :] * stride_states_hdim \ - + offs_k_dstate[:, None] * stride_states_dstate - prev_states = tl.load(states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) - prev_states = prev_states.to(C_ptr.dtype.element_ty) + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: - offset_tpa = 0 + prev_states_ptrs = prev_states_ptr \ + + offs_n[None, :] * prev_states_hdim \ + + offs_k_dstate[:, None] * prev_states_dstate for k in range(0, dstate, BLOCK_SIZE_K): C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) - # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) - if seq_idx != seq_idx_prev: - if HAS_INITSTATES: - # load from init states - init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ - + pid_h * stride_init_states_head \ - + offs_n[None, :] * stride_init_states_hdim \ - + offs_k_dstate[:, None] * stride_init_states_dstate \ - + offset_tpa - prev_states = tl.load(init_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate-k) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) - prev_states = prev_states.to(C_ptr.dtype.element_ty) + if (seq_idx != seq_idx_prev and HAS_INITSTATES) or pid_c > 0: + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: - if c_idx > 0: - # Load from previous chunk - states_ptrs = states_ptr + (c_idx-1) * stride_states_chunk \ - + pid_h * stride_states_head \ - + offs_n[None, :] * stride_states_hdim \ - + offs_k_dstate[:, None] * stride_states_dstate \ - + offset_tpa - prev_states = tl.load(states_ptrs, - mask=(offs_k_dstate[:, None] < dstate-k) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K - offset_tpa += BLOCK_SIZE_K - + prev_states_ptrs += BLOCK_SIZE_K acc *= scale_m[:, None] - offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + offs_k = tl.arange(0, BLOCK_SIZE_K) cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + @@ -381,7 +323,7 @@ def _chunk_scan_fwd_kernel( dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_D: @@ -416,11 +358,9 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) - out_ptr += pid_b * stride_out_batch + chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) - tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & From d8591820e0e1ea842e3b90688b490829b860c4d3 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:39:59 -0400 Subject: [PATCH 033/105] cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 0aa12f7f138b..de920f59ec2d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -133,7 +133,6 @@ def _chunk_scan_fwd_kernel( batch, seqlen, nheads_ngroups_ratio, - nchunks, # Strides stride_cb_batch, stride_cb_chunk, @@ -447,7 +446,6 @@ def _chunk_scan_fwd( batch, seqlen, nheads // ngroups, - nchunks, cb.stride(0), cb.stride(1), cb.stride(2), From 56b37c22e2608314ef61be789655b5578df81298 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:41:32 -0400 Subject: [PATCH 034/105] cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 4d4e593b21c4..eca98ff73e8b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -569,7 +569,6 @@ def _chunk_cumsum_fwd(dt, dtype=torch.float32) grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) - with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( dt, From e21b4e633e03865b9b6c7469455d7005a9e23af9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:58:20 -0400 Subject: [PATCH 035/105] lint Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_bmm.py | 4 +-- .../layers/mamba/ops/ssd_chunk_scan.py | 32 +++++++++++-------- .../layers/mamba/ops/ssd_chunk_state.py | 1 - .../layers/mamba/ops/ssd_combined.py | 5 ++- .../layers/mamba/ops/ssd_state_passing.py | 9 +++--- vllm/v1/attention/backends/mamba2_attn.py | 23 +++++++++---- 6 files changed, 41 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 786721733af7..260f1e5239af 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -6,8 +6,6 @@ # ruff: noqa: E501,SIM102 -import math - import torch from vllm.triton_utils import tl, triton @@ -209,7 +207,7 @@ def _bmm_chunk_fwd(a, a = a.contiguous() if b.stride(-1) != 1 and b.stride(1) != 1: b = b.contiguous() - nchunks = len(cu_chunk_seqlens)-1 + nchunks = len(cu_chunk_seqlens) - 1 # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype out = torch.empty( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index de920f59ec2d..207c440b0ff6 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -212,15 +212,16 @@ def _chunk_scan_fwd_kernel( seq_idx_ptr += pid_b * stride_seq_idx_batch + chunk_seqlen_start * stride_seq_idx_seqlen seq_idx = tl.load(seq_idx_ptr) seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, - mask=pid_c >= 1, - other=-1) + mask=pid_c >= 1, + other=-1) if HAS_INITSTATES: prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head prev_states_hdim = stride_init_states_hdim prev_states_dstate = stride_init_states_dstate else: - prev_states_ptr = states_ptr + (pid_c-1) * stride_states_chunk + pid_h * stride_states_head + prev_states_ptr = states_ptr + ( + pid_c - 1) * stride_states_chunk + pid_h * stride_states_head prev_states_hdim = stride_states_hdim prev_states_dstate = stride_states_dstate @@ -254,12 +255,13 @@ def _chunk_scan_fwd_kernel( + offs_n[None, :] * prev_states_hdim \ + offs_k_dstate[:, None] * prev_states_dstate prev_states = tl.load(prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & - (offs_n[None, :] < hdim), - other=0.0) + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) else: - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty) + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), + dtype=C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] @@ -273,13 +275,15 @@ def _chunk_scan_fwd_kernel( (offs_k_dstate[None, :] < dstate - k), other=0.0) if (seq_idx != seq_idx_prev and HAS_INITSTATES) or pid_c > 0: - prev_states = tl.load(prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate - k) & - (offs_n[None, :] < hdim), - other=0.0) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) else: - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty) + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), + dtype=C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K prev_states_ptrs += BLOCK_SIZE_K @@ -418,8 +422,8 @@ def _chunk_scan_fwd( else: out_x = None - grid = lambda META: ( - triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0)) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index eca98ff73e8b..448c7970b64b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -6,7 +6,6 @@ # ruff: noqa: E501 - import torch from vllm.triton_utils import tl, triton diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 48d4c7e6da09..e04ff3da991d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -14,8 +14,7 @@ from .ssd_bmm import _bmm_chunk_fwd from .ssd_chunk_scan import _chunk_scan_fwd -from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, - chunk_state_varlen) +from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd from .ssd_state_passing import _state_passing_fwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') @@ -165,12 +164,12 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, out=out, ) + final_states = states[:, -1, ...] if cu_seqlens is None: return out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" varlen_states = states[:, last_chunk, ...].clone().squeeze(0) - final_states = states[:, -1, ...] return out_x, dt, dA_cumsum, states, final_states, varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 93cb1b485c53..3a3de30ba2f5 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -82,8 +82,7 @@ def _state_passing_fwd_kernel( + pid_h * stride_initstates_head \ + offs_m * stride_initstates_dim - states = tl.load(initstates_ptrs, - mask=offs_m < dim, + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) else: states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) @@ -94,15 +93,15 @@ def _state_passing_fwd_kernel( new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - seq_idx = tl.load(seq_idx_ptr + chunk_seqlen_start * stride_seq_idx_seqlen) + seq_idx = tl.load(seq_idx_ptr + + chunk_seqlen_start * stride_seq_idx_seqlen) # we have started a new sequence if prev_seq_idx != seq_idx: if HAS_INITSTATES: initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \ + pid_h * stride_initstates_head \ + offs_m * stride_initstates_dim - states = tl.load(initstates_ptrs, - mask=offs_m < dim, + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) else: states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 010c8f25946c..9a060bff6d1f 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -9,12 +9,13 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm.v1.attention.backends.mamba_attn import ( BaseMambaAttentionMetadataBuilder) from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.utils import cdiv + def _query_start_loc_to_chunk_indices_offsets( query_start_loc: torch.Tensor, chunk_size: int, @@ -193,7 +194,9 @@ def build(self, output_size=num_prefill_tokens) seq_idx_p.unsqueeze_(0) - num_computed_tokens_p = common_attn_metadata.num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] + num_computed_tokens_p = \ + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills:num_reqs] query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ -num_prefills - 1:] - num_decode_tokens @@ -203,14 +206,16 @@ def build(self, seqlen_pos = 0 for req_idx in range(num_prefills): this_num_computed = num_computed_tokens_p[req_idx].item() - this_new_tokens = query_start_loc_p_cpu[req_idx+1].item() - query_start_loc_p_cpu[req_idx].item() + this_new_tokens = query_start_loc_p_cpu[req_idx + 1].item( + ) - query_start_loc_p_cpu[req_idx].item() # if computed tokens are not chunk-aligned, use the first # chunk to finish it off if this_num_computed % self.chunk_size != 0: cu_chunk_seqlen.append(seqlen_pos) # how many tokens to finish the chunk? - chunk_len = cdiv(this_num_computed, self.chunk_size)*self.chunk_size - this_num_computed + chunk_len = cdiv(this_num_computed, self.chunk_size + ) * self.chunk_size - this_num_computed # we can only use at most this_new_tokens chunk_len = min(chunk_len, this_new_tokens) seqlen_pos += chunk_len @@ -224,12 +229,16 @@ def build(self, this_new_tokens -= chunk_len assert this_new_tokens == 0 - last_chunk.append(len(cu_chunk_seqlen)-1) + last_chunk.append(len(cu_chunk_seqlen) - 1) cu_chunk_seqlen.append(seqlen_pos) - cu_chunk_seqlen_p = torch.as_tensor(cu_chunk_seqlen, device=query_start_loc.device, dtype=torch.int32) - last_chunk_p = torch.as_tensor(last_chunk, device=query_start_loc.device, dtype=torch.int32) + cu_chunk_seqlen_p = torch.as_tensor(cu_chunk_seqlen, + device=query_start_loc.device, + dtype=torch.int32) + last_chunk_p = torch.as_tensor(last_chunk, + device=query_start_loc.device, + dtype=torch.int32) # We compute metadata for chunked prefill once at the top level # model forward and reuse them in mamba layers. If not needed, From 90dd7f50fd17b54f902d1a99bab1b052852fd907 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 12 Sep 2025 23:54:33 +0000 Subject: [PATCH 036/105] Initial conv state fixes Signed-off-by: Stanislaw Wozniak --- .../layers/mamba/mamba_mixer2.py | 85 +++++++++++-------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 93bd5564844d..48c30ad606d9 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -605,7 +605,11 @@ def forward_cuda( seq_lens_completed = (mamba2_metadata.seq_lens - seq_lens_pending) # e.g. 16 blocks computed; 0th based indexing -> state[15] last_computed_token_block_idx = \ - seq_lens_completed // mamba_block_size - 1 + cdiv(seq_lens_completed, mamba_block_size) - 1 + last_computed_token_block_offset = \ + seq_lens_completed % mamba_block_size + if last_computed_token_block_offset.sum() > 0: + pass # block mis-alignment # -1 in case it's non-computed and causes later issues with indexing last_computed_token_block_idx = last_computed_token_block_idx.clamp( min=0) @@ -614,9 +618,15 @@ def forward_cuda( current_last_token_block_idx = cdiv( seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 + seq_lens_completed_d, seq_lens_completed_p = torch.split( + seq_lens_completed, [num_decodes, num_prefills], + dim=0) last_computed_idx_d, last_computed_idx_p = torch.split( last_computed_token_block_idx, [num_decodes, num_prefills], dim=0) + last_computed_offset_d, last_computed_offset_p = torch.split( + last_computed_token_block_offset, [num_decodes, num_prefills], + dim=0) current_first_idx_d, current_first_idx_p = torch.split( current_first_token_block_idx, [num_decodes, num_prefills], dim=0) @@ -688,7 +698,7 @@ def forward_cuda( if cache_enabled: - def copy_x_to_conv_state(conv_state_block_idx, x_offset, x_end, + def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, query_start_loc): conv_state[conv_state_block_idx, :, 0] = torch.transpose( x[:, query_start_loc + x_offset - 3:query_start_loc + @@ -706,26 +716,32 @@ def copy_x_to_conv_state(conv_state_block_idx, x_offset, x_end, # state_indices_tensor_p[, current_first_idx_p[]: # current_last_idx_p[]] if cache_strategy == "all": + num_blocks_to_fill = current_last_idx_p - last_computed_idx_p # Iterate over all sequences to need prefill - for seq_idx in range(state_indices_tensor_p.shape[0]): - number_full_blocks = seq_lens_pending[ - seq_idx] // mamba_block_size - if seq_lens_pending[seq_idx] % mamba_block_size > 0: - second_last_block_idx = number_full_blocks - else: - second_last_block_idx = number_full_blocks - 1 - #TODO: simpler logic via?: - # if (current_last_idx_p - current_first_idx_p) - # [seq_idx] > 0: - if number_full_blocks > 0: # and seq_lens_pending[ - #seq_idx] % mamba_block_size > 0: # unnecessary? - copy_x_to_conv_state( - state_indices_tensor_p[ - seq_idx, current_first_idx_p[seq_idx]: - current_first_idx_p[seq_idx] + - second_last_block_idx], mamba_block_size, - mamba_block_size * second_last_block_idx, - query_start_loc_p[seq_idx]) + for seq_idx in range(num_prefills): + if num_blocks_to_fill[seq_idx] == 0: + continue + cache_blocks_to_fill = state_indices_tensor_p[seq_idx, + current_first_idx_p[seq_idx]: + current_first_idx_p[seq_idx]+num_blocks_to_fill[seq_idx]] + from_where = x[:,query_start_loc_p[seq_idx]:query_start_loc_p[seq_idx+1]] + # if last computation ended just before the end of block + if last_computed_offset_p[seq_idx] + 3 >= mamba_block_size: + # the current x doesn't have the proper values anymore + # we need to get them from the past state. + # Trick: The indices will go negative: + # e.g. x[:,-3], x[:,-2], x[:,-1] + # so pass x := concat(x, last_state) + # to enable reading from the back + # Note: Maybe always do this and remove "if"? + from_where = torch.concat([from_where, + conv_state[cache_blocks_to_fill[0]]], 1) + copy_to_conv_state(cache_blocks_to_fill, + from_where, + mamba_block_size, + mamba_block_size * num_blocks_to_fill[seq_idx], + mamba_block_size * last_computed_idx_p[seq_idx] + - seq_lens_completed_p[seq_idx]) elif cache_strategy == "last": # i.e. keep two states: either # a) states at the last two block boundaries or @@ -736,19 +752,20 @@ def copy_x_to_conv_state(conv_state_block_idx, x_offset, x_end, # Only store the additional second state if there are # is at least one full block and a remainder. # Otherwise, there is only one state to store - if number_full_blocks > 0 and seq_lens_pending[ - seq_idx] % mamba_block_size > 0: - if seq_lens_pending[seq_idx] % mamba_block_size > 0: - second_last_block_idx = number_full_blocks - else: - second_last_block_idx = number_full_blocks - 1 - copy_x_to_conv_state( - state_indices_tensor_p[ - seq_idx, current_last_idx_p[seq_idx] - - 1:current_last_idx_p[seq_idx]], - mamba_block_size * second_last_block_idx, - mamba_block_size * second_last_block_idx, - query_start_loc_p[seq_idx]) + pass + # if number_full_blocks > 0 and seq_lens_pending[ + # seq_idx] % mamba_block_size > 0: + # if seq_lens_pending[seq_idx] % mamba_block_size > 0: + # second_last_block_idx = number_full_blocks + # else: + # second_last_block_idx = number_full_blocks - 1 + # copy_x_to_conv_state( + # state_indices_tensor_p[ + # seq_idx, current_last_idx_p[seq_idx] - + # 1:current_last_idx_p[seq_idx]], + # mamba_block_size * second_last_block_idx, + # mamba_block_size * second_last_block_idx, + # query_start_loc_p[seq_idx]) hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_B_C_p) From d8e00e3e99e34e7da1a9c978ceeed1ddedd397b3 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Mon, 15 Sep 2025 08:08:20 +0000 Subject: [PATCH 037/105] Conv1D fix Signed-off-by: Thomas Ortner --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 48c30ad606d9..525a3cdd0bc1 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -701,14 +701,14 @@ def forward_cuda( def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, query_start_loc): conv_state[conv_state_block_idx, :, 0] = torch.transpose( - x[:, query_start_loc + x_offset - 3:query_start_loc + - x_end:mamba_block_size], 1, 0) + x[:, query_start_loc + x_offset - 3: + x_end:mamba_block_size], 1, 0) conv_state[conv_state_block_idx, :, 1] = torch.transpose( - x[:, query_start_loc + x_offset - 2:query_start_loc + - x_end:mamba_block_size], 1, 0) + x[:, query_start_loc + x_offset - 2: + x_end:mamba_block_size], 1, 0) conv_state[conv_state_block_idx, :, 2] = torch.transpose( - x[:, query_start_loc + x_offset - 1:query_start_loc + - x_end:mamba_block_size], 1, 0) + x[:, query_start_loc + x_offset - 1: + x_end:mamba_block_size], 1, 0) # initial state: # state_indices_tensor_p[, last_computed_idx_p[]] From 5785a8507ab759717fa6f5277456beebe6b392f1 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 15 Sep 2025 12:34:01 +0000 Subject: [PATCH 038/105] Conv and SSD state storing fixes Signed-off-by: Stanislaw Wozniak --- .../layers/mamba/mamba_mixer2.py | 189 ++++++++++-------- 1 file changed, 107 insertions(+), 82 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 525a3cdd0bc1..e6355a78cf99 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -621,7 +621,7 @@ def forward_cuda( seq_lens_completed_d, seq_lens_completed_p = torch.split( seq_lens_completed, [num_decodes, num_prefills], dim=0) - last_computed_idx_d, last_computed_idx_p = torch.split( + last_state_idx_d, last_state_idx_p = torch.split( last_computed_token_block_idx, [num_decodes, num_prefills], dim=0) last_computed_offset_d, last_computed_offset_p = torch.split( @@ -675,7 +675,7 @@ def forward_cuda( if has_initial_states_p is not None \ and has_initial_states_p.sum() > 0: conv_state_idx_input = state_indices_tensor_p.gather( - 1, last_computed_idx_p.unsqueeze(1)) + 1, last_state_idx_p.unsqueeze(1)) conv_state_idx_output = state_indices_tensor_p.gather( 1, current_last_idx_p.unsqueeze(1)) conv_state[conv_state_idx_output[ @@ -716,7 +716,7 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, # state_indices_tensor_p[, current_first_idx_p[]: # current_last_idx_p[]] if cache_strategy == "all": - num_blocks_to_fill = current_last_idx_p - last_computed_idx_p + num_blocks_to_fill = current_last_idx_p - current_first_idx_p # Iterate over all sequences to need prefill for seq_idx in range(num_prefills): if num_blocks_to_fill[seq_idx] == 0: @@ -728,7 +728,7 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, # if last computation ended just before the end of block if last_computed_offset_p[seq_idx] + 3 >= mamba_block_size: # the current x doesn't have the proper values anymore - # we need to get them from the past state. + # we need to get them from the past PARTIAL state. # Trick: The indices will go negative: # e.g. x[:,-3], x[:,-2], x[:,-1] # so pass x := concat(x, last_state) @@ -740,7 +740,7 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, from_where, mamba_block_size, mamba_block_size * num_blocks_to_fill[seq_idx], - mamba_block_size * last_computed_idx_p[seq_idx] + mamba_block_size * current_first_idx_p[seq_idx] - seq_lens_completed_p[seq_idx]) elif cache_strategy == "last": # i.e. keep two states: either @@ -772,22 +772,14 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, # 3. State Space Model sequence transformation initial_states = None - seq_pad = None + if (has_initial_states_p is not None and prep_initial_states): # making a copy of the states if envs.VLLM_USE_V1: kernel_ssm_indices = state_indices_tensor_p if cache_enabled: - #TODO: Move to attn metadata builder kernel_ssm_indices = state_indices_tensor_p.gather( - 1, last_computed_idx_p.unsqueeze(1)).squeeze(1) - if num_prefills > 1: - # Padding for mamba_chunk_scan_combined - seq_lens_pad = cdiv(seq_lens_pending[num_decodes:], chunk_size) * chunk_size # [6144, 1024, 1024, 256] - seq_offsets_pad = seq_lens_pad.cumsum(0)[:-1] # [6144, 7168, 8192] - seq_pad = seq_lens_pad - seq_lens_pending[num_decodes:] # [ 41, 38, 41, 136] - else: - seq_pad = None + 1, last_state_idx_p.unsqueeze(1)).squeeze(1) initial_states = torch.where( has_initial_states_p[:, None, None, None], ssm_state[kernel_ssm_indices], 0) @@ -827,73 +819,106 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, self.head_dim), state_dtype=ssm_state.dtype) - if cache_enabled and num_prefills == 1: - states, varlen_state = mamba_outputs - - # update ssm states - # - varlen state at FINAL chunk is a (num_prefills, nheads, headdim, dstate) tensor - # states (num_prefills, states_at_INTERMEDIATE_chunks, nheads, headdim, dstate) tensor - # Combine to have all_states (num_prefills, ALL_states, nheads, headdim, dstate) tensor: - all_states = torch.concat( - [states[:, 1:], varlen_state.unsqueeze(1)], - 1) # for num_prefills=1 first returned state is zero - state_stride = mamba_block_size // chunk_size - # states for chunks 0,1,2,3,4 (chunk_size=256) correspond to - # states at blocks 0,0,1,1,2 (block_size=512). - # For first blocks, stride(=2). For last block can't stride. - - # initial state: - # state_indices_tensor_p[, last_computed_idx_p[]] - # new states: - # state_indices_tensor_p[, current_first_idx_p[]: - # current_last_idx_p[]] - - # Code assuming 1 prefill request: - states_at_blocks = torch.concat([ - all_states[:, state_stride - 1:(current_last_idx_p[0] - - current_first_idx_p[0]) * - state_stride:state_stride], - varlen_state.unsqueeze(1) - ], 1) - if cache_strategy == "all": - ssm_state[state_indices_tensor_p[:, current_first_idx_p[0]: - current_last_idx_p[0] + - 1]] = states_at_blocks - elif cache_strategy == "last": - ssm_state[ - state_indices_tensor_p[:, current_last_idx_p[0] - - 1:]] = states_at_blocks[:, -2:] - elif cache_enabled and num_prefills > 1: - if self.prefix == 'model.layers.0.mixer' and attn_metadata.num_prefills == 4: - pass - states, varlen_state = mamba_outputs - last_states_indices = cdiv(seq_lens_pending[num_decodes:], chunk_size).cumsum(0)-1 - all_states = states - #layout: [full states 1, partial state 1, full states 2, partial state 2, ... ] - # update all partial states with correct varlen_states - all_states[0, last_states_indices] = varlen_state - state_stride = mamba_block_size // chunk_size - - states_indices = torch.cat([torch.zeros(1, dtype=last_states_indices.dtype, device=last_states_indices.device), last_states_indices + 1]) - # seq_till_chunk [0, 24, 28, 32, 33] -> e.g. 32:33 is the last one - # seq_till_chunk = torch.concat([torch.tensor([0]), cdiv(seq_lens_pending[num_decodes:], chunk_size).cumsum(0)]) - for seq_idx in range(state_indices_tensor_p.shape[0]): - pass - all_seq_states = all_states[:,states_indices[seq_idx]:states_indices[seq_idx+1]] - states_at_blocks = torch.concat([ - all_seq_states[:, state_stride - 1:(current_last_idx_p[seq_idx] - - current_first_idx_p[seq_idx]) * - state_stride:state_stride], - varlen_state[seq_idx].unsqueeze(0).unsqueeze(0) - ], 1) - if cache_strategy == "all": - ssm_state[state_indices_tensor_p[seq_idx, current_first_idx_p[seq_idx]: - current_last_idx_p[seq_idx] + - 1]] = states_at_blocks - elif cache_strategy == "last": - ssm_state[ - state_indices_tensor_p[:, current_last_idx_p[seq_idx] - - 1:]] = states_at_blocks[:, -2:] + if cache_enabled: + states, _ = mamba_outputs + n_blocks_to_fill = current_last_idx_p - current_first_idx_p + #for seq_idx in range(num_prefills): + # if n_blocks_to_fill[seq_idx] > 0: + #More compact than above: + for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): + cache_blocks_to_fill = state_indices_tensor_p[seq_idx, + current_first_idx_p[seq_idx]: + current_first_idx_p[seq_idx]+n_blocks_to_fill[seq_idx]] + # chunks = [0 1 2 3 4 5 6 ...] + # First aligned chunk would typically be: + # mamba_block_size = 1024, chunk_size = 256 + # 1024 // 256 - 1 --> chunks[3] + # But when last chunk wasn't block aligned: + # - last_computed_token_block_offset[seq_idx] // chunk_size + # e.g. 1000 // 256 -> 3 completed --> store chunk[0] + # e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1) + # e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2) + # e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3) + chunk_stride = mamba_block_size // chunk_size + first_aligned_chunk = chunk_stride - 1 \ + - last_computed_token_block_offset[seq_idx] // chunk_size + from_where = states[0, first_aligned_chunk: + first_aligned_chunk + + n_blocks_to_fill[seq_idx]*chunk_stride:chunk_stride] + ssm_state[cache_blocks_to_fill] = from_where + + #For all seqs, store last state that might be partial: + ssm_state[state_indices_tensor_p.gather(1, + current_last_idx_p.unsqueeze(1)).squeeze(1)] = \ + states[0, last_chunk_p] + + # if cache_enabled and num_prefills == 1: + # states, varlen_state = mamba_outputs + + # # update ssm states + # # - varlen state at FINAL chunk is a (num_prefills, nheads, headdim, dstate) tensor + # # states (num_prefills, states_at_INTERMEDIATE_chunks, nheads, headdim, dstate) tensor + # # Combine to have all_states (num_prefills, ALL_states, nheads, headdim, dstate) tensor: + # all_states = torch.concat( + # [states[:, 1:], varlen_state.unsqueeze(1)], + # 1) # for num_prefills=1 first returned state is zero + # state_stride = mamba_block_size // chunk_size + # # states for chunks 0,1,2,3,4 (chunk_size=256) correspond to + # # states at blocks 0,0,1,1,2 (block_size=512). + # # For first blocks, stride(=2). For last block can't stride. + + # # initial state: + # # state_indices_tensor_p[, last_computed_idx_p[]] + # # new states: + # # state_indices_tensor_p[, current_first_idx_p[]: + # # current_last_idx_p[]] + + # # Code assuming 1 prefill request: + # states_at_blocks = torch.concat([ + # all_states[:, state_stride - 1:(current_last_idx_p[0] - + # current_first_idx_p[0]) * + # state_stride:state_stride], + # varlen_state.unsqueeze(1) + # ], 1) + # if cache_strategy == "all": + # ssm_state[state_indices_tensor_p[:, current_first_idx_p[0]: + # current_last_idx_p[0] + + # 1]] = states_at_blocks + # elif cache_strategy == "last": + # ssm_state[ + # state_indices_tensor_p[:, current_last_idx_p[0] - + # 1:]] = states_at_blocks[:, -2:] + # elif cache_enabled and num_prefills > 1: + # if self.prefix == 'model.layers.0.mixer' and attn_metadata.num_prefills == 4: + # pass + # states, varlen_state = mamba_outputs + # last_states_indices = cdiv(seq_lens_pending[num_decodes:], chunk_size).cumsum(0)-1 + # all_states = states + # #layout: [full states 1, partial state 1, full states 2, partial state 2, ... ] + # # update all partial states with correct varlen_states + # all_states[0, last_states_indices] = varlen_state + # state_stride = mamba_block_size // chunk_size + + # states_indices = torch.cat([torch.zeros(1, dtype=last_states_indices.dtype, device=last_states_indices.device), last_states_indices + 1]) + # # seq_till_chunk [0, 24, 28, 32, 33] -> e.g. 32:33 is the last one + # # seq_till_chunk = torch.concat([torch.tensor([0]), cdiv(seq_lens_pending[num_decodes:], chunk_size).cumsum(0)]) + # for seq_idx in range(state_indices_tensor_p.shape[0]): + # pass + # all_seq_states = all_states[:,states_indices[seq_idx]:states_indices[seq_idx+1]] + # states_at_blocks = torch.concat([ + # all_seq_states[:, state_stride - 1:(current_last_idx_p[seq_idx] - + # current_first_idx_p[seq_idx]) * + # state_stride:state_stride], + # varlen_state[seq_idx].unsqueeze(0).unsqueeze(0) + # ], 1) + # if cache_strategy == "all": + # ssm_state[state_indices_tensor_p[seq_idx, current_first_idx_p[seq_idx]: + # current_last_idx_p[seq_idx] + + # 1]] = states_at_blocks + # elif cache_strategy == "last": + # ssm_state[ + # state_indices_tensor_p[:, current_last_idx_p[seq_idx] - + # 1:]] = states_at_blocks[:, -2:] else: varlen_state = mamba_outputs # update ssm states From d5cab4c14f8c2ec05c881d4042a3043d1c327580 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 15 Sep 2025 14:09:48 +0000 Subject: [PATCH 039/105] Corrected decode. APC should work OK. Signed-off-by: Stanislaw Wozniak --- .../layers/mamba/mamba_mixer2.py | 178 ++++-------------- 1 file changed, 38 insertions(+), 140 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index e6355a78cf99..812dcd34f953 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -491,8 +491,8 @@ def forward_cuda( mamba_block_size = attn_metadata.cache_spec.block_size cache_strategy = attn_metadata.cache_spec.cache_strategy cache_enabled = (cache_strategy != 'disabled') - cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p #TODO: from TPA - last_chunk_p = attn_metadata.last_chunk_p #TODO: from TPA + cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p + last_chunk_p = attn_metadata.last_chunk_p else: conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state @@ -603,21 +603,27 @@ def forward_cuda( torch.roll(attn_metadata.query_start_loc, -1, -1) - attn_metadata.query_start_loc)[:-1] seq_lens_completed = (mamba2_metadata.seq_lens - seq_lens_pending) - # e.g. 16 blocks computed; 0th based indexing -> state[15] - last_computed_token_block_idx = \ - cdiv(seq_lens_completed, mamba_block_size) - 1 last_computed_token_block_offset = \ seq_lens_completed % mamba_block_size - if last_computed_token_block_offset.sum() > 0: - pass # block mis-alignment - # -1 in case it's non-computed and causes later issues with indexing - last_computed_token_block_idx = last_computed_token_block_idx.clamp( - min=0) - current_first_token_block_idx = cdiv(seq_lens_completed + 1, - mamba_block_size) - 1 + + # Indices: last_computed <= current_first <= current_last + # Cases: + # last_computed == current_first if last state was partially + # computed and needs to be updated + # current_first == current_last if no block crossing occurs, and + # only one state will be stored + # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: current_last_token_block_idx = cdiv( seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 + current_first_token_block_idx = cdiv( + seq_lens_completed + 1, mamba_block_size) - 1 + last_computed_token_block_idx = cdiv( + seq_lens_completed, mamba_block_size) - 1 + # -1 in case it's non-computed and causes later issues with indexing + last_computed_token_block_idx = \ + last_computed_token_block_idx.clamp(min=0) + # Split decodes and prefills: seq_lens_completed_d, seq_lens_completed_p = torch.split( seq_lens_completed, [num_decodes, num_prefills], dim=0) @@ -710,25 +716,21 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, x[:, query_start_loc + x_offset - 1: x_end:mamba_block_size], 1, 0) - # initial state: - # state_indices_tensor_p[, last_computed_idx_p[]] - # new states: - # state_indices_tensor_p[, current_first_idx_p[]: - # current_last_idx_p[]] if cache_strategy == "all": - num_blocks_to_fill = current_last_idx_p - current_first_idx_p - # Iterate over all sequences to need prefill - for seq_idx in range(num_prefills): - if num_blocks_to_fill[seq_idx] == 0: - continue + n_blocks_to_fill = current_last_idx_p - current_first_idx_p + # Iterate over sequences that require state storing: + for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): cache_blocks_to_fill = state_indices_tensor_p[seq_idx, current_first_idx_p[seq_idx]: - current_first_idx_p[seq_idx]+num_blocks_to_fill[seq_idx]] - from_where = x[:,query_start_loc_p[seq_idx]:query_start_loc_p[seq_idx+1]] + current_first_idx_p[seq_idx]+ + n_blocks_to_fill[seq_idx]] + from_where = x[:,query_start_loc_p[seq_idx]: + query_start_loc_p[seq_idx+1]] # if last computation ended just before the end of block - if last_computed_offset_p[seq_idx] + 3 >= mamba_block_size: - # the current x doesn't have the proper values anymore - # we need to get them from the past PARTIAL state. + if last_computed_offset_p[seq_idx] + 3 >= \ + mamba_block_size: + # the current x doesn't have the proper values + # We need to get them from the past state. # Trick: The indices will go negative: # e.g. x[:,-3], x[:,-2], x[:,-1] # so pass x := concat(x, last_state) @@ -739,33 +741,9 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, copy_to_conv_state(cache_blocks_to_fill, from_where, mamba_block_size, - mamba_block_size * num_blocks_to_fill[seq_idx], + mamba_block_size * n_blocks_to_fill[seq_idx], mamba_block_size * current_first_idx_p[seq_idx] - seq_lens_completed_p[seq_idx]) - elif cache_strategy == "last": - # i.e. keep two states: either - # a) states at the last two block boundaries or - # b) state at the last block boundary and last state of - # the sequence, which might not be at a block boundary - # Iterate over all sequences to need prefill - for seq_idx in range(state_indices_tensor_p.shape[0]): - # Only store the additional second state if there are - # is at least one full block and a remainder. - # Otherwise, there is only one state to store - pass - # if number_full_blocks > 0 and seq_lens_pending[ - # seq_idx] % mamba_block_size > 0: - # if seq_lens_pending[seq_idx] % mamba_block_size > 0: - # second_last_block_idx = number_full_blocks - # else: - # second_last_block_idx = number_full_blocks - 1 - # copy_x_to_conv_state( - # state_indices_tensor_p[ - # seq_idx, current_last_idx_p[seq_idx] - - # 1:current_last_idx_p[seq_idx]], - # mamba_block_size * second_last_block_idx, - # mamba_block_size * second_last_block_idx, - # query_start_loc_p[seq_idx]) hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_B_C_p) @@ -822,9 +800,7 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, if cache_enabled: states, _ = mamba_outputs n_blocks_to_fill = current_last_idx_p - current_first_idx_p - #for seq_idx in range(num_prefills): - # if n_blocks_to_fill[seq_idx] > 0: - #More compact than above: + # Save states for sequences with more than just the final state: for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): cache_blocks_to_fill = state_indices_tensor_p[seq_idx, current_first_idx_p[seq_idx]: @@ -847,78 +823,10 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, n_blocks_to_fill[seq_idx]*chunk_stride:chunk_stride] ssm_state[cache_blocks_to_fill] = from_where - #For all seqs, store last state that might be partial: + #For all seqs, store the last state (Note: might be partial): ssm_state[state_indices_tensor_p.gather(1, current_last_idx_p.unsqueeze(1)).squeeze(1)] = \ states[0, last_chunk_p] - - # if cache_enabled and num_prefills == 1: - # states, varlen_state = mamba_outputs - - # # update ssm states - # # - varlen state at FINAL chunk is a (num_prefills, nheads, headdim, dstate) tensor - # # states (num_prefills, states_at_INTERMEDIATE_chunks, nheads, headdim, dstate) tensor - # # Combine to have all_states (num_prefills, ALL_states, nheads, headdim, dstate) tensor: - # all_states = torch.concat( - # [states[:, 1:], varlen_state.unsqueeze(1)], - # 1) # for num_prefills=1 first returned state is zero - # state_stride = mamba_block_size // chunk_size - # # states for chunks 0,1,2,3,4 (chunk_size=256) correspond to - # # states at blocks 0,0,1,1,2 (block_size=512). - # # For first blocks, stride(=2). For last block can't stride. - - # # initial state: - # # state_indices_tensor_p[, last_computed_idx_p[]] - # # new states: - # # state_indices_tensor_p[, current_first_idx_p[]: - # # current_last_idx_p[]] - - # # Code assuming 1 prefill request: - # states_at_blocks = torch.concat([ - # all_states[:, state_stride - 1:(current_last_idx_p[0] - - # current_first_idx_p[0]) * - # state_stride:state_stride], - # varlen_state.unsqueeze(1) - # ], 1) - # if cache_strategy == "all": - # ssm_state[state_indices_tensor_p[:, current_first_idx_p[0]: - # current_last_idx_p[0] + - # 1]] = states_at_blocks - # elif cache_strategy == "last": - # ssm_state[ - # state_indices_tensor_p[:, current_last_idx_p[0] - - # 1:]] = states_at_blocks[:, -2:] - # elif cache_enabled and num_prefills > 1: - # if self.prefix == 'model.layers.0.mixer' and attn_metadata.num_prefills == 4: - # pass - # states, varlen_state = mamba_outputs - # last_states_indices = cdiv(seq_lens_pending[num_decodes:], chunk_size).cumsum(0)-1 - # all_states = states - # #layout: [full states 1, partial state 1, full states 2, partial state 2, ... ] - # # update all partial states with correct varlen_states - # all_states[0, last_states_indices] = varlen_state - # state_stride = mamba_block_size // chunk_size - - # states_indices = torch.cat([torch.zeros(1, dtype=last_states_indices.dtype, device=last_states_indices.device), last_states_indices + 1]) - # # seq_till_chunk [0, 24, 28, 32, 33] -> e.g. 32:33 is the last one - # # seq_till_chunk = torch.concat([torch.tensor([0]), cdiv(seq_lens_pending[num_decodes:], chunk_size).cumsum(0)]) - # for seq_idx in range(state_indices_tensor_p.shape[0]): - # pass - # all_seq_states = all_states[:,states_indices[seq_idx]:states_indices[seq_idx+1]] - # states_at_blocks = torch.concat([ - # all_seq_states[:, state_stride - 1:(current_last_idx_p[seq_idx] - - # current_first_idx_p[seq_idx]) * - # state_stride:state_stride], - # varlen_state[seq_idx].unsqueeze(0).unsqueeze(0) - # ], 1) - # if cache_strategy == "all": - # ssm_state[state_indices_tensor_p[seq_idx, current_first_idx_p[seq_idx]: - # current_last_idx_p[seq_idx] + - # 1]] = states_at_blocks - # elif cache_strategy == "last": - # ssm_state[ - # state_indices_tensor_p[:, current_last_idx_p[seq_idx] - - # 1:]] = states_at_blocks[:, -2:] else: varlen_state = mamba_outputs # update ssm states @@ -927,30 +835,20 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, # Process decode requests if has_decode: - if cache_enabled: - # # if at_block_boundary, load states from previous blocks: - # at_block_boundary = mamba2_metadata.seq_lens \ - # % mamba_block_size == 0 - # finished_blocks = attn_metadata.seq_lens[ - # 0] // mamba_block_size #e.g. 1024:2 blocks; 1025:2 blocks - input_block = cdiv( - attn_metadata.seq_lens[:num_decodes], mamba_block_size - ) #e.g. 1024 -> 2nd block, 1025 -> 3rd block - output_block = cdiv( - attn_metadata.seq_lens[:num_decodes] + 1, mamba_block_size - ) #e.g. 1023 -> 2nd block, 1024 -> 3rd block - state_indices_tensor_d_input = \ state_indices_tensor_d.gather(1, - (input_block-1).unsqueeze(1)).squeeze(1) + last_state_idx_d.unsqueeze(1)).squeeze(1) state_indices_tensor_d_output = \ state_indices_tensor_d.gather(1, - (output_block-1).unsqueeze(1)).squeeze(1) + current_last_idx_d.unsqueeze(1)).squeeze(1) + #Note: + # for decode always: current_first_idx_d == current_last_idx_d + # at block boundaries: current_first_idx_d > last_state_idx_d # copy initial state to new location, # as update kernel works in place - if (output_block > input_block).any(): + if (current_last_idx_d > last_state_idx_d).any(): conv_state[state_indices_tensor_d_output] = conv_state[ state_indices_tensor_d_input] else: From e711fc06d15874233dd41ef0a80c3f7a8d1b0678 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Tue, 16 Sep 2025 11:28:28 +0000 Subject: [PATCH 040/105] Cleanup. Signed-off-by: Stanislaw Wozniak --- vllm/config/cache.py | 5 ++-- vllm/model_executor/models/config.py | 28 +++++++++----------- vllm/v1/core/single_type_kv_cache_manager.py | 3 ++- vllm/v1/kv_cache_interface.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 7 ++--- 5 files changed, 24 insertions(+), 22 deletions(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index d52698ee54e4..20488cae0cda 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -24,6 +24,7 @@ BlockSize = Literal[1, 8, 16, 32, 64, 128] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] MambaDType = Literal["auto", "float32"] +MambaCacheStrategy = Literal["disabled", "all", "last"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] @@ -89,9 +90,9 @@ class CacheConfig: models to ensure exact alignment with attention page size.""" mamba_block_size: Optional[int] = None """Size of a contiguous cache block in number of tokens for mamba cache.""" - mamba_cache_strategy: str = "all" + mamba_cache_strategy: MambaCacheStrategy = "all" """Logic for mamba cache: - * disabled - turn of prefix caching + * disabled - turn off prefix caching * all - keep states for all prefixes * last - keep the states of the last full blocks after each request """ diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 8789eabdbe65..e509f6fa6980 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -373,23 +373,21 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # With prefix caching, select attention block size to # optimize for mamba kernel performance - # mamba SSD kernel uses a chunk_size, e.g. 256. Align the block to the kernel: - # use lowest multiple of 256 attention tokens that would fit mamba_page_size - # e.g. mamba page size of 788kB ; attn_1_token 2kB -> fits ~394 tokens - # then round up to a mulitple of 256 -> 512 tokens - # attn_block_size = 512 - # mamba_block_size = 512 (aligned to a multiple of kernel chunk_size) + # mamba SSD kernel uses a chunk_size, e.g. 256 + # Align the block to the kernel: use lowest multiple of chunk_size + # of attention tokens that would fit mamba_page_size: + # e.g. for mamba page size = 788kB + # attn_1_token = 2kB -> fits ~394 tokens + # then round up to a mulitple of 256 -> 512 tokens + # End result: + # attn_block_size = 512 + # mamba_block_size = 512 (aligned to a multiple of chunk_size) chunk_size = model_config.get_mamba_chunk_size() - attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) - attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) + attn_tokens_per_mamba_state = \ + cdiv(mamba_page_size, attn_page_size_1_token) + attn_block_size = chunk_size * \ + cdiv(attn_tokens_per_mamba_state, chunk_size) cache_config.mamba_block_size = attn_block_size - - # This below might be redundant now: - if model_config.max_model_len % attn_block_size != 0: - # Currently HybridCacheManager uses max_model_len for Mamba block - # and requires it to be a multiple of attention block - model_config.max_model_len -= model_config.max_model_len % attn_block_size - print("Adjusting max_model_len to", model_config.max_model_len) else: # Without prefix caching, select minimum valid attention block size # to minimize mamba state padding diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 463780a6fe2e..a13b956d72f1 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -640,7 +640,8 @@ def allocate_new_blocks(self, request_id: str, if num_new_blocks <= 0: return [] else: - if num_new_blocks > 2 and self.kv_cache_spec.cache_strategy == "last": + if num_new_blocks > 2 and \ + self.kv_cache_spec.cache_strategy == "last": # for the last strategy only - allocate 2 blocks: # one for block_size aligned state # and one for the last temporary state diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 57b8932be45a..668d8db2072b 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -10,6 +10,7 @@ from typing_extensions import Self from vllm.config import VllmConfig +from vllm.config.cache import MambaCacheStrategy from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.utils import cdiv, get_dtype_size @@ -194,7 +195,7 @@ class MambaSpec(KVCacheSpec): dtypes: tuple[torch.dtype] page_size_padded: Optional[int] = None mamba_type: str = "mamba2" - cache_strategy: str = "disabled" + cache_strategy: MambaCacheStrategy = "disabled" num_speculative_blocks: int = 0 @property diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cc71ceb55fd4..b520bf08c063 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3703,10 +3703,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") if self.vllm_config.cache_config.enable_prefix_caching: - mamba_block_size = self.vllm_config.cache_config.mamba_block_size + mamba_block_size = \ + self.vllm_config.cache_config.mamba_block_size else: - # Set block_size to max_model_len, so that mamba model will always - # have only one block + # Set block_size to max_model_len, so that mamba model + # will always have only one block mamba_block_size = self.vllm_config.model_config.max_model_len self.vllm_config.cache_config.mamba_cache_strategy = "disabled" From 120fbb787a8524f121e3cb8fa81aa13317090861 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Tue, 16 Sep 2025 17:15:24 +0000 Subject: [PATCH 041/105] CUDA graphs fixes. Signed-off-by: Stanislaw Wozniak --- .../models/language/generation/test_hybrid.py | 1 + .../layers/mamba/mamba_mixer2.py | 7 ++--- vllm/model_executor/models/config.py | 7 +++++ vllm/v1/attention/backends/mamba2_attn.py | 26 +++++++------------ 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index d0e42062099e..c3c59d8f766d 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -53,6 +53,7 @@ V0_UNSUPPORTED_MODELS = [ "LiquidAI/LFM2-1.2B", + "ibm-granite/granite-4.0-tiny-preview", ] FP32_STATE_MODELS = [ diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 812dcd34f953..6243fbc06096 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -847,9 +847,10 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, # at block boundaries: current_first_idx_d > last_state_idx_d # copy initial state to new location, - # as update kernel works in place - if (current_last_idx_d > last_state_idx_d).any(): - conv_state[state_indices_tensor_d_output] = conv_state[ + # as update kernel works in place + #if (current_last_idx_d > last_state_idx_d).any(): + # (skip IF as it breaks CUDA graphs) + conv_state[state_indices_tensor_d_output] = conv_state[ state_indices_tensor_d_input] else: # Without caching, read and write in-place to the same blocks: diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index e509f6fa6980..13be7b1356aa 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -369,6 +369,13 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=model_config.max_model_len, ).page_size_bytes + # Cascade attn doesn't work with Mamba: + # * enable_prefix_caching = True -> fails + # * enable_prefix_caching = False -> cascade attention is triggered, + # but always terminates early, not raising any exception + # Thus, it's more effective to disable the cascade attention logic: + model_config.disable_cascade_attn = True + if cache_config.enable_prefix_caching: # With prefix caching, select attention block size to # optimize for mamba kernel performance diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 8d73e9b88c56..f4c109fd3d9e 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -148,6 +148,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") + if kv_cache_spec.cache_strategy == "all": + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, + cdiv(vllm_config.model_config.max_model_len, + kv_cache_spec.block_size)), + dtype=torch.int32, + device=device, + ) def build(self, common_prefix_len: int, @@ -171,22 +179,8 @@ def build(self, state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] else: - # Return a tensor of shape (#requests, #blocks for longest request) - # filled in with cached and newly allocated blocks for each request - cache_block_size = self.kv_cache_spec.block_size - seq_lens_cpu = common_attn_metadata.seq_lens_cpu - block_table_bounds_cpu = (seq_lens_cpu + cache_block_size - - 1) // cache_block_size - max_num_blocks = block_table_bounds_cpu.max() - paged_kv_indices = common_attn_metadata.block_table_tensor[:, : - max_num_blocks] - if self.kv_cache_spec.cache_strategy == "last": - # TODO: The "last" strategy is not fully implemented yet - # In the "last" strategy, the allocator puts 2 block in front - # For easiness of handling, we move them to be two last in list - paged_kv_indices = torch.roll(paged_kv_indices, - max_num_blocks.item() - 2, -1) - state_indices_tensor = paged_kv_indices + # Return a tensor of shape (#requests, #max blocks) + state_indices_tensor = common_attn_metadata.block_table_tensor num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( From 22c47f9eef6d485a89ca6d2b279e08b475419228 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 17 Sep 2025 12:24:28 +0000 Subject: [PATCH 042/105] Precommit fixes. Signed-off-by: Stanislaw Wozniak --- .../layers/mamba/mamba_mixer2.py | 78 +++++++++---------- .../layers/mamba/ops/mamba_ssm.py | 4 +- vllm/model_executor/models/config.py | 6 +- vllm/v1/attention/backends/mamba2_attn.py | 7 +- vllm/v1/kv_cache_interface.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 7 +- 6 files changed, 53 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 2498c0b64c0e..796aebdd8c78 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -647,18 +647,17 @@ def forward_cuda( # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: current_last_token_block_idx = cdiv( seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 - current_first_token_block_idx = cdiv( - seq_lens_completed + 1, mamba_block_size) - 1 - last_computed_token_block_idx = cdiv( - seq_lens_completed, mamba_block_size) - 1 + current_first_token_block_idx = cdiv(seq_lens_completed + 1, + mamba_block_size) - 1 + last_computed_token_block_idx = cdiv(seq_lens_completed, + mamba_block_size) - 1 # -1 in case it's non-computed and causes later issues with indexing last_computed_token_block_idx = \ last_computed_token_block_idx.clamp(min=0) # Split decodes and prefills: seq_lens_completed_d, seq_lens_completed_p = torch.split( - seq_lens_completed, [num_decodes, num_prefills], - dim=0) + seq_lens_completed, [num_decodes, num_prefills], dim=0) last_state_idx_d, last_state_idx_p = torch.split( last_computed_token_block_idx, [num_decodes, num_prefills], dim=0) @@ -736,28 +735,28 @@ def forward_cuda( if cache_enabled: - def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, - query_start_loc): + def copy_to_conv_state(conv_state_block_idx, x, x_offset, + x_end, query_start_loc): conv_state[conv_state_block_idx, :, 0] = torch.transpose( - x[:, query_start_loc + x_offset - 3: - x_end:mamba_block_size], 1, 0) + x[:, query_start_loc + x_offset - + 3:x_end:mamba_block_size], 1, 0) conv_state[conv_state_block_idx, :, 1] = torch.transpose( - x[:, query_start_loc + x_offset - 2: - x_end:mamba_block_size], 1, 0) + x[:, query_start_loc + x_offset - + 2:x_end:mamba_block_size], 1, 0) conv_state[conv_state_block_idx, :, 2] = torch.transpose( - x[:, query_start_loc + x_offset - 1: - x_end:mamba_block_size], 1, 0) + x[:, query_start_loc + x_offset - + 1:x_end:mamba_block_size], 1, 0) if cache_strategy == "all": n_blocks_to_fill = current_last_idx_p - current_first_idx_p # Iterate over sequences that require state storing: for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): - cache_blocks_to_fill = state_indices_tensor_p[seq_idx, - current_first_idx_p[seq_idx]: - current_first_idx_p[seq_idx]+ - n_blocks_to_fill[seq_idx]] - from_where = x[:,query_start_loc_p[seq_idx]: - query_start_loc_p[seq_idx+1]] + cache_blocks_to_fill = state_indices_tensor_p[ + seq_idx, current_first_idx_p[seq_idx]: + current_first_idx_p[seq_idx] + + n_blocks_to_fill[seq_idx]] + from_where = x[:, query_start_loc_p[seq_idx]: + query_start_loc_p[seq_idx + 1]] # if last computation ended just before the end of block if last_computed_offset_p[seq_idx] + 3 >= \ mamba_block_size: @@ -768,14 +767,14 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, # so pass x := concat(x, last_state) # to enable reading from the back # Note: Maybe always do this and remove "if"? - from_where = torch.concat([from_where, - conv_state[cache_blocks_to_fill[0]]], 1) - copy_to_conv_state(cache_blocks_to_fill, - from_where, - mamba_block_size, + from_where = torch.concat([ + from_where, conv_state[cache_blocks_to_fill[0]] + ], 1) + copy_to_conv_state( + cache_blocks_to_fill, from_where, mamba_block_size, mamba_block_size * n_blocks_to_fill[seq_idx], - mamba_block_size * current_first_idx_p[seq_idx] - - seq_lens_completed_p[seq_idx]) + mamba_block_size * current_first_idx_p[seq_idx] - + seq_lens_completed_p[seq_idx]) hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_B_C_p) @@ -834,9 +833,10 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, n_blocks_to_fill = current_last_idx_p - current_first_idx_p # Save states for sequences with more than just the final state: for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): - cache_blocks_to_fill = state_indices_tensor_p[seq_idx, - current_first_idx_p[seq_idx]: - current_first_idx_p[seq_idx]+n_blocks_to_fill[seq_idx]] + cache_blocks_to_fill = state_indices_tensor_p[ + seq_idx, current_first_idx_p[seq_idx]: + current_first_idx_p[seq_idx] + + n_blocks_to_fill[seq_idx]] # chunks = [0 1 2 3 4 5 6 ...] # First aligned chunk would typically be: # mamba_block_size = 1024, chunk_size = 256 @@ -850,13 +850,13 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, chunk_stride = mamba_block_size // chunk_size first_aligned_chunk = chunk_stride - 1 \ - last_computed_token_block_offset[seq_idx] // chunk_size - from_where = states[0, first_aligned_chunk: - first_aligned_chunk + - n_blocks_to_fill[seq_idx]*chunk_stride:chunk_stride] + from_where = states[ + 0, first_aligned_chunk:first_aligned_chunk + + n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride] ssm_state[cache_blocks_to_fill] = from_where #For all seqs, store the last state (Note: might be partial): - ssm_state[state_indices_tensor_p.gather(1, + ssm_state[state_indices_tensor_p.gather(1, current_last_idx_p.unsqueeze(1)).squeeze(1)] = \ states[0, last_chunk_p] else: @@ -869,21 +869,21 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, x_end, if has_decode: if cache_enabled: state_indices_tensor_d_input = \ - state_indices_tensor_d.gather(1, + state_indices_tensor_d.gather(1, last_state_idx_d.unsqueeze(1)).squeeze(1) state_indices_tensor_d_output = \ - state_indices_tensor_d.gather(1, + state_indices_tensor_d.gather(1, current_last_idx_d.unsqueeze(1)).squeeze(1) - #Note: + #Note: # for decode always: current_first_idx_d == current_last_idx_d # at block boundaries: current_first_idx_d > last_state_idx_d # copy initial state to new location, - # as update kernel works in place + # as update kernel works in place #if (current_last_idx_d > last_state_idx_d).any(): # (skip IF as it breaks CUDA graphs) conv_state[state_indices_tensor_d_output] = conv_state[ - state_indices_tensor_d_input] + state_indices_tensor_d_input] else: # Without caching, read and write in-place to the same blocks: state_indices_tensor_d_input = state_indices_tensor_d diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 30d59d45813b..585ae72fe565 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -111,7 +111,7 @@ def _selective_scan_update_kernel( dst_state_batch_indices_ptr += pid_b dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64) dst_state_ptr = state_ptr + (dst_state_batch_idx * stride_state_batch + - pid_h * stride_state_head) + pid_h * stride_state_head) state_batch_indices_ptr += pid_b state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += (state_batch_idx * stride_state_batch + @@ -138,7 +138,7 @@ def _selective_scan_update_kernel( state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) dst_state_ptrs = dst_state_ptr + (offs_m[:, None] * stride_state_dim + - offs_n[None, :] * stride_state_dstate) + offs_n[None, :] * stride_state_dstate) x_ptrs = x_ptr + offs_m * stride_x_dim dt_ptrs = dt_ptr + offs_m * stride_dt_dim if HAS_DT_BIAS: diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 13be7b1356aa..586ffac91b0b 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -371,13 +371,13 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # Cascade attn doesn't work with Mamba: # * enable_prefix_caching = True -> fails - # * enable_prefix_caching = False -> cascade attention is triggered, + # * enable_prefix_caching = False -> cascade attention is triggered, # but always terminates early, not raising any exception # Thus, it's more effective to disable the cascade attention logic: model_config.disable_cascade_attn = True if cache_config.enable_prefix_caching: - # With prefix caching, select attention block size to + # With prefix caching, select attention block size to # optimize for mamba kernel performance # mamba SSD kernel uses a chunk_size, e.g. 256 @@ -404,7 +404,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # that would work (note: FA is currently not compatible # with mamba layers, use FlashInfer instead). attn_block_size = 16 * cdiv(mamba_page_size, - 16 * attn_page_size_1_token) + 16 * attn_page_size_1_token) # override attention block size if either (a) the # user has not set it or (b) the user has set it diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index f4c109fd3d9e..3a57bafc8fc0 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec -from vllm.utils import cdiv + def _query_start_loc_to_chunk_indices_offsets( query_start_loc: torch.Tensor, chunk_size: int, @@ -139,6 +139,7 @@ class Mamba2AttentionMetadata: token_chunk_offset_ptr: Optional[torch.tensor] = None cache_spec: Optional[MambaSpec] = None + class Mamba2AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): @@ -150,8 +151,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], "chunk_size needs to be set in the model config for Mamba2 models") if kv_cache_spec.cache_strategy == "all": self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, - cdiv(vllm_config.model_config.max_model_len, + (self.decode_cudagraph_max_bs, + cdiv(vllm_config.model_config.max_model_len, kv_cache_spec.block_size)), dtype=torch.int32, device=device, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index a90fb65c3498..f9d0ebb81c99 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -207,7 +207,7 @@ def page_size_bytes(self) -> int: return self.page_size_padded return page_size - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: if self.cache_strategy == "last": # Keeps the last full block and one non-full block state: return 2 * self.page_size_bytes diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7f8afb3e4031..c678b65b75d7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3920,8 +3920,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: if self.vllm_config.cache_config.enable_prefix_caching: mamba_block_size = \ self.vllm_config.cache_config.mamba_block_size - else: - # Set block_size to max_model_len, so that mamba model + else: + # Set block_size to max_model_len, so that mamba model # will always have only one block mamba_block_size = self.vllm_config.model_config.max_model_len self.vllm_config.cache_config.mamba_cache_strategy = "disabled" @@ -3934,7 +3934,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: shapes=mamba_module.get_state_shape(), dtypes=mamba_module.get_state_dtype(), block_size=mamba_block_size, - cache_strategy=self.vllm_config.cache_config.mamba_cache_strategy, + cache_strategy=self.vllm_config.cache_config. + mamba_cache_strategy, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type, num_speculative_blocks=( From f56fe5c82ccdf376967c95a05984a04f37a90817 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 17 Sep 2025 14:10:09 +0000 Subject: [PATCH 043/105] Precommit fixes. Signed-off-by: Stanislaw Wozniak --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 2 +- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 3 ++- vllm/model_executor/models/config.py | 1 - vllm/model_executor/models/granitemoehybrid.py | 1 - 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 796aebdd8c78..a70f6f5d3fbd 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -634,7 +634,7 @@ def forward_cuda( seq_lens_pending = ( torch.roll(attn_metadata.query_start_loc, -1, -1) - attn_metadata.query_start_loc)[:-1] - seq_lens_completed = (mamba2_metadata.seq_lens - seq_lens_pending) + seq_lens_completed = (attn_metadata.seq_lens - seq_lens_pending) last_computed_token_block_offset = \ seq_lens_completed % mamba_block_size diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 585ae72fe565..367c0173e374 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -117,7 +117,8 @@ def _selective_scan_update_kernel( state_ptr += (state_batch_idx * stride_state_batch + pid_h * stride_state_head) else: - dst_state_ptr = state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head + dst_state_ptr = state_ptr + pid_b * stride_state_batch + \ + pid_h * stride_state_head state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 586ffac91b0b..be76758bf2d4 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -302,7 +302,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: return model_config = vllm_config.model_config - cache_config = vllm_config.cache_config compilation_config = vllm_config.compilation_config # TODO(tdoublep): remove as full cuda graph support is added diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 4c7c8fa7c86d..9cd968fc0725 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -587,7 +587,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config self.quant_config = vllm_config.quant_config From 15bb92104317873108532f2752faa26363cd949a Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 17 Sep 2025 14:40:18 +0000 Subject: [PATCH 044/105] Precommit fixes. Signed-off-by: Stanislaw Wozniak --- vllm/v1/attention/backends/mamba2_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 3a57bafc8fc0..9d2fd1a58876 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -143,7 +143,7 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: MambaSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() From 58906779ed5af23ae4ac7a9d30321ae95a163a66 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 17 Sep 2025 11:57:07 -0400 Subject: [PATCH 045/105] Fix CUDA graph issue Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 44 +++------ vllm/v1/attention/backends/mamba2_attn.py | 90 ++++++++++++++++++- 2 files changed, 102 insertions(+), 32 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index a70f6f5d3fbd..dab9019a38c4 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -39,7 +39,7 @@ from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 @@ -630,45 +630,25 @@ def forward_cuda( if has_prefill else None) if envs.VLLM_USE_V1 and cache_enabled: - # Additional variables used by caching logic: - seq_lens_pending = ( - torch.roll(attn_metadata.query_start_loc, -1, -1) - - attn_metadata.query_start_loc)[:-1] - seq_lens_completed = (attn_metadata.seq_lens - seq_lens_pending) - last_computed_token_block_offset = \ - seq_lens_completed % mamba_block_size - - # Indices: last_computed <= current_first <= current_last - # Cases: - # last_computed == current_first if last state was partially - # computed and needs to be updated - # current_first == current_last if no block crossing occurs, and - # only one state will be stored - # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: - current_last_token_block_idx = cdiv( - seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 - current_first_token_block_idx = cdiv(seq_lens_completed + 1, - mamba_block_size) - 1 - last_computed_token_block_idx = cdiv(seq_lens_completed, - mamba_block_size) - 1 - # -1 in case it's non-computed and causes later issues with indexing - last_computed_token_block_idx = \ - last_computed_token_block_idx.clamp(min=0) - # Split decodes and prefills: seq_lens_completed_d, seq_lens_completed_p = torch.split( - seq_lens_completed, [num_decodes, num_prefills], dim=0) + attn_metadata.seq_lens_completed, [num_decodes, num_prefills], + dim=0) last_state_idx_d, last_state_idx_p = torch.split( - last_computed_token_block_idx, [num_decodes, num_prefills], + attn_metadata.last_computed_token_block_idx, + [num_decodes, num_prefills], dim=0) last_computed_offset_d, last_computed_offset_p = torch.split( - last_computed_token_block_offset, [num_decodes, num_prefills], + attn_metadata.last_computed_token_block_offset, + [num_decodes, num_prefills], dim=0) current_first_idx_d, current_first_idx_p = torch.split( - current_first_token_block_idx, [num_decodes, num_prefills], + attn_metadata.current_first_token_block_idx, + [num_decodes, num_prefills], dim=0) current_last_idx_d, current_last_idx_p = torch.split( - current_last_token_block_idx, [num_decodes, num_prefills], + attn_metadata.current_last_token_block_idx, + [num_decodes, num_prefills], dim=0) # Preallocate output tensor to avoid memcpy cost for merging prefill @@ -848,6 +828,8 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, # e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2) # e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3) chunk_stride = mamba_block_size // chunk_size + last_computed_token_block_offset = \ + attn_metadata.last_computed_token_block_offset first_aligned_chunk = chunk_stride - 1 \ - last_computed_token_block_offset[seq_idx] // chunk_size from_where = states[ diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 9d2fd1a58876..307c22a15bdb 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -14,7 +14,7 @@ BaseMambaAttentionMetadataBuilder) from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import MambaSpec def _query_start_loc_to_chunk_indices_offsets( @@ -131,6 +131,11 @@ class Mamba2AttentionMetadata: last_chunk_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] + current_last_token_block_idx: torch.Tensor + current_first_token_block_idx: torch.Tensor + last_computed_token_block_idx: torch.Tensor + seq_lens_completed: torch.Tensor + last_computed_token_block_offset: torch.Tensor # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None @@ -157,6 +162,31 @@ def __init__(self, kv_cache_spec: MambaSpec, layer_names: list[str], dtype=torch.int32, device=device, ) + self.current_last_token_block_idx = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + self.current_first_token_block_idx = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + self.last_computed_token_block_idx = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + self.seq_lens_completed = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + self.last_computed_token_block_offset = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) def build(self, common_prefix_len: int, @@ -188,6 +218,30 @@ def build(self, common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) + mamba_block_size = self.kv_cache_spec.block_size + seq_lens_pending = ( + torch.roll(common_attn_metadata.query_start_loc, -1, -1) - + common_attn_metadata.query_start_loc)[:-1] + seq_lens_completed = (common_attn_metadata.seq_lens - seq_lens_pending) + last_computed_token_block_offset = \ + seq_lens_completed % mamba_block_size + # Indices: last_computed <= current_first <= current_last + # Cases: + # last_computed == current_first if last state was partially + # computed and needs to be updated + # current_first == current_last if no block crossing occurs, and + # only one state will be stored + # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: + current_last_token_block_idx = cdiv( + seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 + current_first_token_block_idx = cdiv(seq_lens_completed + 1, + mamba_block_size) - 1 + last_computed_token_block_idx = cdiv(seq_lens_completed, + mamba_block_size) - 1 + # -1 in case it's non-computed and causes later issues with indexing + last_computed_token_block_idx = \ + last_computed_token_block_idx.clamp(min=0) + # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: #[batch,] @@ -272,6 +326,35 @@ def build(self, state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID + self.current_last_token_block_idx[:num_decodes].copy_( + current_last_token_block_idx, non_blocking=True) + current_last_token_block_idx = \ + self.current_last_token_block_idx[:num_input_tokens] + current_last_token_block_idx[num_decodes:] = 0 + + self.current_first_token_block_idx[:num_decodes].copy_( + current_first_token_block_idx, non_blocking=True) + current_first_token_block_idx = \ + self.current_first_token_block_idx[:num_input_tokens] + current_first_token_block_idx[num_decodes:] = 0 + + self.last_computed_token_block_idx[:num_decodes].copy_( + last_computed_token_block_idx, non_blocking=True) + last_computed_token_block_idx = \ + self.last_computed_token_block_idx[:num_input_tokens] + last_computed_token_block_idx[num_decodes:] = 0 + + self.seq_lens_completed[:num_decodes].copy_(seq_lens_completed, + non_blocking=True) + seq_lens_completed = self.seq_lens_completed[:num_input_tokens] + seq_lens_completed[num_decodes:] = 0 + + self.last_computed_token_block_offset[:num_decodes].copy_( + last_computed_token_block_offset, non_blocking=True) + last_computed_token_block_offset = \ + self.last_computed_token_block_offset[:num_input_tokens] + last_computed_token_block_offset[num_decodes:] = 0 + attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, @@ -289,5 +372,10 @@ def build(self, state_indices_tensor=state_indices_tensor, cu_chunk_seqlen_p=cu_chunk_seqlen_p, last_chunk_p=last_chunk_p, + current_last_token_block_idx=current_last_token_block_idx, + current_first_token_block_idx=current_first_token_block_idx, + last_computed_token_block_idx=last_computed_token_block_idx, + seq_lens_completed=seq_lens_completed, + last_computed_token_block_offset=last_computed_token_block_offset, ) return attn_metadata From df23ee4454e23b731eff431cb3aa887c946a4f78 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 17 Sep 2025 12:07:08 -0400 Subject: [PATCH 046/105] pre-commit Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/mamba2_attn.py | 7 ++++--- vllm/v1/core/single_type_kv_cache_manager.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 307c22a15bdb..2404a1cdda3e 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -14,7 +14,7 @@ BaseMambaAttentionMetadataBuilder) from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import MambaSpec +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec def _query_start_loc_to_chunk_indices_offsets( @@ -148,12 +148,13 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): - def __init__(self, kv_cache_spec: MambaSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") + assert isinstance(kv_cache_spec, MambaSpec) if kv_cache_spec.cache_strategy == "all": self.state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs, @@ -204,7 +205,7 @@ def build(self, prep_initial_states = False cu_chunk_seqlen_p = None last_chunk_p = None - + assert isinstance(self.kv_cache_spec, MambaSpec) if self.kv_cache_spec.cache_strategy == "disabled": # Always return just a single block per each request: state_indices_tensor = common_attn_metadata.block_table_tensor[:, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index a13b956d72f1..e7ad70245292 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -573,6 +573,7 @@ def remove_skipped_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> int: + assert isinstance(self.kv_cache_spec, MambaSpec) if self.kv_cache_spec.cache_strategy == "disabled": return 0 From 4ef40233835687c880016d3fd02a889825435675 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Thu, 18 Sep 2025 06:25:28 +0000 Subject: [PATCH 047/105] Tests Signed-off-by: Thomas Ortner --- tests/v1/core/test_APC_hybrid_models.py | 388 ++++++++++++++++++++++++ 1 file changed, 388 insertions(+) create mode 100644 tests/v1/core/test_APC_hybrid_models.py diff --git a/tests/v1/core/test_APC_hybrid_models.py b/tests/v1/core/test_APC_hybrid_models.py new file mode 100644 index 000000000000..56b6946ac7f8 --- /dev/null +++ b/tests/v1/core/test_APC_hybrid_models.py @@ -0,0 +1,388 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compare the outputs of HF and vLLM when using greedy sampling. + +It tests automated prefix caching (APC). APC can be enabled by +enable_prefix_caching=True. + +Run `pytest tests/basic_correctness/test_APC_hybrid_models.py`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +import time +import torch + +from vllm.platforms import current_platform +from vllm.utils import STR_BACKEND_ENV_VAR + +from tests.models.utils import check_logprobs_close, check_outputs_equal + +if TYPE_CHECKING: + from ...conftest import HfRunner, VllmRunner + +MODELS = [ + "ibm-granite/granite-4.0-tiny-preview", +] + +def _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size): + return {'model_name': model, + 'mamba_ssm_cache_dtype': mamba_ssm_cache_dtype, + 'enable_prefix_caching': False, + 'enforce_eager': enforce_eager, + 'max_model_len': max_model_len, + 'dtype': dtype, + 'tensor_parallel_size': tensor_parallel_size, + 'disable_cascade_attn': True, ## not verified yet + 'disable_log_stats': False, ## collect APC stats + 'gpu_memory_utilization': 0.4} + +def _get_vLLM_output_logprobs(vllm_runner, kwargs, prompts, max_tokens, num_logprobs, num_repetitions=1): + outs = [] + with vllm_runner( + **kwargs + ) as vllm_model: + for _ in range(num_repetitions): + outs.append(vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs)) + + return outs + +def _get_vLLM_output(vllm_runner, kwargs, prompts, max_tokens, num_repetitions=1): + outs = [] + with vllm_runner( + **kwargs + ) as vllm_model: + for _ in range(num_repetitions): + outs.append(vllm_model.generate_greedy(prompts, max_tokens)) + + return outs + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_single_prompt( + hf_runner: HfRunner, + vllm_runner: VllmRunner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + mamba_ssm_cache_dtype: str, + tensor_parallel_size: int, + num_logprobs: int, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Checks exact match decode vllm runner with and without prefix caching + """ + MULTIPLE = 120 + + # Common prefix. + prefix = MULTIPLE * example_prompts[0] + + # Sample prompts. + generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] + + with monkeypatch.context() as m: + # Ensure that the testcase is using V1 + m.setenv("VLLM_USE_V1", "1") + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) + vllm_outputs_no_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, num_logprobs)[0] + + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + check_logprobs_close( + outputs_0_lst=vllm_outputs_no_cache, + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_single_prompt_mamba_size_alignment( + hf_runner: HfRunner, + vllm_runner: VllmRunner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + mamba_ssm_cache_dtype: str, + tensor_parallel_size: int, + num_logprobs: int, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Checks exact match decode vllm runner with and without prefix caching + """ + MULTIPLE = 120 + + # Common prefix. + prefix = MULTIPLE * example_prompts[0] + + # Sample prompts. + generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] + + with monkeypatch.context() as m: + # Ensure that the testcase is using V1 + m.setenv("VLLM_USE_V1", "1") + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) + vllm_outputs_no_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, num_logprobs)[0] + + vllm_runner_kwargs['enable_prefix_caching'] = True + with vllm_runner( + **vllm_runner_kwargs + ) as vllm_model: + # Retrieve mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size + + vllm_outputs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, num_logprobs, n_repetitions) + + for multiple in [1, 3, 7]: + for offsets in [3, + mamba_block_size//4 - 3, + mamba_block_size//4, + mamba_block_size//4 + 3, + mamba_block_size//2 - 3, + mamba_block_size//2, + mamba_block_size//2 + 3, + mamba_block_size - 3 + ]: + + vllm_runner_kwargs['max_num_batched_tokens'] = multiple * mamba_block_size - offsets + vllm_outputs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, num_logprobs, n_repetitions) + + # Check whether the output logits of the model is the same using APC + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + check_logprobs_close( + outputs_0_lst=vllm_outputs_no_cache, + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multiple_prompts_all_cached_output_logprobs( + hf_runner: HfRunner, + vllm_runner: VllmRunner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + mamba_ssm_cache_dtype: str, + tensor_parallel_size: int, + num_logprobs: int, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Checks exact match decode vllm runner with and without prefix caching + """ + MULTIPLE = 120 + + # Common prefix. + prefix = MULTIPLE * example_prompts[0] + + # Sample prompts. + generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] + + with monkeypatch.context() as m: + # Ensure that the testcase is using V1 + m.setenv("VLLM_USE_V1", "1") + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) + vllm_outputs_no_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs)[0] + + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + check_logprobs_close( + outputs_0_lst=vllm_outputs_no_cache, + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multiple_prompts_partial_cached_output_logprobs( + hf_runner: HfRunner, + vllm_runner: VllmRunner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + mamba_ssm_cache_dtype: str, + tensor_parallel_size: int, + num_logprobs: int, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Checks exact match decode vllm runner with and without prefix caching + """ + MULTIPLE = 120 + + # Common prefix. + prefix = MULTIPLE * example_prompts[0] + + # Sample prompts. + generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] + + with monkeypatch.context() as m: + # Ensure that the testcase is using V1 + m.setenv("VLLM_USE_V1", "1") + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) + vllm_outputs_no_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs)[0] + + # Cache only part of all the prompts + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_partial_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs)[0] + + check_logprobs_close( + outputs_0_lst=vllm_outputs_no_cache[:3], + outputs_1_lst=vllm_outputs_partial_cache, + name_0="vllm_no_cache", + name_1=f"vllm_partial_cache", + ) + + vllm_outputs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + check_logprobs_close( + outputs_0_lst=vllm_outputs_no_cache, + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_specific_prompts_output_logprobs( + hf_runner: HfRunner, + vllm_runner: VllmRunner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + mamba_ssm_cache_dtype: str, + tensor_parallel_size: int, + num_logprobs: int, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Checks exact match decode vllm runner with and without prefix caching + """ + + generated_prompts = [ + "Hello, my name is John Smith and I work at " * 100, + "The president of the United States is " * 200, + "The capital of France is something like" * 200, + "The future of AI is " * 300, + ] + + with monkeypatch.context() as m: + # Ensure that the testcase is using V1 + m.setenv("VLLM_USE_V1", "1") + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) + vllm_outputs_logprobs_no_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs)[0] + # vllm_outputs_no_cache = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens)[0] + + # Cache only part of all the prompts + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_logprobs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs, n_repetitions) + # vllm_outputs_cache_rep = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, n_repetitions) + + # for r_idx, (vllm_outputs_logprobs_cache_itn, vllm_outputs_cache_itn) in enumerate(zip(vllm_outputs_logprobs_cache_rep, vllm_outputs_cache_rep)): + for r_idx, vllm_outputs_logprobs_cache_itn in enumerate(vllm_outputs_logprobs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + check_logprobs_close( + outputs_0_lst=vllm_outputs_logprobs_no_cache, + outputs_1_lst=vllm_outputs_logprobs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + # check_outputs_equal( + # outputs_0_lst=vllm_outputs_no_cache, + # outputs_1_lst=vllm_outputs_cache_itn, + # name_0="vllm_no_cache", + # name_1=f"vllm_cache_it_{r_idx + 1}", + # ) \ No newline at end of file From ff493431bfc3bf811978dc717421dc08919b13ad Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 18 Sep 2025 06:38:29 +0000 Subject: [PATCH 048/105] Precommit fixes. Signed-off-by: Stanislaw Wozniak --- tests/v1/core/test_APC_hybrid_models.py | 297 +++++++++++++++--------- 1 file changed, 182 insertions(+), 115 deletions(-) diff --git a/tests/v1/core/test_APC_hybrid_models.py b/tests/v1/core/test_APC_hybrid_models.py index 56b6946ac7f8..03f89a6262cf 100644 --- a/tests/v1/core/test_APC_hybrid_models.py +++ b/tests/v1/core/test_APC_hybrid_models.py @@ -13,13 +13,8 @@ from typing import TYPE_CHECKING import pytest -import time -import torch -from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR - -from tests.models.utils import check_logprobs_close, check_outputs_equal +from tests.models.utils import check_logprobs_close if TYPE_CHECKING: from ...conftest import HfRunner, VllmRunner @@ -28,38 +23,52 @@ "ibm-granite/granite-4.0-tiny-preview", ] -def _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size): - return {'model_name': model, - 'mamba_ssm_cache_dtype': mamba_ssm_cache_dtype, - 'enable_prefix_caching': False, - 'enforce_eager': enforce_eager, - 'max_model_len': max_model_len, - 'dtype': dtype, - 'tensor_parallel_size': tensor_parallel_size, - 'disable_cascade_attn': True, ## not verified yet - 'disable_log_stats': False, ## collect APC stats - 'gpu_memory_utilization': 0.4} - -def _get_vLLM_output_logprobs(vllm_runner, kwargs, prompts, max_tokens, num_logprobs, num_repetitions=1): + +def _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, + max_model_len, dtype, tensor_parallel_size): + return { + 'model_name': model, + 'mamba_ssm_cache_dtype': mamba_ssm_cache_dtype, + 'enable_prefix_caching': False, + 'enforce_eager': enforce_eager, + 'max_model_len': max_model_len, + 'dtype': dtype, + 'tensor_parallel_size': tensor_parallel_size, + 'disable_cascade_attn': True, ## not verified yet + 'disable_log_stats': False, ## collect APC stats + 'gpu_memory_utilization': 0.4 + } + + +def _get_vLLM_output_logprobs(vllm_runner, + kwargs, + prompts, + max_tokens, + num_logprobs, + num_repetitions=1): outs = [] - with vllm_runner( - **kwargs - ) as vllm_model: + with vllm_runner(**kwargs) as vllm_model: for _ in range(num_repetitions): - outs.append(vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs)) - + outs.append( + vllm_model.generate_greedy_logprobs(prompts, max_tokens, + num_logprobs)) + return outs -def _get_vLLM_output(vllm_runner, kwargs, prompts, max_tokens, num_repetitions=1): + +def _get_vLLM_output(vllm_runner, + kwargs, + prompts, + max_tokens, + num_repetitions=1): outs = [] - with vllm_runner( - **kwargs - ) as vllm_model: + with vllm_runner(**kwargs) as vllm_model: for _ in range(num_repetitions): outs.append(vllm_model.generate_greedy(prompts, max_tokens)) - + return outs + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @@ -88,28 +97,37 @@ def test_single_prompt( Checks exact match decode vllm runner with and without prefix caching """ MULTIPLE = 120 - + # Common prefix. prefix = MULTIPLE * example_prompts[0] # Sample prompts. generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] - + with monkeypatch.context() as m: # Ensure that the testcase is using V1 m.setenv("VLLM_USE_V1", "1") - - max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_no_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, num_logprobs)[0] - + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, + mamba_ssm_cache_dtype, + enforce_eager, + max_model_len, dtype, + tensor_parallel_size) + vllm_outputs_no_cache = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], + max_tokens, num_logprobs)[0] + vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, num_logprobs, n_repetitions) - + vllm_outputs_cache_rep = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], + max_tokens, num_logprobs, n_repetitions) + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): # In the first repetition, the caches are filled # In the second repetition, these caches are reused - + check_logprobs_close( outputs_0_lst=vllm_outputs_no_cache, outputs_1_lst=vllm_outputs_cache_itn, @@ -117,6 +135,7 @@ def test_single_prompt( name_1=f"vllm_cache_it_{r_idx + 1}", ) + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @@ -145,49 +164,59 @@ def test_single_prompt_mamba_size_alignment( Checks exact match decode vllm runner with and without prefix caching """ MULTIPLE = 120 - + # Common prefix. prefix = MULTIPLE * example_prompts[0] # Sample prompts. generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] - + with monkeypatch.context() as m: # Ensure that the testcase is using V1 m.setenv("VLLM_USE_V1", "1") - - max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_no_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, num_logprobs)[0] - + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, + mamba_ssm_cache_dtype, + enforce_eager, + max_model_len, dtype, + tensor_parallel_size) + vllm_outputs_no_cache = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], + max_tokens, num_logprobs)[0] + vllm_runner_kwargs['enable_prefix_caching'] = True - with vllm_runner( - **vllm_runner_kwargs - ) as vllm_model: + with vllm_runner(**vllm_runner_kwargs) as vllm_model: # Retrieve mamba state block size - mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size - - vllm_outputs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, num_logprobs, n_repetitions) - + mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ + mamba_block_size + + vllm_outputs_cache_rep = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], + max_tokens, num_logprobs, n_repetitions) + for multiple in [1, 3, 7]: - for offsets in [3, - mamba_block_size//4 - 3, - mamba_block_size//4, - mamba_block_size//4 + 3, - mamba_block_size//2 - 3, - mamba_block_size//2, - mamba_block_size//2 + 3, - mamba_block_size - 3 - ]: - - vllm_runner_kwargs['max_num_batched_tokens'] = multiple * mamba_block_size - offsets - vllm_outputs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, num_logprobs, n_repetitions) - - # Check whether the output logits of the model is the same using APC - for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + for offsets in [ + 3, mamba_block_size // 4 - 3, mamba_block_size // 4, + mamba_block_size // 4 + 3, mamba_block_size // 2 - 3, + mamba_block_size // 2, mamba_block_size // 2 + 3, + mamba_block_size - 3 + ]: + + vllm_runner_kwargs[ + 'max_num_batched_tokens'] = multiple * mamba_block_size - \ + offsets + vllm_outputs_cache_rep = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], + max_tokens, num_logprobs, n_repetitions) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate( + vllm_outputs_cache_rep): # In the first repetition, the caches are filled # In the second repetition, these caches are reused - + check_logprobs_close( outputs_0_lst=vllm_outputs_no_cache, outputs_1_lst=vllm_outputs_cache_itn, @@ -195,6 +224,7 @@ def test_single_prompt_mamba_size_alignment( name_1=f"vllm_cache_it_{r_idx + 1}", ) + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @@ -223,35 +253,45 @@ def test_multiple_prompts_all_cached_output_logprobs( Checks exact match decode vllm runner with and without prefix caching """ MULTIPLE = 120 - + # Common prefix. prefix = MULTIPLE * example_prompts[0] # Sample prompts. generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] - + with monkeypatch.context() as m: # Ensure that the testcase is using V1 m.setenv("VLLM_USE_V1", "1") - - max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_no_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs)[0] - + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, + mamba_ssm_cache_dtype, + enforce_eager, + max_model_len, dtype, + tensor_parallel_size) + vllm_outputs_no_cache = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs)[0] + vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs, n_repetitions) - + vllm_outputs_cache_rep = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs, n_repetitions) + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): # In the first repetition, the caches are filled # In the second repetition, these caches are reused - + check_logprobs_close( outputs_0_lst=vllm_outputs_no_cache, outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", ) - + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @@ -280,45 +320,57 @@ def test_multiple_prompts_partial_cached_output_logprobs( Checks exact match decode vllm runner with and without prefix caching """ MULTIPLE = 120 - + # Common prefix. prefix = MULTIPLE * example_prompts[0] # Sample prompts. generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] - + with monkeypatch.context() as m: # Ensure that the testcase is using V1 m.setenv("VLLM_USE_V1", "1") - - max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_no_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs)[0] - + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, + mamba_ssm_cache_dtype, + enforce_eager, + max_model_len, dtype, + tensor_parallel_size) + vllm_outputs_no_cache = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs)[0] + # Cache only part of all the prompts vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_partial_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs)[0] - + vllm_outputs_partial_cache = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, + num_logprobs)[0] + check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache[:3], - outputs_1_lst=vllm_outputs_partial_cache, - name_0="vllm_no_cache", - name_1=f"vllm_partial_cache", - ) - - vllm_outputs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs, n_repetitions) - + outputs_0_lst=vllm_outputs_no_cache[:3], + outputs_1_lst=vllm_outputs_partial_cache, + name_0="vllm_no_cache", + name_1="vllm_partial_cache", + ) + + vllm_outputs_cache_rep = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs, n_repetitions) + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): # In the first repetition, the caches are filled # In the second repetition, these caches are reused - + check_logprobs_close( outputs_0_lst=vllm_outputs_no_cache, outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", ) - + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @@ -346,43 +398,58 @@ def test_specific_prompts_output_logprobs( """ Checks exact match decode vllm runner with and without prefix caching """ - + generated_prompts = [ "Hello, my name is John Smith and I work at " * 100, "The president of the United States is " * 200, "The capital of France is something like" * 200, "The future of AI is " * 300, ] - + with monkeypatch.context() as m: # Ensure that the testcase is using V1 m.setenv("VLLM_USE_V1", "1") - - max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_logprobs_no_cache = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs)[0] - # vllm_outputs_no_cache = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens)[0] - + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, + mamba_ssm_cache_dtype, + enforce_eager, + max_model_len, dtype, + tensor_parallel_size) + vllm_outputs_logprobs_no_cache = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs)[0] + # vllm_outputs_no_cache = _get_vLLM_output( + # vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens)[0] + # Cache only part of all the prompts vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_logprobs_cache_rep = _get_vLLM_output_logprobs(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs, n_repetitions) - # vllm_outputs_cache_rep = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, n_repetitions) - - # for r_idx, (vllm_outputs_logprobs_cache_itn, vllm_outputs_cache_itn) in enumerate(zip(vllm_outputs_logprobs_cache_rep, vllm_outputs_cache_rep)): - for r_idx, vllm_outputs_logprobs_cache_itn in enumerate(vllm_outputs_logprobs_cache_rep): + vllm_outputs_logprobs_cache_rep = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs, n_repetitions) + # vllm_outputs_cache_rep = _get_vLLM_output( + # vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + # n_repetitions) + + # for r_idx, (vllm_outputs_logprobs_cache_itn, vllm_outputs_cache_itn) + # in enumerate(zip( + # vllm_outputs_logprobs_cache_rep, vllm_outputs_cache_rep)): + for r_idx, vllm_outputs_logprobs_cache_itn in enumerate( + vllm_outputs_logprobs_cache_rep): # In the first repetition, the caches are filled # In the second repetition, these caches are reused - + check_logprobs_close( outputs_0_lst=vllm_outputs_logprobs_no_cache, outputs_1_lst=vllm_outputs_logprobs_cache_itn, name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", ) - + # check_outputs_equal( # outputs_0_lst=vllm_outputs_no_cache, # outputs_1_lst=vllm_outputs_cache_itn, # name_0="vllm_no_cache", # name_1=f"vllm_cache_it_{r_idx + 1}", - # ) \ No newline at end of file + # ) From c4255bbd6c4f60f552b88a3435b6ed3d2426573c Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 18 Sep 2025 06:50:56 +0000 Subject: [PATCH 049/105] Precommit fixes. Signed-off-by: Stanislaw Wozniak --- tests/v1/core/test_APC_hybrid_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/core/test_APC_hybrid_models.py b/tests/v1/core/test_APC_hybrid_models.py index 03f89a6262cf..f5cf6ee95278 100644 --- a/tests/v1/core/test_APC_hybrid_models.py +++ b/tests/v1/core/test_APC_hybrid_models.py @@ -148,7 +148,7 @@ def test_single_prompt( @pytest.mark.parametrize("num_logprobs", [5]) def test_single_prompt_mamba_size_alignment( hf_runner: HfRunner, - vllm_runner: VllmRunner, + vllm_runner, example_prompts, model: str, dtype: str, From 7649489b4f7f19360f23859b23fa004f24461a71 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 18 Sep 2025 05:05:29 -0400 Subject: [PATCH 050/105] Fix bug in scan kernel when to reading previous state. Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 207c440b0ff6..186f771a1018 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -215,7 +215,7 @@ def _chunk_scan_fwd_kernel( mask=pid_c >= 1, other=-1) - if HAS_INITSTATES: + if HAS_INITSTATES and (seq_idx != seq_idx_prev): prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head prev_states_hdim = stride_init_states_hdim prev_states_dstate = stride_init_states_dstate @@ -250,7 +250,12 @@ def _chunk_scan_fwd_kernel( (offs_k_dstate[None, :] < dstate), other=0.0) - if (seq_idx != seq_idx_prev and HAS_INITSTATES) or pid_c > 0: + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + # if no init states AND starting a new sequence, we need zeros + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), + dtype=C_ptr.dtype.element_ty) + else: + # otherwise read the previous state prev_states_ptrs = prev_states_ptr \ + offs_n[None, :] * prev_states_hdim \ + offs_k_dstate[:, None] * prev_states_dstate @@ -259,9 +264,6 @@ def _chunk_scan_fwd_kernel( (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), - dtype=C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] @@ -274,16 +276,16 @@ def _chunk_scan_fwd_kernel( mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) - if (seq_idx != seq_idx_prev and HAS_INITSTATES) or pid_c > 0: + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), + dtype=C_ptr.dtype.element_ty) + else: prev_states = tl.load( prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), - dtype=C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K prev_states_ptrs += BLOCK_SIZE_K From 1bb59d7e6dbd5e0413f1cf832526bfd558ab3171 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 18 Sep 2025 14:49:25 -0400 Subject: [PATCH 051/105] Remove BLOCK_H=1 from list of tuneable configurations. Co-authored-by: Chih-Chieh-Yang Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 448c7970b64b..4ad3b348658d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -15,7 +15,6 @@ @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_H': 1}), triton.Config({'BLOCK_SIZE_H': 2}), triton.Config({'BLOCK_SIZE_H': 4}), triton.Config({'BLOCK_SIZE_H': 8}), From 25f8a2718e9ff8cb6f81bdd74fca0dc07fe6ecbd Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Fri, 19 Sep 2025 09:02:00 +0000 Subject: [PATCH 052/105] Reworked testcase test_multiple_prompts_partial_cached_output_logprobs Signed-off-by: Thomas Ortner --- tests/v1/core/test_APC_hybrid_models.py | 384 +++++++++++------------- 1 file changed, 167 insertions(+), 217 deletions(-) diff --git a/tests/v1/core/test_APC_hybrid_models.py b/tests/v1/core/test_APC_hybrid_models.py index f5cf6ee95278..f927e69d5487 100644 --- a/tests/v1/core/test_APC_hybrid_models.py +++ b/tests/v1/core/test_APC_hybrid_models.py @@ -45,28 +45,17 @@ def _get_vLLM_output_logprobs(vllm_runner, prompts, max_tokens, num_logprobs, - num_repetitions=1): + num_repetitions=1, + vllm_model=None): outs = [] - with vllm_runner(**kwargs) as vllm_model: - for _ in range(num_repetitions): - outs.append( - vllm_model.generate_greedy_logprobs(prompts, max_tokens, - num_logprobs)) + if vllm_model is None: + vllm_model = vllm_runner(**kwargs) + for _ in range(num_repetitions): + outs.append( + vllm_model.generate_greedy_logprobs(prompts, max_tokens, + num_logprobs)) - return outs - - -def _get_vLLM_output(vllm_runner, - kwargs, - prompts, - max_tokens, - num_repetitions=1): - outs = [] - with vllm_runner(**kwargs) as vllm_model: - for _ in range(num_repetitions): - outs.append(vllm_model.generate_greedy(prompts, max_tokens)) - - return outs + return outs, vllm_model @pytest.mark.parametrize("model", MODELS) @@ -104,36 +93,32 @@ def test_single_prompt( # Sample prompts. generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] - with monkeypatch.context() as m: - # Ensure that the testcase is using V1 - m.setenv("VLLM_USE_V1", "1") - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, - mamba_ssm_cache_dtype, - enforce_eager, - max_model_len, dtype, - tensor_parallel_size) - vllm_outputs_no_cache = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], - max_tokens, num_logprobs)[0] - - vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_cache_rep = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], - max_tokens, num_logprobs, n_repetitions) - - for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache, - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, + mamba_ssm_cache_dtype, + enforce_eager, + max_model_len, dtype, + tensor_parallel_size) + vllm_outputs_no_cache = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], + max_tokens, num_logprobs)[0][0] + + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], + max_tokens, num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + check_logprobs_close( + outputs_0_lst=vllm_outputs_no_cache, + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) @pytest.mark.parametrize("model", MODELS) @@ -171,58 +156,50 @@ def test_single_prompt_mamba_size_alignment( # Sample prompts. generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] - with monkeypatch.context() as m: - # Ensure that the testcase is using V1 - m.setenv("VLLM_USE_V1", "1") - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, - mamba_ssm_cache_dtype, - enforce_eager, - max_model_len, dtype, - tensor_parallel_size) - vllm_outputs_no_cache = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], - max_tokens, num_logprobs)[0] - - vllm_runner_kwargs['enable_prefix_caching'] = True - with vllm_runner(**vllm_runner_kwargs) as vllm_model: - # Retrieve mamba state block size - mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ - mamba_block_size - - vllm_outputs_cache_rep = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], - max_tokens, num_logprobs, n_repetitions) - - for multiple in [1, 3, 7]: - for offsets in [ - 3, mamba_block_size // 4 - 3, mamba_block_size // 4, - mamba_block_size // 4 + 3, mamba_block_size // 2 - 3, - mamba_block_size // 2, mamba_block_size // 2 + 3, - mamba_block_size - 3 - ]: - - vllm_runner_kwargs[ - 'max_num_batched_tokens'] = multiple * mamba_block_size - \ - offsets - vllm_outputs_cache_rep = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], - max_tokens, num_logprobs, n_repetitions) - - # Check alignment of the output logits when using APC - for r_idx, vllm_outputs_cache_itn in enumerate( - vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache, - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, + mamba_ssm_cache_dtype, + enforce_eager, + max_model_len, dtype, + tensor_parallel_size) + vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], + max_tokens, num_logprobs)[0][0] + + vllm_runner_kwargs['enable_prefix_caching'] = True + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + # Retrieve the default mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ + mamba_block_size + + for multiple in [1, 3, 7]: + for offsets in [ + 3, mamba_block_size // 4 - 3, mamba_block_size // 4, + mamba_block_size // 4 + 3, mamba_block_size // 2 - 3, + mamba_block_size // 2, mamba_block_size // 2 + 3, + mamba_block_size - 3 + ]: + + vllm_runner_kwargs[ + 'max_num_batched_tokens'] = multiple * mamba_block_size - \ + offsets + vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], + max_tokens, num_logprobs, n_repetitions) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate( + vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + check_logprobs_close( + outputs_0_lst=vllm_outputs_no_cache, + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) @pytest.mark.parametrize("model", MODELS) @@ -260,36 +237,32 @@ def test_multiple_prompts_all_cached_output_logprobs( # Sample prompts. generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] - with monkeypatch.context() as m: - # Ensure that the testcase is using V1 - m.setenv("VLLM_USE_V1", "1") - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, - mamba_ssm_cache_dtype, - enforce_eager, - max_model_len, dtype, - tensor_parallel_size) - vllm_outputs_no_cache = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs)[0] - - vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_cache_rep = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs, n_repetitions) - - for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache, - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, + mamba_ssm_cache_dtype, + enforce_eager, + max_model_len, dtype, + tensor_parallel_size) + vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs)[0][0] + + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + check_logprobs_close( + outputs_0_lst=vllm_outputs_no_cache, + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) @pytest.mark.parametrize("model", MODELS) @@ -327,49 +300,45 @@ def test_multiple_prompts_partial_cached_output_logprobs( # Sample prompts. generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] - with monkeypatch.context() as m: - # Ensure that the testcase is using V1 - m.setenv("VLLM_USE_V1", "1") - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, - mamba_ssm_cache_dtype, - enforce_eager, - max_model_len, dtype, - tensor_parallel_size) - vllm_outputs_no_cache = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs)[0] - - # Cache only part of all the prompts - vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_partial_cache = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, - num_logprobs)[0] + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, + mamba_ssm_cache_dtype, + enforce_eager, + max_model_len, dtype, + tensor_parallel_size) + vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs)[0][0] + + # Cache only part of all the prompts + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_partial_cache, vllm_model = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, + num_logprobs) + + check_logprobs_close( + outputs_0_lst=vllm_outputs_no_cache[:3], + outputs_1_lst=vllm_outputs_partial_cache[0], + name_0="vllm_no_cache", + name_1="vllm_partial_cache", + ) + + vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs, n_repetitions, vllm_model=vllm_model) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache[:3], - outputs_1_lst=vllm_outputs_partial_cache, + outputs_0_lst=vllm_outputs_no_cache, + outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", - name_1="vllm_partial_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", ) - vllm_outputs_cache_rep = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs, n_repetitions) - - for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache, - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @@ -406,50 +375,31 @@ def test_specific_prompts_output_logprobs( "The future of AI is " * 300, ] - with monkeypatch.context() as m: - # Ensure that the testcase is using V1 - m.setenv("VLLM_USE_V1", "1") - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, - mamba_ssm_cache_dtype, - enforce_eager, - max_model_len, dtype, - tensor_parallel_size) - vllm_outputs_logprobs_no_cache = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs)[0] - # vllm_outputs_no_cache = _get_vLLM_output( - # vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens)[0] - - # Cache only part of all the prompts - vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_logprobs_cache_rep = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs, n_repetitions) - # vllm_outputs_cache_rep = _get_vLLM_output( - # vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - # n_repetitions) - - # for r_idx, (vllm_outputs_logprobs_cache_itn, vllm_outputs_cache_itn) - # in enumerate(zip( - # vllm_outputs_logprobs_cache_rep, vllm_outputs_cache_rep)): - for r_idx, vllm_outputs_logprobs_cache_itn in enumerate( - vllm_outputs_logprobs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - check_logprobs_close( - outputs_0_lst=vllm_outputs_logprobs_no_cache, - outputs_1_lst=vllm_outputs_logprobs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) - - # check_outputs_equal( - # outputs_0_lst=vllm_outputs_no_cache, - # outputs_1_lst=vllm_outputs_cache_itn, - # name_0="vllm_no_cache", - # name_1=f"vllm_cache_it_{r_idx + 1}", - # ) + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, + mamba_ssm_cache_dtype, + enforce_eager, + max_model_len, dtype, + tensor_parallel_size) + vllm_outputs_logprobs_no_cache, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs)[0][0] + + # Cache only part of all the prompts + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_logprobs_cache_rep, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_logprobs_cache_itn in enumerate( + vllm_outputs_logprobs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + check_logprobs_close( + outputs_0_lst=vllm_outputs_logprobs_no_cache, + outputs_1_lst=vllm_outputs_logprobs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) \ No newline at end of file From 8515ee2088d332600f0ca519771cfb07ea834524 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 19 Sep 2025 10:53:33 +0000 Subject: [PATCH 053/105] Fixed indexing for SSM state storing when bs>1 Signed-off-by: Stanislaw Wozniak --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index dab9019a38c4..7495f22e6b85 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -830,8 +830,11 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, chunk_stride = mamba_block_size // chunk_size last_computed_token_block_offset = \ attn_metadata.last_computed_token_block_offset - first_aligned_chunk = chunk_stride - 1 \ - - last_computed_token_block_offset[seq_idx] // chunk_size + first_aligned_chunk = \ + torch.concat([torch.zeros(1, dtype=last_chunk_p.dtype, \ + device=last_chunk_p.device), last_chunk_p])[seq_idx] + 1 \ + + chunk_stride - 1 \ + - last_computed_token_block_offset[seq_idx] // chunk_size from_where = states[ 0, first_aligned_chunk:first_aligned_chunk + n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride] From ebba273aa0e105cde0e58717b06e6b7d73c06edb Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 19 Sep 2025 11:08:23 +0000 Subject: [PATCH 054/105] Fixed indexing for SSM state storing when bs>1 Signed-off-by: Stanislaw Wozniak --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 7495f22e6b85..17ac87d4fdc7 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -832,7 +832,7 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, attn_metadata.last_computed_token_block_offset first_aligned_chunk = \ torch.concat([torch.zeros(1, dtype=last_chunk_p.dtype, \ - device=last_chunk_p.device), last_chunk_p])[seq_idx] + 1 \ + device=last_chunk_p.device), last_chunk_p + 1])[seq_idx] \ + chunk_stride - 1 \ - last_computed_token_block_offset[seq_idx] // chunk_size from_where = states[ From 48faed8b800fb5a7d4bd729870eb3f9b4514d56d Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 19 Sep 2025 07:38:45 -0400 Subject: [PATCH 055/105] Fix test_specific_prompts_output_logprobs Signed-off-by: Thomas Parnell --- tests/v1/core/test_APC_hybrid_models.py | 71 ++++++++++++------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/tests/v1/core/test_APC_hybrid_models.py b/tests/v1/core/test_APC_hybrid_models.py index f927e69d5487..5787d0dada50 100644 --- a/tests/v1/core/test_APC_hybrid_models.py +++ b/tests/v1/core/test_APC_hybrid_models.py @@ -95,19 +95,19 @@ def test_single_prompt( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, - mamba_ssm_cache_dtype, - enforce_eager, - max_model_len, dtype, - tensor_parallel_size) - vllm_outputs_no_cache = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], - max_tokens, num_logprobs)[0][0] + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, + enforce_eager, max_model_len, + dtype, tensor_parallel_size) + vllm_outputs_no_cache = _get_vLLM_output_logprobs(vllm_runner, + vllm_runner_kwargs, + [generated_prompts[0]], + max_tokens, + num_logprobs)[0][0] vllm_runner_kwargs['enable_prefix_caching'] = True vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], - max_tokens, num_logprobs, n_repetitions) + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, + num_logprobs, n_repetitions) for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): # In the first repetition, the caches are filled @@ -158,14 +158,12 @@ def test_single_prompt_mamba_size_alignment( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, - mamba_ssm_cache_dtype, - enforce_eager, - max_model_len, dtype, - tensor_parallel_size) + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, + enforce_eager, max_model_len, + dtype, tensor_parallel_size) vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], - max_tokens, num_logprobs)[0][0] + vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, + num_logprobs)[0][0] vllm_runner_kwargs['enable_prefix_caching'] = True with vllm_runner(**vllm_runner_kwargs) as vllm_model: @@ -239,11 +237,9 @@ def test_multiple_prompts_all_cached_output_logprobs( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, - mamba_ssm_cache_dtype, - enforce_eager, - max_model_len, dtype, - tensor_parallel_size) + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, + enforce_eager, max_model_len, + dtype, tensor_parallel_size) vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs)[0][0] @@ -302,11 +298,9 @@ def test_multiple_prompts_partial_cached_output_logprobs( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, - mamba_ssm_cache_dtype, - enforce_eager, - max_model_len, dtype, - tensor_parallel_size) + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, + enforce_eager, max_model_len, + dtype, tensor_parallel_size) vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs)[0][0] @@ -325,8 +319,13 @@ def test_multiple_prompts_partial_cached_output_logprobs( ) vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs, n_repetitions, vllm_model=vllm_model) + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + vllm_model=vllm_model) for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): # In the first repetition, the caches are filled @@ -377,14 +376,12 @@ def test_specific_prompts_output_logprobs( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, - mamba_ssm_cache_dtype, - enforce_eager, - max_model_len, dtype, - tensor_parallel_size) + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, + enforce_eager, max_model_len, + dtype, tensor_parallel_size) vllm_outputs_logprobs_no_cache, _ = _get_vLLM_output_logprobs( vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs)[0][0] + num_logprobs) # Cache only part of all the prompts vllm_runner_kwargs['enable_prefix_caching'] = True @@ -398,8 +395,8 @@ def test_specific_prompts_output_logprobs( # In the second repetition, these caches are reused check_logprobs_close( - outputs_0_lst=vllm_outputs_logprobs_no_cache, + outputs_0_lst=vllm_outputs_logprobs_no_cache[0], outputs_1_lst=vllm_outputs_logprobs_cache_itn, name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", - ) \ No newline at end of file + ) From 0ce539ee3e0c97840335776b080c60c0564c2f92 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 19 Sep 2025 16:28:13 -0400 Subject: [PATCH 056/105] Fix tests Signed-off-by: Thomas Parnell --- tests/v1/core/test_APC_hybrid_models.py | 52 +++++++++---------------- 1 file changed, 19 insertions(+), 33 deletions(-) diff --git a/tests/v1/core/test_APC_hybrid_models.py b/tests/v1/core/test_APC_hybrid_models.py index 5787d0dada50..4ca76d877536 100644 --- a/tests/v1/core/test_APC_hybrid_models.py +++ b/tests/v1/core/test_APC_hybrid_models.py @@ -87,26 +87,21 @@ def test_single_prompt( """ MULTIPLE = 120 - # Common prefix. - prefix = MULTIPLE * example_prompts[0] - # Sample prompts. - generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] + generated_prompts = [MULTIPLE * example_prompts[0]] max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_no_cache = _get_vLLM_output_logprobs(vllm_runner, - vllm_runner_kwargs, - [generated_prompts[0]], - max_tokens, - num_logprobs)[0][0] + vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs) vllm_runner_kwargs['enable_prefix_caching'] = True vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs, n_repetitions) for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): @@ -114,7 +109,7 @@ def test_single_prompt( # In the second repetition, these caches are reused check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache, + outputs_0_lst=vllm_outputs_no_cache[0], outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", @@ -150,11 +145,8 @@ def test_single_prompt_mamba_size_alignment( """ MULTIPLE = 120 - # Common prefix. - prefix = MULTIPLE * example_prompts[0] - # Sample prompts. - generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] + generated_prompts = [MULTIPLE * example_prompts[0]] max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) @@ -162,8 +154,8 @@ def test_single_prompt_mamba_size_alignment( enforce_eager, max_model_len, dtype, tensor_parallel_size) vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], max_tokens, - num_logprobs)[0][0] + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs) vllm_runner_kwargs['enable_prefix_caching'] = True with vllm_runner(**vllm_runner_kwargs) as vllm_model: @@ -183,8 +175,8 @@ def test_single_prompt_mamba_size_alignment( 'max_num_batched_tokens'] = multiple * mamba_block_size - \ offsets vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, [generated_prompts[0]], - max_tokens, num_logprobs, n_repetitions) + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs, n_repetitions) # Check alignment of the output logits when using APC for r_idx, vllm_outputs_cache_itn in enumerate( @@ -193,7 +185,7 @@ def test_single_prompt_mamba_size_alignment( # In the second repetition, these caches are reused check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache, + outputs_0_lst=vllm_outputs_no_cache[0], outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", @@ -229,11 +221,8 @@ def test_multiple_prompts_all_cached_output_logprobs( """ MULTIPLE = 120 - # Common prefix. - prefix = MULTIPLE * example_prompts[0] - # Sample prompts. - generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] + generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) @@ -242,7 +231,7 @@ def test_multiple_prompts_all_cached_output_logprobs( dtype, tensor_parallel_size) vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs)[0][0] + num_logprobs) vllm_runner_kwargs['enable_prefix_caching'] = True vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( @@ -254,7 +243,7 @@ def test_multiple_prompts_all_cached_output_logprobs( # In the second repetition, these caches are reused check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache, + outputs_0_lst=vllm_outputs_no_cache[0], outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", @@ -290,11 +279,8 @@ def test_multiple_prompts_partial_cached_output_logprobs( """ MULTIPLE = 120 - # Common prefix. - prefix = MULTIPLE * example_prompts[0] - # Sample prompts. - generated_prompts = [prefix + prompt for prompt in example_prompts[1:]] + generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) @@ -303,7 +289,7 @@ def test_multiple_prompts_partial_cached_output_logprobs( dtype, tensor_parallel_size) vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs)[0][0] + num_logprobs) # Cache only part of all the prompts vllm_runner_kwargs['enable_prefix_caching'] = True @@ -312,7 +298,7 @@ def test_multiple_prompts_partial_cached_output_logprobs( num_logprobs) check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache[:3], + outputs_0_lst=vllm_outputs_no_cache[0][:3], outputs_1_lst=vllm_outputs_partial_cache[0], name_0="vllm_no_cache", name_1="vllm_partial_cache", @@ -332,7 +318,7 @@ def test_multiple_prompts_partial_cached_output_logprobs( # In the second repetition, these caches are reused check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache, + outputs_0_lst=vllm_outputs_no_cache[0], outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", From c06246b15508a573617c54fafce68f64a04aa960 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Mon, 22 Sep 2025 05:17:53 -0400 Subject: [PATCH 057/105] Fused causal_conv1d. Signed-off-by: Thomas Ortner Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 92 ++----- .../layers/mamba/ops/causal_conv1d.py | 234 ++++++++++++++++-- 2 files changed, 236 insertions(+), 90 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 17ac87d4fdc7..256f184c0f39 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -685,21 +685,17 @@ def forward_cuda( mamba2_metadata = update_metadata(x, query_start_loc_p, mamba2_metadata) - kernel_conv1d_indices = state_indices_tensor_p + BLOCK_M = list(mamba2_metadata.nums_dict.keys())[0] if cache_enabled: - # Kernel expects to have the initial state here - # and overwrites it -> use final state location - if has_initial_states_p is not None \ - and has_initial_states_p.sum() > 0: - conv_state_idx_input = state_indices_tensor_p.gather( - 1, last_state_idx_p.unsqueeze(1)) - conv_state_idx_output = state_indices_tensor_p.gather( - 1, current_last_idx_p.unsqueeze(1)) - conv_state[conv_state_idx_output[ - has_initial_states_p]] = conv_state[ - conv_state_idx_input[has_initial_states_p]] - kernel_conv1d_indices = state_indices_tensor_p.gather( - 1, current_last_idx_p.unsqueeze(1)).squeeze(1) + n_blocks_to_fill = current_last_idx_p - current_first_idx_p + stride_state_indices = state_indices_tensor_p.shape[-1] + else: + current_first_idx_p = None + current_last_idx_p = None + seq_lens_completed_p = None + last_state_idx_p = None + n_blocks_to_fill = None + stride_state_indices = 1 hidden_states_B_C_p = causal_conv1d_fn( x, @@ -708,54 +704,18 @@ def forward_cuda( activation=self.activation, conv_states=conv_state, has_initial_state=has_initial_states_p, - cache_indices=kernel_conv1d_indices, + cache_indices=state_indices_tensor_p, + n_blocks_to_fill=n_blocks_to_fill, + current_first_idx=current_first_idx_p, + current_last_idx=current_last_idx_p, + last_state_idx=last_state_idx_p, + seq_lens_completed=seq_lens_completed_p, + stride_cache_chunk=mamba_block_size // BLOCK_M, + stride_state_indices=stride_state_indices, metadata=mamba2_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] - if cache_enabled: - - def copy_to_conv_state(conv_state_block_idx, x, x_offset, - x_end, query_start_loc): - conv_state[conv_state_block_idx, :, 0] = torch.transpose( - x[:, query_start_loc + x_offset - - 3:x_end:mamba_block_size], 1, 0) - conv_state[conv_state_block_idx, :, 1] = torch.transpose( - x[:, query_start_loc + x_offset - - 2:x_end:mamba_block_size], 1, 0) - conv_state[conv_state_block_idx, :, 2] = torch.transpose( - x[:, query_start_loc + x_offset - - 1:x_end:mamba_block_size], 1, 0) - - if cache_strategy == "all": - n_blocks_to_fill = current_last_idx_p - current_first_idx_p - # Iterate over sequences that require state storing: - for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): - cache_blocks_to_fill = state_indices_tensor_p[ - seq_idx, current_first_idx_p[seq_idx]: - current_first_idx_p[seq_idx] + - n_blocks_to_fill[seq_idx]] - from_where = x[:, query_start_loc_p[seq_idx]: - query_start_loc_p[seq_idx + 1]] - # if last computation ended just before the end of block - if last_computed_offset_p[seq_idx] + 3 >= \ - mamba_block_size: - # the current x doesn't have the proper values - # We need to get them from the past state. - # Trick: The indices will go negative: - # e.g. x[:,-3], x[:,-2], x[:,-1] - # so pass x := concat(x, last_state) - # to enable reading from the back - # Note: Maybe always do this and remove "if"? - from_where = torch.concat([ - from_where, conv_state[cache_blocks_to_fill[0]] - ], 1) - copy_to_conv_state( - cache_blocks_to_fill, from_where, mamba_block_size, - mamba_block_size * n_blocks_to_fill[seq_idx], - mamba_block_size * current_first_idx_p[seq_idx] - - seq_lens_completed_p[seq_idx]) - hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_B_C_p) @@ -842,7 +802,7 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, #For all seqs, store the last state (Note: might be partial): ssm_state[state_indices_tensor_p.gather(1, - current_last_idx_p.unsqueeze(1)).squeeze(1)] = \ + current_last_idx_p.unsqueeze(1)).squeeze(1)] = \ states[0, last_chunk_p] else: varlen_state = mamba_outputs @@ -862,17 +822,12 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, #Note: # for decode always: current_first_idx_d == current_last_idx_d # at block boundaries: current_first_idx_d > last_state_idx_d - - # copy initial state to new location, - # as update kernel works in place - #if (current_last_idx_d > last_state_idx_d).any(): - # (skip IF as it breaks CUDA graphs) - conv_state[state_indices_tensor_d_output] = conv_state[ - state_indices_tensor_d_input] else: # Without caching, read and write in-place to the same blocks: state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + current_last_idx_d = None + last_state_idx_d = None # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( @@ -881,7 +836,10 @@ def copy_to_conv_state(conv_state_block_idx, x, x_offset, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d_output) + conv_state_indices=state_indices_tensor_d, + current_last_idx=current_last_idx_d, + last_state_idx=last_state_idx_d, + ) hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( hidden_states_B_C_d) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 635fb2e9fd07..8604c1114542 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -20,11 +20,16 @@ def _causal_conv1d_fwd_kernel( # continuous batching w_ptr, # (dim, width) bias_ptr, initial_states_ptr, # conv_states_ptr - cache_indices_ptr, # conv_state_indices_ptr + cache_indices_ptr, # (dim, cu_seqlen) has_initial_states_ptr, query_start_loc_ptr, batch_ptr, token_chunk_offset_ptr, + n_blocks_to_fill_ptr, # (dim,) + current_first_idx, # (dim,) + current_last_idx, # (dim,) + last_state_idx, # (dim,) + seq_lens_completed, # (dim,) o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions batch: tl.int32, # actually padded_batch @@ -44,6 +49,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_cache_chunk: tl.constexpr, # others pad_slot_id: tl.constexpr, # Meta-parameters @@ -53,6 +60,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching HAS_INITIAL_STATES: tl.constexpr, HAS_CACHE: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_CACHE_ENABLED: tl.constexpr, USE_PAD_SLOT: tl.constexpr, NP2_STATELEN: tl.constexpr, BLOCK_M: tl.constexpr, @@ -83,19 +91,61 @@ def _causal_conv1d_fwd_kernel( # continuous batching # find the actual sequence length seqlen = sequence_end_index - sequence_start_index + if IS_CACHE_ENABLED: + # Get the completed sequence length for far and compute the offset. + current_first_index = tl.load(current_first_idx + idx_seq) + current_last_index = tl.load(current_last_idx + idx_seq) + sequence_completed_index = tl.load(seq_lens_completed + idx_seq) + + # Compute the offset where the first stride_cache_chunk-aligned first full block is + sequence_offset_index = stride_cache_chunk * BLOCK_M * current_first_index - sequence_completed_index + + # Compute the last full cache block for the sequence + last_full_cache_index = sequence_end_index - ( + (seqlen - sequence_offset_index) % (stride_cache_chunk * BLOCK_M)) + # If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one + if ((seqlen - sequence_offset_index) % + (stride_cache_chunk * BLOCK_M)) == 0: + last_full_cache_index = last_full_cache_index - stride_cache_chunk * BLOCK_M + + # Get the number of blocks for the current sequence + n_block_to_fill = tl.load(n_blocks_to_fill_ptr + idx_seq) + + if HAS_INITIAL_STATES: + # Get the state from the last_state_idx + conv_state_init = tl.load(last_state_idx + idx_seq) + else: + conv_state_init = 0 + else: + n_block_to_fill = 0 + current_last_index = 0 + conv_state_init = 0 + sequence_offset_index = 0 + last_full_cache_index = 0 + token_offset = BLOCK_M * chunk_offset segment_len = min(BLOCK_M, seqlen - token_offset) # base of the sequence x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( - tl.int64) + if HAS_INITIAL_STATES: + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + conv_state_init).to(tl.int64) + else: + # cache_idx + conv_state_batch_coord = conv_state_init else: - # cache_idx - conv_state_batch_coord = idx_seq + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq).to(tl.int64) + else: + # cache_idx + conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa if conv_state_batch_coord == pad_slot_id: # not processing as this is not the actual sequence @@ -175,8 +225,34 @@ def _causal_conv1d_fwd_kernel( # continuous batching loaded_x = tl.load(x_ptrs, mask_x, 0.0) new_conv_state = tl.load(x_ptrs, mask_x, 0.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - conv_states_ptrs_target = conv_states_base[None, :] + ( - idx_tokens_conv * stride_conv_state_tok)[:, None] + + if HAS_INITIAL_STATES: + # Get the state from the last_state_idx + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_states_offset = tl.load(conv_state_indices_ptr + + idx_seq * + stride_state_indices + + current_last_index).to( + tl.int64) + else: + # cache_idx + conv_states_offset = current_last_index + else: + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_states_offset = tl.load(conv_state_indices_ptr + + idx_seq + + current_last_index).to( + tl.int64) + else: + # cache_idx + conv_states_offset = idx_seq + current_last_index + conv_states_ptrs_target = ( + conv_states_ptr + (conv_states_offset * stride_conv_state_seq) + + # Offset from seq + (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] + idx_tokens_conv * stride_conv_state_tok)[:, None] mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] @@ -278,6 +354,61 @@ def _causal_conv1d_fwd_kernel( # continuous batching conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + # Store intermediate states aligned with stride_cache_chunk + if chunk_offset > 0 and (chunk_offset - 1) < n_block_to_fill: + # Old approach. Store the states at chunk boundaries from the start of the sequence + # idx_tokens_last = (chunk_offset * stride_cache_chunk * BLOCK_M + sequence_offset_index - state_len) + tl.arange( + # 0, NP2_STATELEN) # [BLOCK_M] + # x_ptrs = x_ptr + ( + # (sequence_start_index + idx_tokens_last) * + # stride_x_token)[:, None] + ( + # idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] + + # New approach, store the states at the chunk boundaries from the start of the sequence + idx_tokens_last = ( + last_full_cache_index - + (n_block_to_fill - chunk_offset) * stride_cache_chunk * BLOCK_M - + state_len) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = x_ptr + ((idx_tokens_last) * stride_x_token)[:, None] + ( + idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] + + mask_x = ( + (idx_tokens_last >= 0)[:, None] & + (idx_tokens_last < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + if HAS_INITIAL_STATES: + # Get the state from the last_state_idx + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_states_offset = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + conv_state_init + + (chunk_offset - 1)).to(tl.int64) + else: + # cache_idx + conv_states_offset = conv_state_init + (chunk_offset - 1) + else: + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_states_offset = tl.load(conv_state_indices_ptr + idx_seq + + (chunk_offset - 1)).to(tl.int64) + else: + # cache_idx + conv_states_offset = idx_seq + (chunk_offset - 1) + conv_states_ptrs_target = ( + conv_states_ptr + + (conv_states_offset * stride_conv_state_seq) + # Offset from seq + (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] + idx_tokens_conv * stride_conv_state_tok)[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats + < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, new_conv_state, mask) + if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim @@ -366,6 +497,13 @@ def causal_conv1d_fn( has_initial_state: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", pad_slot_id: int = PAD_SLOT_ID, + n_blocks_to_fill: Optional[torch.Tensor] = None, + current_first_idx: Optional[torch.Tensor] = None, + current_last_idx: Optional[torch.Tensor] = None, + last_state_idx: Optional[torch.Tensor] = None, + seq_lens_completed: Optional[torch.Tensor] = None, + stride_state_indices: Optional[int] = 0, + stride_cache_chunk: Optional[int] = 0, metadata=None, validate_data=False, ): @@ -408,16 +546,27 @@ def causal_conv1d_fn( for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - + n_blocks_to_fill: (batch) int32 + The number of full cache blocks to be filled + current_first_idx: (batch) int32 + The first cache block to be filled. This tensor indexes into cache_indices + current_last_idx: (batch) int32 + The last cache block to be filled. This tensor indexes into cache_indices + last_state_idx: (batch) int32 + The cache block for the init values. This tensor indexes into cache_indices + seq_lens_completed: (batch) int32 + The number of tokens already completed for each sequence + stride_state_indices: int + The stride of conv_states. I.e., dim=1 of conv_states. + stride_cache_chunk: int + The token offsets ratio. I.e., BLOCK_M out: same shape as `x` """ if isinstance(activation, bool) and activation: activation = "silu" args = None - #out = torch.empty_like(x) - #TODO: Noticed strange behavior, maybe due to use of uninitialzed values? - out = torch.zeros_like(x) + out = torch.empty_like(x) if metadata is not None: cu_seqlen = metadata.cu_seqlen nums_dict = metadata.nums_dict @@ -581,6 +730,11 @@ def grid(META): query_start_loc, batch_ptr, token_chunk_offset_ptr, + n_blocks_to_fill, + current_first_idx, + current_last_idx, + last_state_idx, + seq_lens_completed, out, # Matrix dimensions padded_batch, @@ -599,6 +753,8 @@ def grid(META): stride_o_seq, stride_o_dim, stride_o_token, + stride_state_indices, + stride_cache_chunk, # others pad_slot_id, # META @@ -608,10 +764,11 @@ def grid(META): HAS_INITIAL_STATES=has_initial_state is not None, HAS_CACHE=conv_states is not None, IS_CONTINUOUS_BATCHING=cache_indices is not None, + IS_CACHE_ENABLED=current_last_idx is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, #launch_cooperative_grid=True - BLOCK_M=8, + BLOCK_M=8, # TODO this should be metadata.nums_dict.keys()[0] BLOCK_N=256, num_stages=2, ) @@ -628,6 +785,8 @@ def _causal_conv1d_update_kernel( cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, num_accepted_tokens_ptr, + current_last_idx, + last_state_idx, o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -655,6 +814,7 @@ def _causal_conv1d_update_kernel( KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_CACHE_ENABLED: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, @@ -668,13 +828,23 @@ def _causal_conv1d_update_kernel( # [BLOCK_N,] elements along the feature-dimension (channel) idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + if IS_CACHE_ENABLED: + # Get the state from the last_state_idx + conv_state_init = tl.load(last_state_idx + idx_seq) + current_last_index = tl.load(current_last_idx + idx_seq) + else: + conv_state_init = 0 + current_last_index = 0 + if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch + # cache_idx conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices).to( - tl.int64) + idx_seq * stride_state_indices + + conv_state_init).to(tl.int64) else: - conv_state_batch_coord = idx_seq + # cache_idx + conv_state_batch_coord = conv_state_init + if USE_PAD_SLOT: # noqa if conv_state_batch_coord == pad_slot_id: # not processing as this is not the actual sequence @@ -751,11 +921,20 @@ def _causal_conv1d_update_kernel( new_conv_state = tl.where(mask, conv_state, loaded_x) - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = conv_state_base + ( - idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + # Get the state from the last_state_idx + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_states_offset = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + current_last_index).to(tl.int64) + else: + # cache_idx + conv_states_offset = current_last_index + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) + # Offset from seq + (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok)[:, None] mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] tl.store(conv_state_ptrs_target, new_conv_state, mask) @@ -853,6 +1032,8 @@ def causal_conv1d_update( conv_state_indices: Optional[torch.Tensor] = None, num_accepted_tokens: Optional[torch.Tensor] = None, pad_slot_id: int = PAD_SLOT_ID, + current_last_idx: Optional[torch.Tensor] = None, + last_state_idx: Optional[torch.Tensor] = None, metadata=None, validate_data=False, ): @@ -872,6 +1053,10 @@ def causal_conv1d_update( If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. + current_last_idx: (batch) int32 + The last cache block to be filled. This tensor indexes into cache_indices + last_state_idx: (batch) int32 + The cache block for the init values. This tensor indexes into cache_indices pad_slot_id: int if cache_indices is passed, lets the kernel identify padded entries that will not be processed, @@ -947,6 +1132,8 @@ def grid(META): cache_seqlens, conv_state_indices, num_accepted_tokens, + current_last_idx, + last_state_idx, out, # Matrix dimensions batch, @@ -974,6 +1161,7 @@ def grid(META): KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_CACHE_ENABLED=current_last_idx is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, From 4bb28c07f2d8b4ab0e6de47429b63a5086a87001 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Tue, 23 Sep 2025 12:11:16 +0000 Subject: [PATCH 058/105] Precommit Signed-off-by: Stanislaw Wozniak --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 256f184c0f39..9c27486489e8 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -684,7 +684,7 @@ def forward_cuda( if mamba2_metadata.cu_seqlen is None: mamba2_metadata = update_metadata(x, query_start_loc_p, mamba2_metadata) - + assert isinstance(mamba2_metadata.nums_dict, dict) BLOCK_M = list(mamba2_metadata.nums_dict.keys())[0] if cache_enabled: n_blocks_to_fill = current_last_idx_p - current_first_idx_p From 1c7e94717f373652af1447cfaa93f93260458a2a Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Tue, 23 Sep 2025 13:51:24 +0000 Subject: [PATCH 059/105] Support for disabling prefix caching Signed-off-by: Stanislaw Wozniak --- vllm/v1/attention/backends/mamba2_attn.py | 63 +++++++++++++---------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 12b7a8e819a4..ee5918b4e1e7 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -327,34 +327,41 @@ def build(self, state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID - self.current_last_token_block_idx[:num_decodes].copy_( - current_last_token_block_idx, non_blocking=True) - current_last_token_block_idx = \ - self.current_last_token_block_idx[:num_input_tokens] - current_last_token_block_idx[num_decodes:] = 0 - - self.current_first_token_block_idx[:num_decodes].copy_( - current_first_token_block_idx, non_blocking=True) - current_first_token_block_idx = \ - self.current_first_token_block_idx[:num_input_tokens] - current_first_token_block_idx[num_decodes:] = 0 - - self.last_computed_token_block_idx[:num_decodes].copy_( - last_computed_token_block_idx, non_blocking=True) - last_computed_token_block_idx = \ - self.last_computed_token_block_idx[:num_input_tokens] - last_computed_token_block_idx[num_decodes:] = 0 - - self.seq_lens_completed[:num_decodes].copy_(seq_lens_completed, - non_blocking=True) - seq_lens_completed = self.seq_lens_completed[:num_input_tokens] - seq_lens_completed[num_decodes:] = 0 - - self.last_computed_token_block_offset[:num_decodes].copy_( - last_computed_token_block_offset, non_blocking=True) - last_computed_token_block_offset = \ - self.last_computed_token_block_offset[:num_input_tokens] - last_computed_token_block_offset[num_decodes:] = 0 + if self.kv_cache_spec.cache_strategy != 'disabled': + self.current_last_token_block_idx[:num_decodes].copy_( + current_last_token_block_idx, non_blocking=True) + current_last_token_block_idx = \ + self.current_last_token_block_idx[:num_input_tokens] + current_last_token_block_idx[num_decodes:] = 0 + + self.current_first_token_block_idx[:num_decodes].copy_( + current_first_token_block_idx, non_blocking=True) + current_first_token_block_idx = \ + self.current_first_token_block_idx[:num_input_tokens] + current_first_token_block_idx[num_decodes:] = 0 + + self.last_computed_token_block_idx[:num_decodes].copy_( + last_computed_token_block_idx, non_blocking=True) + last_computed_token_block_idx = \ + self.last_computed_token_block_idx[:num_input_tokens] + last_computed_token_block_idx[num_decodes:] = 0 + + self.seq_lens_completed[:num_decodes].copy_(seq_lens_completed, + non_blocking=True) + seq_lens_completed = self.seq_lens_completed[:num_input_tokens] + seq_lens_completed[num_decodes:] = 0 + + self.last_computed_token_block_offset[:num_decodes].copy_( + last_computed_token_block_offset, non_blocking=True) + last_computed_token_block_offset = \ + self.last_computed_token_block_offset[:num_input_tokens] + last_computed_token_block_offset[num_decodes:] = 0 + else: + current_last_token_block_idx = None + current_first_token_block_idx = None + last_computed_token_block_idx = None + last_computed_token_block_offset = None + seq_lens_completed = None attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, From 7cdae6025548381238cccfe0d0331c819f6ee635 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 25 Sep 2025 13:16:25 -0400 Subject: [PATCH 060/105] Metadata optimization for apc=off Signed-off-by: Stanislaw Wozniak --- vllm/v1/attention/backends/mamba2_attn.py | 62 ++++++++++++----------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ee5918b4e1e7..2211d7f02711 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -210,39 +210,47 @@ def build(self, # Always return just a single block per each request: state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + # Additional cache-related varaiables: + current_last_token_block_idx = None + current_first_token_block_idx = None + last_computed_token_block_idx = None + last_computed_token_block_offset = None + seq_lens_completed = None else: # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor + # Additional cache-related varaiables: + mamba_block_size = self.kv_cache_spec.block_size + seq_lens_pending = ( + torch.roll(common_attn_metadata.query_start_loc, -1, -1) - + common_attn_metadata.query_start_loc)[:-1] + seq_lens_completed = common_attn_metadata.seq_lens - \ + seq_lens_pending + last_computed_token_block_offset = \ + seq_lens_completed % mamba_block_size + # Indices: last_computed <= current_first <= current_last + # Cases: + # last_computed == current_first if last state was partially + # computed and needs to be updated + # current_first == current_last if no block crossing occurs, and + # only one state will be stored + # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: + current_last_token_block_idx = cdiv( + seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 + current_first_token_block_idx = cdiv(seq_lens_completed + 1, + mamba_block_size) - 1 + last_computed_token_block_idx = cdiv(seq_lens_completed, + mamba_block_size) - 1 + # -1 in case it's non-computed and causes later issues with indexing + last_computed_token_block_idx = \ + last_computed_token_block_idx.clamp(min=0) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) - mamba_block_size = self.kv_cache_spec.block_size - seq_lens_pending = ( - torch.roll(common_attn_metadata.query_start_loc, -1, -1) - - common_attn_metadata.query_start_loc)[:-1] - seq_lens_completed = (common_attn_metadata.seq_lens - seq_lens_pending) - last_computed_token_block_offset = \ - seq_lens_completed % mamba_block_size - # Indices: last_computed <= current_first <= current_last - # Cases: - # last_computed == current_first if last state was partially - # computed and needs to be updated - # current_first == current_last if no block crossing occurs, and - # only one state will be stored - # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: - current_last_token_block_idx = cdiv( - seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 - current_first_token_block_idx = cdiv(seq_lens_completed + 1, - mamba_block_size) - 1 - last_computed_token_block_idx = cdiv(seq_lens_completed, - mamba_block_size) - 1 - # -1 in case it's non-computed and causes later issues with indexing - last_computed_token_block_idx = \ - last_computed_token_block_idx.clamp(min=0) - # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: #[batch,] @@ -356,12 +364,6 @@ def build(self, last_computed_token_block_offset = \ self.last_computed_token_block_offset[:num_input_tokens] last_computed_token_block_offset[num_decodes:] = 0 - else: - current_last_token_block_idx = None - current_first_token_block_idx = None - last_computed_token_block_idx = None - last_computed_token_block_offset = None - seq_lens_completed = None attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, From f2beb4d492d5a1cf162f82c900e220e4964d1c74 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 29 Sep 2025 19:31:11 -0400 Subject: [PATCH 061/105] Evaluating other models. Lightweight model for testing. Signed-off-by: Stanislaw Wozniak --- tests/v1/core/test_APC_hybrid_models.py | 3 ++- vllm/model_executor/models/bamba.py | 4 ---- vllm/model_executor/models/config.py | 21 +++++++++++++++++++++ vllm/model_executor/models/falcon_h1.py | 4 ---- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tests/v1/core/test_APC_hybrid_models.py b/tests/v1/core/test_APC_hybrid_models.py index 4ca76d877536..0c694a2e1af2 100644 --- a/tests/v1/core/test_APC_hybrid_models.py +++ b/tests/v1/core/test_APC_hybrid_models.py @@ -20,7 +20,8 @@ from ...conftest import HfRunner, VllmRunner MODELS = [ - "ibm-granite/granite-4.0-tiny-preview", + #"ibm-granite/granite-4.0-tiny-preview", + "hmellor/tiny-random-BambaForCausalLM", ] diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 584981ef3ebf..57960a79d5a4 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -489,12 +489,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Bamba currently does not support prefix caching" - self.quant_config = vllm_config.quant_config super().__init__() diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 05e2f582b7ed..bc6d6546b78e 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -301,9 +301,30 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if not envs.VLLM_USE_V1: return + cache_config = vllm_config.cache_config model_config = vllm_config.model_config compilation_config = vllm_config.compilation_config + # TODO: find a way to keep this list updated, or redundant + MAMBA2_MODELS = [ + "BambaForCausalLM", + "FalconH1ForCausalLM", + "GraniteMoeHybridForCausalLM", + "Mamba2ForCausalLM", + "NemotronHForCausalLM", + #"Plamo2ForCausalLM", # currently fails + "Zamba2ForCausalLM", + ] + if cache_config.enable_prefix_caching: + if model_config.architecture in MAMBA2_MODELS: + logger.info("Warning: Prefix caching is currently enabled. " + "Its support for Mamba2 layers is experimental. " + "Please report any issues you may observe.") + else: + logger.info("Hybrid or mamba-based model detected without " + "support for prefix caching: disabling.") + cache_config.enable_prefix_caching = False + # TODO(tdoublep): remove as full cuda graph support is added FCG_NOT_SUPPORTED_MODELS = [ "Lfm2ForCausalLM", diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 83efdd2e433f..fd9d95f7f26f 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -577,12 +577,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert (not cache_config.enable_prefix_caching - ), "FalconH1 currently does not support prefix caching" - self.quant_config = vllm_config.quant_config super().__init__() From 1f4079427e08d61692856c7f69aeefc47de3575c Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Tue, 30 Sep 2025 14:35:04 +0000 Subject: [PATCH 062/105] Cleanup conv1D kernel and stripped APC testcases Signed-off-by: Thomas Ortner --- tests/v1/core/test_APC_hybrid_models.py | 51 ++-- .../layers/mamba/mamba_mixer2.py | 23 +- .../layers/mamba/ops/causal_conv1d.py | 263 +++++++----------- 3 files changed, 126 insertions(+), 211 deletions(-) diff --git a/tests/v1/core/test_APC_hybrid_models.py b/tests/v1/core/test_APC_hybrid_models.py index 0c694a2e1af2..a44e59134b3b 100644 --- a/tests/v1/core/test_APC_hybrid_models.py +++ b/tests/v1/core/test_APC_hybrid_models.py @@ -164,33 +164,30 @@ def test_single_prompt_mamba_size_alignment( mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ mamba_block_size - for multiple in [1, 3, 7]: - for offsets in [ - 3, mamba_block_size // 4 - 3, mamba_block_size // 4, - mamba_block_size // 4 + 3, mamba_block_size // 2 - 3, - mamba_block_size // 2, mamba_block_size // 2 + 3, - mamba_block_size - 3 - ]: - - vllm_runner_kwargs[ - 'max_num_batched_tokens'] = multiple * mamba_block_size - \ - offsets - vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs, n_repetitions) - - # Check alignment of the output logits when using APC - for r_idx, vllm_outputs_cache_itn in enumerate( - vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - check_logprobs_close( - outputs_0_lst=vllm_outputs_no_cache[0], - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) + multiple = 2 + for offsets in [ + 3, mamba_block_size // 2 + 3, mamba_block_size - 3 + ]: + + vllm_runner_kwargs[ + 'max_num_batched_tokens'] = multiple * mamba_block_size - \ + offsets + vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, + num_logprobs, n_repetitions) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate( + vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + check_logprobs_close( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) @pytest.mark.parametrize("model", MODELS) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 9c27486489e8..f29e38faa094 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -650,6 +650,12 @@ def forward_cuda( attn_metadata.current_last_token_block_idx, [num_decodes, num_prefills], dim=0) + else: + current_first_idx_d, current_first_idx_p = None, None + current_last_idx_d, current_last_idx_p = None, None + last_state_idx_d, last_state_idx_p = None, None + seq_lens_completed_d, seq_lens_completed_p = None, None + # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -685,17 +691,6 @@ def forward_cuda( mamba2_metadata = update_metadata(x, query_start_loc_p, mamba2_metadata) assert isinstance(mamba2_metadata.nums_dict, dict) - BLOCK_M = list(mamba2_metadata.nums_dict.keys())[0] - if cache_enabled: - n_blocks_to_fill = current_last_idx_p - current_first_idx_p - stride_state_indices = state_indices_tensor_p.shape[-1] - else: - current_first_idx_p = None - current_last_idx_p = None - seq_lens_completed_p = None - last_state_idx_p = None - n_blocks_to_fill = None - stride_state_indices = 1 hidden_states_B_C_p = causal_conv1d_fn( x, @@ -705,13 +700,11 @@ def forward_cuda( conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - n_blocks_to_fill=n_blocks_to_fill, current_first_idx=current_first_idx_p, current_last_idx=current_last_idx_p, last_state_idx=last_state_idx_p, seq_lens_completed=seq_lens_completed_p, - stride_cache_chunk=mamba_block_size // BLOCK_M, - stride_state_indices=stride_state_indices, + block_size_to_align=mamba_block_size, metadata=mamba2_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] @@ -826,8 +819,6 @@ def forward_cuda( # Without caching, read and write in-place to the same blocks: state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d - current_last_idx_d = None - last_state_idx_d = None # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 88df9c436827..3b21400094c2 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -25,41 +25,34 @@ def _causal_conv1d_fwd_kernel( # continuous batching query_start_loc_ptr, batch_ptr, token_chunk_offset_ptr, - n_blocks_to_fill_ptr, # (dim,) current_first_idx, # (dim,) current_last_idx, # (dim,) last_state_idx, # (dim,) seq_lens_completed, # (dim,) o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions - batch: tl.int32, # actually padded_batch dim: tl.constexpr, seqlen: tl.int32, # cu_seqlen num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines # Strides stride_x_seq: tl.constexpr, # stride to get to next sequence, stride_x_dim: tl.constexpr, # stride to get to next feature-value, - stride_x_token: tl. - constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) stride_w_dim: tl.constexpr, # stride to get to next dim-axis value stride_w_width: tl.constexpr, # stride to get to next width-axis value stride_istate_seq: tl.constexpr, stride_istate_dim: tl.constexpr, stride_istate_token: tl.constexpr, - stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, stride_state_indices: tl.constexpr, - stride_cache_chunk: tl.constexpr, + stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M # others pad_slot_id: tl.constexpr, # Meta-parameters HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, - HAS_INITIAL_STATES: tl.constexpr, - HAS_CACHE: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, IS_CACHE_ENABLED: tl.constexpr, USE_PAD_SLOT: tl.constexpr, NP2_STATELEN: tl.constexpr, @@ -92,36 +85,36 @@ def _causal_conv1d_fwd_kernel( # continuous batching seqlen = sequence_end_index - sequence_start_index if IS_CACHE_ENABLED: - # Get the completed sequence length for far and compute the offset. + # Handle the case if prefix caching is enabled + + # Get the length of the completed sequence so far and compute the offset. current_first_index = tl.load(current_first_idx + idx_seq) current_last_index = tl.load(current_last_idx + idx_seq) sequence_completed_index = tl.load(seq_lens_completed + idx_seq) - # Compute the offset where the first stride_cache_chunk-aligned first full block is - sequence_offset_index = stride_cache_chunk * BLOCK_M * current_first_index - sequence_completed_index + # Compute the offset where the first stride_block_m-aligned first full block is + # Value in "token-space" + sequence_offset_token_index = sequence_completed_index % (stride_block_m * BLOCK_M) # Compute the last full cache block for the sequence - last_full_cache_index = sequence_end_index - ( - (seqlen - sequence_offset_index) % (stride_cache_chunk * BLOCK_M)) + # Value in "token-space" + last_full_block_token_index = sequence_end_index - sequence_offset_token_index # If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one - if ((seqlen - sequence_offset_index) % - (stride_cache_chunk * BLOCK_M)) == 0: - last_full_cache_index = last_full_cache_index - stride_cache_chunk * BLOCK_M + if sequence_offset_token_index == 0: + last_full_block_token_index = last_full_block_token_index - stride_block_m * BLOCK_M - # Get the number of blocks for the current sequence - n_block_to_fill = tl.load(n_blocks_to_fill_ptr + idx_seq) + # Get the number of blocks to be filled for the current sequence + # If n_block_to_fill = 0, then only the state at the sequence end is stored + n_block_to_fill = current_last_index - current_first_index - if HAS_INITIAL_STATES: - # Get the state from the last_state_idx - conv_state_init = tl.load(last_state_idx + idx_seq) - else: - conv_state_init = 0 + # Get the index of the init block + conv_state_init_index = tl.load(last_state_idx + idx_seq) else: n_block_to_fill = 0 current_last_index = 0 - conv_state_init = 0 - sequence_offset_index = 0 - last_full_cache_index = 0 + conv_state_init_index = 0 + sequence_offset_token_index = 0 + last_full_block_token_index = 0 token_offset = BLOCK_M * chunk_offset segment_len = min(BLOCK_M, seqlen - token_offset) @@ -129,30 +122,18 @@ def _causal_conv1d_fwd_kernel( # continuous batching # base of the sequence x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] - if HAS_INITIAL_STATES: - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - conv_state_init).to(tl.int64) - else: - # cache_idx - conv_state_batch_coord = conv_state_init - else: - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq).to(tl.int64) - else: - # cache_idx - conv_state_batch_coord = idx_seq + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + conv_state_init_index).to(tl.int64) + if USE_PAD_SLOT: # noqa if conv_state_batch_coord == pad_slot_id: # not processing as this is not the actual sequence return conv_states_base = (conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] @@ -161,10 +142,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] if chunk_offset == 0: # read from conv_states - load_init_state = False - if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES - load_init_state = tl.load(has_initial_states_ptr + idx_seq).to( - tl.int1) + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to( + tl.int1) if load_init_state: # load from conv_states prior_tokens = conv_states_base + (state_len - @@ -226,28 +205,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching new_conv_state = tl.load(x_ptrs, mask_x, 0.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - if HAS_INITIAL_STATES: - # Get the state from the last_state_idx - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * - stride_state_indices + - current_last_index).to( - tl.int64) - else: - # cache_idx - conv_states_offset = current_last_index - else: - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq + - current_last_index).to( - tl.int64) - else: - # cache_idx - conv_states_offset = idx_seq + current_last_index + # Compute the offset where the last block should be written in the conv_states + conv_states_offset = tl.load(conv_state_indices_ptr + + idx_seq * + stride_state_indices + + current_last_index).to( + tl.int64) + conv_states_ptrs_target = ( conv_states_ptr + (conv_states_offset * stride_conv_state_seq) + # Offset from seq @@ -354,60 +318,44 @@ def _causal_conv1d_fwd_kernel( # continuous batching conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') - # Store intermediate states aligned with stride_cache_chunk - if chunk_offset > 0 and (chunk_offset - 1) < n_block_to_fill: - # Old approach. Store the states at chunk boundaries from the start of the sequence - # idx_tokens_last = (chunk_offset * stride_cache_chunk * BLOCK_M + sequence_offset_index - state_len) + tl.arange( - # 0, NP2_STATELEN) # [BLOCK_M] - # x_ptrs = x_ptr + ( - # (sequence_start_index + idx_tokens_last) * - # stride_x_token)[:, None] + ( - # idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] - - # New approach, store the states at the chunk boundaries from the start of the sequence - idx_tokens_last = ( - last_full_cache_index - - (n_block_to_fill - chunk_offset) * stride_cache_chunk * BLOCK_M - - state_len) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] - x_ptrs = x_ptr + ((idx_tokens_last) * stride_x_token)[:, None] + ( - idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] - - mask_x = ( - (idx_tokens_last >= 0)[:, None] & - (idx_tokens_last < seqlen)[:, None] & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) - idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - if HAS_INITIAL_STATES: - # Get the state from the last_state_idx - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - conv_state_init + - (chunk_offset - 1)).to(tl.int64) - else: - # cache_idx - conv_states_offset = conv_state_init + (chunk_offset - 1) - else: - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_states_offset = tl.load(conv_state_indices_ptr + idx_seq + - (chunk_offset - 1)).to(tl.int64) - else: - # cache_idx - conv_states_offset = idx_seq + (chunk_offset - 1) - conv_states_ptrs_target = ( - conv_states_ptr + - (conv_states_offset * stride_conv_state_seq) + # Offset from seq - (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] - idx_tokens_conv * stride_conv_state_tok)[:, None] - - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats - < dim)[None, :] - tl.debug_barrier() # NOTE: use this due to bug in Triton compiler - tl.store(conv_states_ptrs_target, new_conv_state, mask) + # Store intermediate states aligned with stride_cache_chunk + # The states are cached starting from the last stride_block_m. + # For example: + # If n_block_to_fill = 0, then the state at the sequence is cached. + # If n_block_to_fill > 0, then the states at the sequence and at the n_block_to_fill-last stride_block_m are cached. + if (chunk_offset - 1) < n_block_to_fill: + # Store the states at the chunk boundaries from the start of the sequence + idx_tokens_last = ( + last_full_block_token_index - + (n_block_to_fill - chunk_offset) * stride_block_m * BLOCK_M - + state_len) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = x_ptr + ((idx_tokens_last) * stride_x_token)[:, None] + ( + idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] + + mask_x = ( + (idx_tokens_last >= 0)[:, None] & + (idx_tokens_last < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # cache_idx + conv_states_offset = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + conv_state_init_index + + (chunk_offset - 1)).to(tl.int64) + + conv_states_ptrs_target = ( + conv_states_ptr + + (conv_states_offset * stride_conv_state_seq) + # Offset from seq + (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] + idx_tokens_conv * stride_conv_state_tok)[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats + < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, new_conv_state, mask) if HAS_BIAS: bias = bias_ptr + idx_feats @@ -497,13 +445,11 @@ def causal_conv1d_fn( has_initial_state: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", pad_slot_id: int = PAD_SLOT_ID, - n_blocks_to_fill: Optional[torch.Tensor] = None, current_first_idx: Optional[torch.Tensor] = None, current_last_idx: Optional[torch.Tensor] = None, last_state_idx: Optional[torch.Tensor] = None, seq_lens_completed: Optional[torch.Tensor] = None, - stride_state_indices: Optional[int] = 0, - stride_cache_chunk: Optional[int] = 0, + block_size_to_align: Optional[int] = 0, metadata=None, validate_data=False, ): @@ -514,7 +460,7 @@ def causal_conv1d_fn( sequences are concatenated from left to right for varlen weight: (dim, width) conv_states: (...,dim,width - 1) itype - updated inplace if provided + updated inplace if cache_indices are not provided [it use `cache_indices` to get the index to the cache of conv_state for that sequence conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True @@ -546,8 +492,6 @@ def causal_conv1d_fn( for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - n_blocks_to_fill: (batch) int32 - The number of full cache blocks to be filled current_first_idx: (batch) int32 The first cache block to be filled. This tensor indexes into cache_indices current_last_idx: (batch) int32 @@ -556,10 +500,8 @@ def causal_conv1d_fn( The cache block for the init values. This tensor indexes into cache_indices seq_lens_completed: (batch) int32 The number of tokens already completed for each sequence - stride_state_indices: int - The stride of conv_states. I.e., dim=1 of conv_states. - stride_cache_chunk: int - The token offsets ratio. I.e., BLOCK_M + block_size_to_align: int + The block size to align the cached states to out: same shape as `x` """ if isinstance(activation, bool) and activation: @@ -611,6 +553,7 @@ def causal_conv1d_fn( stride_istate_dim = 0 stride_istate_token = 0 num_cache_lines = 0 + BLOCK_M = 8 if conv_states is not None: # extensions to support vLLM: # 1. conv_states is used to replaced initial_states @@ -626,11 +569,9 @@ def causal_conv1d_fn( stride_istate_token = conv_states.stride(2) assert stride_istate_dim == 1 if out.dim() == 2: - stride_o_seq = 0 stride_o_dim = out.stride(0) stride_o_token = out.stride(1) else: - stride_o_seq = out.stride(0) stride_o_dim = out.stride(1) stride_o_token = out.stride(2) @@ -651,6 +592,10 @@ def causal_conv1d_fn( assert weight.stride(1) == 1 assert (dim, width) == weight.shape assert is_channel_last, "Need to run in channel-last layout" + if block_size_to_align > 0: + assert block_size_to_align % BLOCK_M, "The mamba block size needs to be divisible by the BLOCK_M" + else: + block_size_to_align = BLOCK_M if metadata is None: @@ -733,14 +678,12 @@ def grid(META): query_start_loc, batch_ptr, token_chunk_offset_ptr, - n_blocks_to_fill, current_first_idx, current_last_idx, last_state_idx, seq_lens_completed, out, # Matrix dimensions - padded_batch, dim, cu_seqlen, num_cache_lines, @@ -753,25 +696,21 @@ def grid(META): stride_istate_seq, stride_istate_dim, stride_istate_token, - stride_o_seq, stride_o_dim, stride_o_token, - stride_state_indices, - stride_cache_chunk, + cache_indices.stride(0), + block_size_to_align // BLOCK_M, # others pad_slot_id, # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], - HAS_INITIAL_STATES=has_initial_state is not None, - HAS_CACHE=conv_states is not None, - IS_CONTINUOUS_BATCHING=cache_indices is not None, IS_CACHE_ENABLED=current_last_idx is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, #launch_cooperative_grid=True - BLOCK_M=8, # TODO this should be metadata.nums_dict.keys()[0] + BLOCK_M=BLOCK_M, BLOCK_N=256, num_stages=2, ) @@ -785,7 +724,6 @@ def _causal_conv1d_update_kernel( w_ptr, # (dim, width) bias_ptr, conv_state_ptr, - cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, num_accepted_tokens_ptr, query_start_loc_ptr, # (batch + 1) @@ -818,7 +756,6 @@ def _causal_conv1d_update_kernel( KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_VARLEN: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, IS_CACHE_ENABLED: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, @@ -841,14 +778,10 @@ def _causal_conv1d_update_kernel( conv_state_init = 0 current_last_index = 0 - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - conv_state_init).to(tl.int64) - else: - # cache_idx - conv_state_batch_coord = conv_state_init + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + conv_state_init).to(tl.int64) if USE_PAD_SLOT: # noqa if conv_state_batch_coord == pad_slot_id: @@ -861,7 +794,7 @@ def _causal_conv1d_update_kernel( tl.int64) # revise state_len and seqlen state_len = state_len - (seqlen - - (query_end_index - query_start_index)) + (query_end_index - query_start_index)) seqlen = query_end_index - query_start_index x_offset = query_start_index * stride_x_token o_offset = query_start_index * stride_o_token @@ -948,14 +881,10 @@ def _causal_conv1d_update_kernel( new_conv_state = tl.where(mask, conv_state, loaded_x) # Get the state from the last_state_idx - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - current_last_index).to(tl.int64) - else: - # cache_idx - conv_states_offset = current_last_index + # cache_idx + conv_states_offset = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + current_last_index).to(tl.int64) conv_state_ptrs_target = ( conv_state_ptr + (conv_states_offset * stride_conv_state_seq) + # Offset from seq @@ -1235,7 +1164,6 @@ def grid(META): weight, bias, conv_state, - cache_seqlens, conv_state_indices, num_accepted_tokens, query_start_loc, @@ -1268,7 +1196,6 @@ def grid(META): KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_VARLEN=query_start_loc is not None, - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, IS_CACHE_ENABLED=current_last_idx is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, From e5d5519b86e1052f735e9b9385434fae6ab65a5e Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Tue, 30 Sep 2025 18:13:30 -0400 Subject: [PATCH 063/105] Pre-commit fixes. Signed-off-by: Stanislaw Wozniak --- .../layers/mamba/mamba_mixer2.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 797f08a5e3d6..692cee0acec8 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -578,31 +578,31 @@ def forward_cuda( if cache_enabled: # Split decodes and prefills: - seq_lens_completed_d, seq_lens_completed_p = torch.split( - attn_metadata.seq_lens_completed, [num_decodes, num_prefills], - dim=0) last_state_idx_d, last_state_idx_p = torch.split( attn_metadata.last_computed_token_block_idx, [num_decodes, num_prefills], dim=0) - last_computed_offset_d, last_computed_offset_p = torch.split( - attn_metadata.last_computed_token_block_offset, + current_last_idx_d, current_last_idx_p = torch.split( + attn_metadata.current_last_token_block_idx, [num_decodes, num_prefills], dim=0) - current_first_idx_d, current_first_idx_p = torch.split( + # Prefill-only variables: + _, current_first_idx_p = torch.split( attn_metadata.current_first_token_block_idx, [num_decodes, num_prefills], + dim=0) + _, seq_lens_completed_p = torch.split( + attn_metadata.seq_lens_completed, [num_decodes, num_prefills], dim=0) - current_last_idx_d, current_last_idx_p = torch.split( - attn_metadata.current_last_token_block_idx, + _, last_computed_offset_p = torch.split( + attn_metadata.last_computed_token_block_offset, [num_decodes, num_prefills], dim=0) else: - current_first_idx_d, current_first_idx_p = None, None - current_last_idx_d, current_last_idx_p = None, None last_state_idx_d, last_state_idx_p = None, None - seq_lens_completed_d, seq_lens_completed_p = None, None - + current_last_idx_d, current_last_idx_p = None, None + _, current_first_idx_p = None, None + _, seq_lens_completed_p = None, None # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -722,7 +722,8 @@ def forward_cuda( varlen_states[last_chunk_indices_p] else: # update ssm states - # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor + # - varlen state is a (num_prefills, nheads, headdim, dstate) + # tensor ssm_state[state_indices_tensor_p] = varlen_states # Process decode requests From 148ea61c3a980315e8c5532aa043f3d1bcc18e0b Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 1 Oct 2025 15:12:24 -0400 Subject: [PATCH 064/105] Addressing test failures. Signed-off-by: Stanislaw Wozniak --- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 3 +++ vllm/model_executor/models/config.py | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 367c0173e374..21bc32ddecd4 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -278,6 +278,9 @@ def selective_state_update(state, assert state_batch_indices.shape == (batch, ) if dst_state_batch_indices is not None: assert dst_state_batch_indices.shape == (batch, ) + else: + # revert to the default behavior of in-place state updates + dst_state_batch_indices = state_batch_indices assert out.shape == x.shape grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 2efb7f0cacd9..e13387c92265 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -295,12 +295,12 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO: find a way to keep this list updated, or redundant MAMBA2_MODELS = [ "BambaForCausalLM", - "FalconH1ForCausalLM", + #"FalconH1ForCausalLM", "GraniteMoeHybridForCausalLM", "Mamba2ForCausalLM", - "NemotronHForCausalLM", - #"Plamo2ForCausalLM", # currently fails - "Zamba2ForCausalLM", + #"NemotronHForCausalLM", + #"Plamo2ForCausalLM", + #"Zamba2ForCausalLM", ] if cache_config.enable_prefix_caching: if model_config.architecture in MAMBA2_MODELS: From ce5144b2ce1bc9dd550e6a699bc2df79298342c5 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Wed, 1 Oct 2025 20:29:20 +0000 Subject: [PATCH 065/105] Fixed issue with conv1D Signed-off-by: Thomas Ortner --- tests/v1/core/test_APC_hybrid_models.py | 189 +++++++++++++----- .../layers/mamba/mamba_mixer2.py | 13 +- .../layers/mamba/ops/causal_conv1d.py | 97 ++++----- 3 files changed, 187 insertions(+), 112 deletions(-) diff --git a/tests/v1/core/test_APC_hybrid_models.py b/tests/v1/core/test_APC_hybrid_models.py index a44e59134b3b..7f25ca87da23 100644 --- a/tests/v1/core/test_APC_hybrid_models.py +++ b/tests/v1/core/test_APC_hybrid_models.py @@ -14,14 +14,14 @@ import pytest -from tests.models.utils import check_logprobs_close +from tests.models.utils import check_logprobs_close, check_outputs_equal if TYPE_CHECKING: from ...conftest import HfRunner, VllmRunner MODELS = [ - #"ibm-granite/granite-4.0-tiny-preview", - "hmellor/tiny-random-BambaForCausalLM", + "ibm-granite/granite-4.0-tiny-preview", + # "hmellor/tiny-random-BambaForCausalLM", ] @@ -41,7 +41,22 @@ def _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, } -def _get_vLLM_output_logprobs(vllm_runner, +def _get_vLLM_outputs(vllm_runner, + kwargs, + prompts, + max_tokens, + num_repetitions=1, + vllm_model=None): + outs = [] + if vllm_model is None: + vllm_model = vllm_runner(**kwargs) + for _ in range(num_repetitions): + outs.append( + vllm_model.generate_greedy(prompts, max_tokens)) + + return outs, vllm_model + +def _get_vLLM_logprobs(vllm_runner, kwargs, prompts, max_tokens, @@ -86,7 +101,7 @@ def test_single_prompt( """ Checks exact match decode vllm runner with and without prefix caching """ - MULTIPLE = 120 + MULTIPLE = 300 # Sample prompts. generated_prompts = [MULTIPLE * example_prompts[0]] @@ -96,20 +111,20 @@ def test_single_prompt( vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs) + vllm_outputs_no_cache, _ = _get_vLLM_outputs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( + vllm_outputs_cache_rep, _ = _get_vLLM_outputs( vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs, n_repetitions) + n_repetitions) for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): # In the first repetition, the caches are filled # In the second repetition, these caches are reused - check_logprobs_close( + # check_logprobs_close( + check_outputs_equal( outputs_0_lst=vllm_outputs_no_cache[0], outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", @@ -144,19 +159,20 @@ def test_single_prompt_mamba_size_alignment( """ Checks exact match decode vllm runner with and without prefix caching """ - MULTIPLE = 120 + MULTIPLE = 300 # Sample prompts. - generated_prompts = [MULTIPLE * example_prompts[0]] + # generated_prompts = [MULTIPLE * example_prompts[0]] + + generated_prompts = ["The president of the United States is " * MULTIPLE] max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs) + vllm_outputs_no_cache, _ = _get_vLLM_outputs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) vllm_runner_kwargs['enable_prefix_caching'] = True with vllm_runner(**vllm_runner_kwargs) as vllm_model: @@ -164,17 +180,16 @@ def test_single_prompt_mamba_size_alignment( mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ mamba_block_size - multiple = 2 + mamba_block_size_multiplier = 10 for offsets in [ - 3, mamba_block_size // 2 + 3, mamba_block_size - 3 + -3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3 ]: vllm_runner_kwargs[ - 'max_num_batched_tokens'] = multiple * mamba_block_size - \ + 'max_num_batched_tokens'] = mamba_block_size_multiplier * mamba_block_size - \ offsets - vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs, n_repetitions) + vllm_outputs_cache_rep, _ = _get_vLLM_outputs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, n_repetitions) # Check alignment of the output logits when using APC for r_idx, vllm_outputs_cache_itn in enumerate( @@ -182,7 +197,8 @@ def test_single_prompt_mamba_size_alignment( # In the first repetition, the caches are filled # In the second repetition, these caches are reused - check_logprobs_close( + # check_logprobs_close( + check_outputs_equal( outputs_0_lst=vllm_outputs_no_cache[0], outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", @@ -200,7 +216,7 @@ def test_single_prompt_mamba_size_alignment( # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_multiple_prompts_all_cached_output_logprobs( +def test_multiple_prompts_all_cached_outputs( hf_runner: HfRunner, vllm_runner: VllmRunner, example_prompts, @@ -217,7 +233,7 @@ def test_multiple_prompts_all_cached_output_logprobs( """ Checks exact match decode vllm runner with and without prefix caching """ - MULTIPLE = 120 + MULTIPLE = 300 # Sample prompts. generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] @@ -227,25 +243,100 @@ def test_multiple_prompts_all_cached_output_logprobs( vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs) + vllm_outputs_no_cache, _ = _get_vLLM_outputs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( + vllm_outputs_cache_rep, _ = _get_vLLM_outputs( vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs, n_repetitions) + n_repetitions) for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): # In the first repetition, the caches are filled # In the second repetition, these caches are reused - check_logprobs_close( + # check_logprobs_close( + check_outputs_equal( outputs_0_lst=vllm_outputs_no_cache[0], outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", ) + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multiple_prompts_mamba_size_alignment( + hf_runner: HfRunner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + mamba_ssm_cache_dtype: str, + tensor_parallel_size: int, + num_logprobs: int, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Checks exact match decode vllm runner with and without prefix caching + """ + MULTIPLE = 300 + + # Sample prompts. + # generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + + prompt_text = "The president of the United States is " + prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31] + generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, + enforce_eager, max_model_len, + dtype, tensor_parallel_size) + vllm_outputs_no_cache, _ = _get_vLLM_outputs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) + + vllm_runner_kwargs['enable_prefix_caching'] = True + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + # Retrieve the default mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ + mamba_block_size + + mamba_block_size_multiplier = 10 + for offsets in [ + -3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3 + ]: + + vllm_runner_kwargs[ + 'max_num_batched_tokens'] = mamba_block_size_multiplier * mamba_block_size - \ + offsets + vllm_outputs_cache_rep, _ = _get_vLLM_outputs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, n_repetitions) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate( + vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + # check_logprobs_close( + check_outputs_equal( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) @pytest.mark.parametrize("model", MODELS) @@ -258,7 +349,7 @@ def test_multiple_prompts_all_cached_output_logprobs( # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_multiple_prompts_partial_cached_output_logprobs( +def test_multiple_prompts_partial_cached_outputs( hf_runner: HfRunner, vllm_runner: VllmRunner, example_prompts, @@ -275,7 +366,7 @@ def test_multiple_prompts_partial_cached_output_logprobs( """ Checks exact match decode vllm runner with and without prefix caching """ - MULTIPLE = 120 + MULTIPLE = 300 # Sample prompts. generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] @@ -285,24 +376,23 @@ def test_multiple_prompts_partial_cached_output_logprobs( vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_no_cache, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs) + vllm_outputs_no_cache, _ = _get_vLLM_outputs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) # Cache only part of all the prompts vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_partial_cache, vllm_model = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, - num_logprobs) + vllm_outputs_partial_cache, vllm_model = _get_vLLM_outputs( + vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens) - check_logprobs_close( + # check_logprobs_close( + check_outputs_equal( outputs_0_lst=vllm_outputs_no_cache[0][:3], outputs_1_lst=vllm_outputs_partial_cache[0], name_0="vllm_no_cache", name_1="vllm_partial_cache", ) - vllm_outputs_cache_rep, _ = _get_vLLM_output_logprobs( + vllm_outputs_cache_rep, _ = _get_vLLM_outputs( vllm_runner, vllm_runner_kwargs, generated_prompts, @@ -315,7 +405,8 @@ def test_multiple_prompts_partial_cached_output_logprobs( # In the first repetition, the caches are filled # In the second repetition, these caches are reused - check_logprobs_close( + # check_logprobs_close( + check_outputs_equal( outputs_0_lst=vllm_outputs_no_cache[0], outputs_1_lst=vllm_outputs_cache_itn, name_0="vllm_no_cache", @@ -333,7 +424,7 @@ def test_multiple_prompts_partial_cached_output_logprobs( # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_specific_prompts_output_logprobs( +def test_specific_prompts_outputs( hf_runner: HfRunner, vllm_runner: VllmRunner, example_prompts, @@ -363,22 +454,22 @@ def test_specific_prompts_output_logprobs( vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, max_model_len, dtype, tensor_parallel_size) - vllm_outputs_logprobs_no_cache, _ = _get_vLLM_output_logprobs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs) + vllm_outputs_logprobs_no_cache, _ = _get_vLLM_outputs( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) # Cache only part of all the prompts vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_logprobs_cache_rep, _ = _get_vLLM_output_logprobs( + vllm_outputs_logprobs_cache_rep, _ = _get_vLLM_outputs( vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - num_logprobs, n_repetitions) + n_repetitions) for r_idx, vllm_outputs_logprobs_cache_itn in enumerate( vllm_outputs_logprobs_cache_rep): # In the first repetition, the caches are filled # In the second repetition, these caches are reused - check_logprobs_close( + # check_logprobs_close( + check_outputs_equal( outputs_0_lst=vllm_outputs_logprobs_no_cache[0], outputs_1_lst=vllm_outputs_logprobs_cache_itn, name_0="vllm_no_cache", diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 692cee0acec8..e125e1ac0134 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -738,6 +738,13 @@ def forward_cuda( #Note: # for decode always: current_first_idx_d == current_last_idx_d # at block boundaries: current_first_idx_d > last_state_idx_d + + # copy initial state to new location, + # as update kernel works in place + #if (current_last_idx_d > last_state_idx_d).any(): + # (skip IF as it breaks CUDA graphs) + conv_state[state_indices_tensor_d_output] = conv_state[ + state_indices_tensor_d_input] else: # Without caching, read and write in-place to the same blocks: state_indices_tensor_d_input = state_indices_tensor_d @@ -750,14 +757,12 @@ def forward_cuda( conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d, - current_last_idx=current_last_idx_d, - last_state_idx=last_state_idx_d, - ) + conv_state_indices=state_indices_tensor_d_output) hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( hidden_states_B_C_d) + # 3. State Space Model sequence transformation n_groups = self.n_groups // self.tp_size A_d = self.A[:, None, ...][:, :, None].expand( diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index d55083652ea1..4857f5da6808 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -43,9 +43,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching stride_istate_seq: tl.constexpr, stride_istate_dim: tl.constexpr, stride_istate_token: tl.constexpr, + stride_cache_indices: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, - stride_state_indices: tl.constexpr, stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M # others pad_slot_id: tl.constexpr, @@ -70,7 +70,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching # rather than mixing sequences - to make updating initial_states across sequences efficiently # single-sequence id - idx_seq = tl.load(batch_ptr + tl.program_id(0)) + idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64) chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) # BLOCK_N elements along the feature-dimension (channel) @@ -94,14 +94,14 @@ def _causal_conv1d_fwd_kernel( # continuous batching # Compute the offset where the first stride_block_m-aligned first full block is # Value in "token-space" - sequence_offset_token_index = sequence_completed_index % (stride_block_m * BLOCK_M) - - # Compute the last full cache block for the sequence - # Value in "token-space" - last_full_block_token_index = sequence_end_index - sequence_offset_token_index + B_size = (stride_block_m * BLOCK_M) + sequence_completed_offset_token = sequence_completed_index % B_size + seq_completed_offset = B_size - sequence_completed_offset_token + seq_end_offset = (seqlen - seq_completed_offset) % B_size + last_full_block_token_index = sequence_end_index - seq_end_offset # If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one - if sequence_offset_token_index == 0: - last_full_block_token_index = last_full_block_token_index - stride_block_m * BLOCK_M + if seq_end_offset == 0: + last_full_block_token_index = last_full_block_token_index - B_size # Get the number of blocks to be filled for the current sequence # If n_block_to_fill = 0, then only the state at the sequence end is stored @@ -113,6 +113,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching n_block_to_fill = 0 current_last_index = 0 conv_state_init_index = 0 + current_first_index = 0 sequence_offset_token_index = 0 last_full_block_token_index = 0 @@ -124,7 +125,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching # cache_idx conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + + idx_seq * stride_cache_indices + conv_state_init_index).to(tl.int64) if USE_PAD_SLOT: # noqa @@ -208,7 +209,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching # Compute the offset where the last block should be written in the conv_states conv_states_offset = tl.load(conv_state_indices_ptr + idx_seq * - stride_state_indices + + stride_cache_indices + current_last_index).to( tl.int64) @@ -329,22 +330,21 @@ def _causal_conv1d_fwd_kernel( # continuous batching last_full_block_token_index - (n_block_to_fill - chunk_offset) * stride_block_m * BLOCK_M - state_len) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] - x_ptrs = x_ptr + ((idx_tokens_last) * stride_x_token)[:, None] + ( + x_ptrs = x_ptr + (idx_tokens_last * stride_x_token)[:, None] + ( idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] mask_x = ( (idx_tokens_last >= 0)[:, None] & - (idx_tokens_last < seqlen)[:, None] & (idx_feats < dim)[None, :] + (idx_feats < dim)[None, :] ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + new_conv_state = tl.load(x_ptrs, mask_x, 1.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] # cache_idx conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - conv_state_init_index + - (chunk_offset - 1)).to(tl.int64) + idx_seq * stride_cache_indices + + current_first_index + (chunk_offset - 1)).to(tl.int64) conv_states_ptrs_target = ( conv_states_ptr + @@ -513,14 +513,12 @@ def causal_conv1d_fn( x = x.to(conv_states.dtype) out = torch.empty_like(x) if metadata is not None: - #cu_seqlen = metadata.cu_seqlen nums_dict = metadata.nums_dict - #x = metadata.x args = nums_dict batch_ptr = metadata.batch_ptr token_chunk_offset_ptr = metadata.token_chunk_offset_ptr else: - seqlens = np.diff(query_start_loc.to('cpu')) + seqlens = query_start_loc.diff().to('cpu') args = seqlens MAX_NUM_PROGRAMS = 1024 @@ -574,6 +572,7 @@ def causal_conv1d_fn( else: stride_o_dim = out.stride(1) stride_o_token = out.stride(2) + stride_cache_indices = cache_indices.stride(0) if cache_indices is not None else 0 if validate_data: assert x.dim() == 2 @@ -592,7 +591,7 @@ def causal_conv1d_fn( assert weight.stride(1) == 1 assert (dim, width) == weight.shape assert is_channel_last, "Need to run in channel-last layout" - if block_size_to_align > 0: + if block_size_to_align is not None and block_size_to_align > 0: assert block_size_to_align % BLOCK_M, "The mamba block size needs to be divisible by the BLOCK_M" else: block_size_to_align = BLOCK_M @@ -696,9 +695,9 @@ def grid(META): stride_istate_seq, stride_istate_dim, stride_istate_token, + stride_cache_indices, stride_o_dim, stride_o_token, - cache_indices.stride(0), block_size_to_align // BLOCK_M, # others pad_slot_id, @@ -724,11 +723,10 @@ def _causal_conv1d_update_kernel( w_ptr, # (dim, width) bias_ptr, conv_state_ptr, + cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, num_accepted_tokens_ptr, query_start_loc_ptr, # (batch + 1) - current_last_idx, - last_state_idx, o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -756,7 +754,7 @@ def _causal_conv1d_update_kernel( KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_VARLEN: tl.constexpr, - IS_CACHE_ENABLED: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, @@ -770,19 +768,13 @@ def _causal_conv1d_update_kernel( # [BLOCK_N,] elements along the feature-dimension (channel) idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - if IS_CACHE_ENABLED: - # Get the state from the last_state_idx - conv_state_init = tl.load(last_state_idx + idx_seq) - current_last_index = tl.load(current_last_idx + idx_seq) + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices).to( + tl.int64) else: - conv_state_init = 0 - current_last_index = 0 - - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - conv_state_init).to(tl.int64) - + conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa if conv_state_batch_coord == pad_slot_id: # not processing as this is not the actual sequence @@ -794,7 +786,7 @@ def _causal_conv1d_update_kernel( tl.int64) # revise state_len and seqlen state_len = state_len - (seqlen - - (query_end_index - query_start_index)) + (query_end_index - query_start_index)) seqlen = query_end_index - query_start_index x_offset = query_start_index * stride_x_token o_offset = query_start_index * stride_o_token @@ -880,16 +872,11 @@ def _causal_conv1d_update_kernel( new_conv_state = tl.where(mask, conv_state, loaded_x) - # Get the state from the last_state_idx - # cache_idx - conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - current_last_index).to(tl.int64) - conv_state_ptrs_target = ( - conv_state_ptr + - (conv_states_offset * stride_conv_state_seq) + # Offset from seq - (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] - idx_tokens * stride_conv_state_tok)[:, None] + conv_state_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + conv_state_ptrs_target = conv_state_base + ( + idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] tl.store(conv_state_ptrs_target, new_conv_state, mask) @@ -1036,9 +1023,6 @@ def causal_conv1d_update( query_start_loc: Optional[torch.Tensor] = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, - current_last_idx: Optional[torch.Tensor] = None, - last_state_idx: Optional[torch.Tensor] = None, - metadata=None, validate_data=False, ): """ @@ -1061,10 +1045,6 @@ def causal_conv1d_update( If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. - current_last_idx: (batch) int32 - The last cache block to be filled. This tensor indexes into cache_indices - last_state_idx: (batch) int32 - The cache block for the init values. This tensor indexes into cache_indices num_accepted_tokens: (batch,), dtype int32 If not None, it indicates the number of accepted tokens for each sequence in the batch. @@ -1164,11 +1144,10 @@ def grid(META): weight, bias, conv_state, + cache_seqlens, conv_state_indices, num_accepted_tokens, query_start_loc, - current_last_idx, - last_state_idx, out, # Matrix dimensions batch, @@ -1196,7 +1175,7 @@ def grid(META): KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_VARLEN=query_start_loc is not None, - IS_CACHE_ENABLED=current_last_idx is not None, + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, @@ -1204,4 +1183,4 @@ def grid(META): ) if unsqueeze: out = out.squeeze(-1) - return out.to(original_x_dtype) + return out.to(original_x_dtype) \ No newline at end of file From e3b8cfb8ec4c1a7f83248974647ee0264a2e55a0 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Wed, 1 Oct 2025 20:57:13 +0000 Subject: [PATCH 066/105] Reintegrated conv1D update changes Signed-off-by: Thomas Ortner --- .../layers/mamba/mamba_mixer2.py | 13 ++--- .../layers/mamba/ops/causal_conv1d.py | 53 ++++++++++++------- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index e125e1ac0134..692cee0acec8 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -738,13 +738,6 @@ def forward_cuda( #Note: # for decode always: current_first_idx_d == current_last_idx_d # at block boundaries: current_first_idx_d > last_state_idx_d - - # copy initial state to new location, - # as update kernel works in place - #if (current_last_idx_d > last_state_idx_d).any(): - # (skip IF as it breaks CUDA graphs) - conv_state[state_indices_tensor_d_output] = conv_state[ - state_indices_tensor_d_input] else: # Without caching, read and write in-place to the same blocks: state_indices_tensor_d_input = state_indices_tensor_d @@ -757,12 +750,14 @@ def forward_cuda( conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d_output) + conv_state_indices=state_indices_tensor_d, + current_last_idx=current_last_idx_d, + last_state_idx=last_state_idx_d, + ) hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( hidden_states_B_C_d) - # 3. State Space Model sequence transformation n_groups = self.n_groups // self.tp_size A_d = self.A[:, None, ...][:, :, None].expand( diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 4857f5da6808..df92cffbc666 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -35,7 +35,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching seqlen: tl.int32, # cu_seqlen num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines # Strides - stride_x_seq: tl.constexpr, # stride to get to next sequence, stride_x_dim: tl.constexpr, # stride to get to next feature-value, stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) stride_w_dim: tl.constexpr, # stride to get to next dim-axis value @@ -687,7 +686,6 @@ def grid(META): cu_seqlen, num_cache_lines, # stride - stride_x_seq, stride_x_dim, stride_x_token, stride_w_dim, @@ -723,10 +721,11 @@ def _causal_conv1d_update_kernel( w_ptr, # (dim, width) bias_ptr, conv_state_ptr, - cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, num_accepted_tokens_ptr, query_start_loc_ptr, # (batch + 1) + current_last_idx, + last_state_idx, o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -754,7 +753,7 @@ def _causal_conv1d_update_kernel( KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_VARLEN: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_CACHE_ENABLED: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, @@ -768,13 +767,19 @@ def _causal_conv1d_update_kernel( # [BLOCK_N,] elements along the feature-dimension (channel) idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices).to( - tl.int64) + if IS_CACHE_ENABLED: + # Get the state from the last_state_idx + conv_state_init = tl.load(last_state_idx + idx_seq) + current_last_index = tl.load(current_last_idx + idx_seq) else: - conv_state_batch_coord = idx_seq + conv_state_init = 0 + current_last_index = 0 + + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + conv_state_init).to(tl.int64) + if USE_PAD_SLOT: # noqa if conv_state_batch_coord == pad_slot_id: # not processing as this is not the actual sequence @@ -786,7 +791,7 @@ def _causal_conv1d_update_kernel( tl.int64) # revise state_len and seqlen state_len = state_len - (seqlen - - (query_end_index - query_start_index)) + (query_end_index - query_start_index)) seqlen = query_end_index - query_start_index x_offset = query_start_index * stride_x_token o_offset = query_start_index * stride_o_token @@ -872,11 +877,16 @@ def _causal_conv1d_update_kernel( new_conv_state = tl.where(mask, conv_state, loaded_x) - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = conv_state_base + ( - idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + # Get the state from the last_state_idx + # cache_idx + conv_states_offset = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + current_last_index).to(tl.int64) + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) + # Offset from seq + (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok)[:, None] mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] tl.store(conv_state_ptrs_target, new_conv_state, mask) @@ -1023,6 +1033,8 @@ def causal_conv1d_update( query_start_loc: Optional[torch.Tensor] = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, + current_last_idx: Optional[torch.Tensor] = None, + last_state_idx: Optional[torch.Tensor] = None, validate_data=False, ): """ @@ -1045,6 +1057,10 @@ def causal_conv1d_update( If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. + current_last_idx: (batch) int32 + The last cache block to be filled. This tensor indexes into cache_indices + last_state_idx: (batch) int32 + The cache block for the init values. This tensor indexes into cache_indices num_accepted_tokens: (batch,), dtype int32 If not None, it indicates the number of accepted tokens for each sequence in the batch. @@ -1144,10 +1160,11 @@ def grid(META): weight, bias, conv_state, - cache_seqlens, conv_state_indices, num_accepted_tokens, query_start_loc, + current_last_idx, + last_state_idx, out, # Matrix dimensions batch, @@ -1175,7 +1192,7 @@ def grid(META): KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_VARLEN=query_start_loc is not None, - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_CACHE_ENABLED=current_last_idx is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, From 0bb519793c746df1bbce30b4445f1d6339216281 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 1 Oct 2025 17:15:35 -0400 Subject: [PATCH 067/105] Precommit fixes Signed-off-by: Stanislaw Wozniak --- .../layers/mamba/mamba_mixer2.py | 4 +- .../layers/mamba/ops/ssd_combined.py | 40 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 692cee0acec8..48cc022343e5 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -590,7 +590,7 @@ def forward_cuda( _, current_first_idx_p = torch.split( attn_metadata.current_first_token_block_idx, [num_decodes, num_prefills], - dim=0) + dim=0) _, seq_lens_completed_p = torch.split( attn_metadata.seq_lens_completed, [num_decodes, num_prefills], dim=0) @@ -722,7 +722,7 @@ def forward_cuda( varlen_states[last_chunk_indices_p] else: # update ssm states - # - varlen state is a (num_prefills, nheads, headdim, dstate) + # - varlen state is a (num_prefills, nheads, headdim, dstate) # tensor ssm_state[state_indices_tensor_p] = varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index ae839a610659..e9e589115b8a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -151,7 +151,7 @@ def _mamba_chunk_scan_combined_fwd(x, z=z, initial_states=initial_states, ) - + if return_intermediate_states: return states else: @@ -159,25 +159,25 @@ def _mamba_chunk_scan_combined_fwd(x, def mamba_chunk_scan_combined_varlen( - x, - dt, - A, - B, - C, - chunk_size, - cu_seqlens, - cu_chunk_seqlens, - last_chunk_indices, - seq_idx, - out, - D=None, - z=None, - dt_bias=None, - initial_states=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - return_intermediate_states=False, - state_dtype=None, + x, + dt, + A, + B, + C, + chunk_size, + cu_seqlens, + cu_chunk_seqlens, + last_chunk_indices, + seq_idx, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_intermediate_states=False, + state_dtype=None, ): """ Argument: From c911c884b68476dc1df87ea77279f5fcbbd3cef0 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Wed, 1 Oct 2025 22:10:20 +0000 Subject: [PATCH 068/105] Moved APC tests: test_hybrid.py; pre-commit clean Signed-off-by: Thomas Ortner --- .../models/language/generation/test_hybrid.py | 399 +++++++++++++++++- .../layers/mamba/mamba_mixer2.py | 4 +- .../layers/mamba/ops/causal_conv1d.py | 55 ++- 3 files changed, 427 insertions(+), 31 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index e60a86075b8b..483b4543380f 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -8,7 +8,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams -from ...utils import check_logprobs_close +from ...utils import check_logprobs_close, check_outputs_equal # Mark all tests as hybrid pytestmark = pytest.mark.hybrid_model @@ -332,3 +332,400 @@ def test_fp32_cache_state( name_0="hf", name_1="vllm", ) + + +# Helper functions for the APC tests +def _get_vllm_runner_params(model, enforce_eager, max_model_len): + return { + 'model_name': model, + 'enable_prefix_caching': False, + 'enforce_eager': enforce_eager, + 'max_model_len': max_model_len, + 'disable_cascade_attn': True, ## not verified yet + 'disable_log_stats': False, ## collect APC stats + 'gpu_memory_utilization': 0.4 + } + + +def _get_vLLM_output(vllm_runner, + kwargs, + prompts, + max_tokens, + num_logprobs, + num_repetitions=1, + vllm_model=None): + outs = [] + if vllm_model is None: + vllm_model = vllm_runner(**kwargs) + for _ in range(num_repetitions): + if num_logprobs < 0: + vllm_output = vllm_model.generate_greedy(prompts, max_tokens) + else: + vllm_output = vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs) + outs.append(vllm_output) + + return outs, vllm_model + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("cache_dtype_param", + ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) +def test_apc_single_prompt( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + num_logprobs: int, + cache_dtype_param: str, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator = check_logprobs_close if num_logprobs > 0 else check_outputs_equal + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * example_prompts[0]] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, + max_model_len) + vllm_runner_kwargs[cache_dtype_param] = "float32" + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) + + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("cache_dtype_param", + ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) +def test_apc_single_prompt_block_align_alignment( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + num_logprobs: int, + cache_dtype_param: str, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator = check_logprobs_close if num_logprobs > 0 else check_outputs_equal + + MULTIPLE = 300 + + # Sample prompts. This custom prompt is used, as it causes the most issues + generated_prompts = ["The president of the United States is " * MULTIPLE] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, + max_model_len) + vllm_runner_kwargs[cache_dtype_param] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) + + vllm_runner_kwargs['enable_prefix_caching'] = True + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + # Retrieve the default mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ + mamba_block_size + + mamba_block_size_multiplier = 10 + for offsets in [ + -3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3 + ]: + + vllm_runner_kwargs[ + 'max_num_batched_tokens'] = mamba_block_size_multiplier * mamba_block_size - \ + offsets + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, num_logprobs, + n_repetitions) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("cache_dtype_param", + ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) +def test_apc_multiple_prompts_all_cached_outputs( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + num_logprobs: int, + cache_dtype_param: str, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator = check_logprobs_close if num_logprobs > 0 else check_outputs_equal + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, + max_model_len) + vllm_runner_kwargs[cache_dtype_param] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) + + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("cache_dtype_param", + ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) +def test_apc_multiple_prompts_block_align_alignment( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + num_logprobs: int, + cache_dtype_param: str, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator = check_logprobs_close if num_logprobs > 0 else check_outputs_equal + + MULTIPLE = 300 + + # Sample prompts. This custom prompt is used, as it causes the most issues + prompt_text = "The president of the United States is " + prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31] + generated_prompts = [ + prompt_text[offset:] * MULTIPLE for offset in prompt_offsets + ] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, + max_model_len) + vllm_runner_kwargs[cache_dtype_param] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) + + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +@pytest.mark.parametrize("enforce_eager", [True]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("cache_dtype_param", + ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) +def test_apc_multiple_prompts_partial_cached_outputs( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + enforce_eager: bool, + num_logprobs: int, + cache_dtype_param: str, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator = check_logprobs_close if num_logprobs > 0 else check_outputs_equal + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, + max_model_len) + vllm_runner_kwargs[cache_dtype_param] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) + + # Cache only part of all the prompts + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_partial_cache, vllm_model = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, + num_logprobs) + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0][:3], + outputs_1_lst=vllm_outputs_partial_cache[0], + name_0="vllm_no_cache", + name_1="vllm_partial_cache", + ) + + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + vllm_model=vllm_model) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 692cee0acec8..48cc022343e5 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -590,7 +590,7 @@ def forward_cuda( _, current_first_idx_p = torch.split( attn_metadata.current_first_token_block_idx, [num_decodes, num_prefills], - dim=0) + dim=0) _, seq_lens_completed_p = torch.split( attn_metadata.seq_lens_completed, [num_decodes, num_prefills], dim=0) @@ -722,7 +722,7 @@ def forward_cuda( varlen_states[last_chunk_indices_p] else: # update ssm states - # - varlen state is a (num_prefills, nheads, headdim, dstate) + # - varlen state is a (num_prefills, nheads, headdim, dstate) # tensor ssm_state[state_indices_tensor_p] = varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index df92cffbc666..176b5c9148e4 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -36,7 +36,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines # Strides stride_x_dim: tl.constexpr, # stride to get to next feature-value, - stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_x_token: tl. + constexpr, # stride to get to next token (same feature-index, same sequence-index) stride_w_dim: tl.constexpr, # stride to get to next dim-axis value stride_w_width: tl.constexpr, # stride to get to next width-axis value stride_istate_seq: tl.constexpr, @@ -45,7 +46,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching stride_cache_indices: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, - stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M + stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M # others pad_slot_id: tl.constexpr, # Meta-parameters @@ -85,7 +86,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching if IS_CACHE_ENABLED: # Handle the case if prefix caching is enabled - + # Get the length of the completed sequence so far and compute the offset. current_first_index = tl.load(current_first_idx + idx_seq) current_last_index = tl.load(current_last_idx + idx_seq) @@ -126,14 +127,14 @@ def _causal_conv1d_fwd_kernel( # continuous batching conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index).to(tl.int64) - + if USE_PAD_SLOT: # noqa if conv_state_batch_coord == pad_slot_id: # not processing as this is not the actual sequence return conv_states_base = (conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] @@ -142,8 +143,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] if chunk_offset == 0: # read from conv_states - load_init_state = tl.load(has_initial_states_ptr + idx_seq).to( - tl.int1) + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) if load_init_state: # load from conv_states prior_tokens = conv_states_base + (state_len - @@ -207,10 +207,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching # Compute the offset where the last block should be written in the conv_states conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * - stride_cache_indices + - current_last_index).to( - tl.int64) + idx_seq * stride_cache_indices + + current_last_index).to(tl.int64) conv_states_ptrs_target = ( conv_states_ptr + (conv_states_offset * stride_conv_state_seq) @@ -320,7 +318,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching # Store intermediate states aligned with stride_cache_chunk # The states are cached starting from the last stride_block_m. - # For example: + # For example: # If n_block_to_fill = 0, then the state at the sequence is cached. # If n_block_to_fill > 0, then the states at the sequence and at the n_block_to_fill-last stride_block_m are cached. if (chunk_offset - 1) < n_block_to_fill: @@ -333,8 +331,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] mask_x = ( - (idx_tokens_last >= 0)[:, None] & - (idx_feats < dim)[None, :] + (idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[None, :] ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) new_conv_state = tl.load(x_ptrs, mask_x, 1.0) @@ -343,16 +340,17 @@ def _causal_conv1d_fwd_kernel( # continuous batching # cache_idx conv_states_offset = tl.load(conv_state_indices_ptr + idx_seq * stride_cache_indices + - current_first_index + (chunk_offset - 1)).to(tl.int64) - + current_first_index + + (chunk_offset - 1)).to(tl.int64) + conv_states_ptrs_target = ( - conv_states_ptr + - (conv_states_offset * stride_conv_state_seq) + # Offset from seq + conv_states_ptr + (conv_states_offset * stride_conv_state_seq) + + # Offset from seq (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] idx_tokens_conv * stride_conv_state_tok)[:, None] mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats - < dim)[None, :] + < dim)[None, :] tl.debug_barrier() # NOTE: use this due to bug in Triton compiler tl.store(conv_states_ptrs_target, new_conv_state, mask) @@ -448,7 +446,7 @@ def causal_conv1d_fn( current_last_idx: Optional[torch.Tensor] = None, last_state_idx: Optional[torch.Tensor] = None, seq_lens_completed: Optional[torch.Tensor] = None, - block_size_to_align: Optional[int] = 0, + block_size_to_align=0, metadata=None, validate_data=False, ): @@ -571,7 +569,8 @@ def causal_conv1d_fn( else: stride_o_dim = out.stride(1) stride_o_token = out.stride(2) - stride_cache_indices = cache_indices.stride(0) if cache_indices is not None else 0 + stride_cache_indices = cache_indices.stride( + 0) if cache_indices is not None else 0 if validate_data: assert x.dim() == 2 @@ -777,8 +776,8 @@ def _causal_conv1d_update_kernel( # cache_idx conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - conv_state_init).to(tl.int64) + idx_seq * stride_state_indices + + conv_state_init).to(tl.int64) if USE_PAD_SLOT: # noqa if conv_state_batch_coord == pad_slot_id: @@ -791,7 +790,7 @@ def _causal_conv1d_update_kernel( tl.int64) # revise state_len and seqlen state_len = state_len - (seqlen - - (query_end_index - query_start_index)) + (query_end_index - query_start_index)) seqlen = query_end_index - query_start_index x_offset = query_start_index * stride_x_token o_offset = query_start_index * stride_o_token @@ -880,8 +879,8 @@ def _causal_conv1d_update_kernel( # Get the state from the last_state_idx # cache_idx conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - current_last_index).to(tl.int64) + idx_seq * stride_state_indices + + current_last_index).to(tl.int64) conv_state_ptrs_target = ( conv_state_ptr + (conv_states_offset * stride_conv_state_seq) + # Offset from seq @@ -1200,4 +1199,4 @@ def grid(META): ) if unsqueeze: out = out.squeeze(-1) - return out.to(original_x_dtype) \ No newline at end of file + return out.to(original_x_dtype) From cf5c4c7cbda519ee7beabcd1487756450524ccaf Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Wed, 1 Oct 2025 22:12:49 +0000 Subject: [PATCH 069/105] Merge: Deleted old test file Signed-off-by: Thomas Ortner --- tests/v1/core/test_APC_hybrid_models.py | 477 ------------------------ 1 file changed, 477 deletions(-) delete mode 100644 tests/v1/core/test_APC_hybrid_models.py diff --git a/tests/v1/core/test_APC_hybrid_models.py b/tests/v1/core/test_APC_hybrid_models.py deleted file mode 100644 index 7f25ca87da23..000000000000 --- a/tests/v1/core/test_APC_hybrid_models.py +++ /dev/null @@ -1,477 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Compare the outputs of HF and vLLM when using greedy sampling. - -It tests automated prefix caching (APC). APC can be enabled by -enable_prefix_caching=True. - -Run `pytest tests/basic_correctness/test_APC_hybrid_models.py`. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from tests.models.utils import check_logprobs_close, check_outputs_equal - -if TYPE_CHECKING: - from ...conftest import HfRunner, VllmRunner - -MODELS = [ - "ibm-granite/granite-4.0-tiny-preview", - # "hmellor/tiny-random-BambaForCausalLM", -] - - -def _get_vllm_runner_params(model, mamba_ssm_cache_dtype, enforce_eager, - max_model_len, dtype, tensor_parallel_size): - return { - 'model_name': model, - 'mamba_ssm_cache_dtype': mamba_ssm_cache_dtype, - 'enable_prefix_caching': False, - 'enforce_eager': enforce_eager, - 'max_model_len': max_model_len, - 'dtype': dtype, - 'tensor_parallel_size': tensor_parallel_size, - 'disable_cascade_attn': True, ## not verified yet - 'disable_log_stats': False, ## collect APC stats - 'gpu_memory_utilization': 0.4 - } - - -def _get_vLLM_outputs(vllm_runner, - kwargs, - prompts, - max_tokens, - num_repetitions=1, - vllm_model=None): - outs = [] - if vllm_model is None: - vllm_model = vllm_runner(**kwargs) - for _ in range(num_repetitions): - outs.append( - vllm_model.generate_greedy(prompts, max_tokens)) - - return outs, vllm_model - -def _get_vLLM_logprobs(vllm_runner, - kwargs, - prompts, - max_tokens, - num_logprobs, - num_repetitions=1, - vllm_model=None): - outs = [] - if vllm_model is None: - vllm_model = vllm_runner(**kwargs) - for _ in range(num_repetitions): - outs.append( - vllm_model.generate_greedy_logprobs(prompts, max_tokens, - num_logprobs)) - - return outs, vllm_model - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_single_prompt( - hf_runner: HfRunner, - vllm_runner: VllmRunner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - n_repetitions: int, - enforce_eager: bool, - mamba_ssm_cache_dtype: str, - tensor_parallel_size: int, - num_logprobs: int, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Checks exact match decode vllm runner with and without prefix caching - """ - MULTIPLE = 300 - - # Sample prompts. - generated_prompts = [MULTIPLE * example_prompts[0]] - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, - enforce_eager, max_model_len, - dtype, tensor_parallel_size) - vllm_outputs_no_cache, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) - - vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_cache_rep, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - n_repetitions) - - for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - # check_logprobs_close( - check_outputs_equal( - outputs_0_lst=vllm_outputs_no_cache[0], - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_single_prompt_mamba_size_alignment( - hf_runner: HfRunner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - n_repetitions: int, - enforce_eager: bool, - mamba_ssm_cache_dtype: str, - tensor_parallel_size: int, - num_logprobs: int, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Checks exact match decode vllm runner with and without prefix caching - """ - MULTIPLE = 300 - - # Sample prompts. - # generated_prompts = [MULTIPLE * example_prompts[0]] - - generated_prompts = ["The president of the United States is " * MULTIPLE] - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, - enforce_eager, max_model_len, - dtype, tensor_parallel_size) - vllm_outputs_no_cache, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) - - vllm_runner_kwargs['enable_prefix_caching'] = True - with vllm_runner(**vllm_runner_kwargs) as vllm_model: - # Retrieve the default mamba state block size - mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ - mamba_block_size - - mamba_block_size_multiplier = 10 - for offsets in [ - -3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3 - ]: - - vllm_runner_kwargs[ - 'max_num_batched_tokens'] = mamba_block_size_multiplier * mamba_block_size - \ - offsets - vllm_outputs_cache_rep, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, n_repetitions) - - # Check alignment of the output logits when using APC - for r_idx, vllm_outputs_cache_itn in enumerate( - vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - # check_logprobs_close( - check_outputs_equal( - outputs_0_lst=vllm_outputs_no_cache[0], - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_multiple_prompts_all_cached_outputs( - hf_runner: HfRunner, - vllm_runner: VllmRunner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - n_repetitions: int, - enforce_eager: bool, - mamba_ssm_cache_dtype: str, - tensor_parallel_size: int, - num_logprobs: int, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Checks exact match decode vllm runner with and without prefix caching - """ - MULTIPLE = 300 - - # Sample prompts. - generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, - enforce_eager, max_model_len, - dtype, tensor_parallel_size) - vllm_outputs_no_cache, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) - - vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_cache_rep, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - n_repetitions) - - for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - # check_logprobs_close( - check_outputs_equal( - outputs_0_lst=vllm_outputs_no_cache[0], - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_multiple_prompts_mamba_size_alignment( - hf_runner: HfRunner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - n_repetitions: int, - enforce_eager: bool, - mamba_ssm_cache_dtype: str, - tensor_parallel_size: int, - num_logprobs: int, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Checks exact match decode vllm runner with and without prefix caching - """ - MULTIPLE = 300 - - # Sample prompts. - # generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] - - prompt_text = "The president of the United States is " - prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31] - generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets] - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, - enforce_eager, max_model_len, - dtype, tensor_parallel_size) - vllm_outputs_no_cache, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) - - vllm_runner_kwargs['enable_prefix_caching'] = True - with vllm_runner(**vllm_runner_kwargs) as vllm_model: - # Retrieve the default mamba state block size - mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ - mamba_block_size - - mamba_block_size_multiplier = 10 - for offsets in [ - -3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3 - ]: - - vllm_runner_kwargs[ - 'max_num_batched_tokens'] = mamba_block_size_multiplier * mamba_block_size - \ - offsets - vllm_outputs_cache_rep, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, n_repetitions) - - # Check alignment of the output logits when using APC - for r_idx, vllm_outputs_cache_itn in enumerate( - vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - # check_logprobs_close( - check_outputs_equal( - outputs_0_lst=vllm_outputs_no_cache[0], - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_multiple_prompts_partial_cached_outputs( - hf_runner: HfRunner, - vllm_runner: VllmRunner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - n_repetitions: int, - enforce_eager: bool, - mamba_ssm_cache_dtype: str, - tensor_parallel_size: int, - num_logprobs: int, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Checks exact match decode vllm runner with and without prefix caching - """ - MULTIPLE = 300 - - # Sample prompts. - generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, - enforce_eager, max_model_len, - dtype, tensor_parallel_size) - vllm_outputs_no_cache, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) - - # Cache only part of all the prompts - vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_partial_cache, vllm_model = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens) - - # check_logprobs_close( - check_outputs_equal( - outputs_0_lst=vllm_outputs_no_cache[0][:3], - outputs_1_lst=vllm_outputs_partial_cache[0], - name_0="vllm_no_cache", - name_1="vllm_partial_cache", - ) - - vllm_outputs_cache_rep, _ = _get_vLLM_outputs( - vllm_runner, - vllm_runner_kwargs, - generated_prompts, - max_tokens, - num_logprobs, - n_repetitions, - vllm_model=vllm_model) - - for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - # check_logprobs_close( - check_outputs_equal( - outputs_0_lst=vllm_outputs_no_cache[0], - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("mamba_ssm_cache_dtype", ['auto', 'float32']) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_specific_prompts_outputs( - hf_runner: HfRunner, - vllm_runner: VllmRunner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - n_repetitions: int, - enforce_eager: bool, - mamba_ssm_cache_dtype: str, - tensor_parallel_size: int, - num_logprobs: int, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Checks exact match decode vllm runner with and without prefix caching - """ - - generated_prompts = [ - "Hello, my name is John Smith and I work at " * 100, - "The president of the United States is " * 200, - "The capital of France is something like" * 200, - "The future of AI is " * 300, - ] - - max_model_len = max( - len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, mamba_ssm_cache_dtype, - enforce_eager, max_model_len, - dtype, tensor_parallel_size) - vllm_outputs_logprobs_no_cache, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens) - - # Cache only part of all the prompts - vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_logprobs_cache_rep, _ = _get_vLLM_outputs( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, - n_repetitions) - - for r_idx, vllm_outputs_logprobs_cache_itn in enumerate( - vllm_outputs_logprobs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused - - # check_logprobs_close( - check_outputs_equal( - outputs_0_lst=vllm_outputs_logprobs_no_cache[0], - outputs_1_lst=vllm_outputs_logprobs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) From 618fe530d3e8f26c637266f17c7583bb07982314 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 2 Oct 2025 03:16:13 -0400 Subject: [PATCH 070/105] Precommit fixes Signed-off-by: Stanislaw Wozniak --- .../models/language/generation/test_hybrid.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 483b4543380f..b07bdac3d3a2 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -398,7 +398,8 @@ def test_apc_single_prompt( except ValueError: pass - compare_operator = check_logprobs_close if num_logprobs > 0 else check_outputs_equal + compare_operator = check_logprobs_close if num_logprobs > 0 \ + else check_outputs_equal MULTIPLE = 300 @@ -463,7 +464,8 @@ def test_apc_single_prompt_block_align_alignment( except ValueError: pass - compare_operator = check_logprobs_close if num_logprobs > 0 else check_outputs_equal + compare_operator = check_logprobs_close if num_logprobs > 0 \ + else check_outputs_equal MULTIPLE = 300 @@ -493,8 +495,8 @@ def test_apc_single_prompt_block_align_alignment( ]: vllm_runner_kwargs[ - 'max_num_batched_tokens'] = mamba_block_size_multiplier * mamba_block_size - \ - offsets + 'max_num_batched_tokens'] = mamba_block_size_multiplier * \ + mamba_block_size - offsets vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, generated_prompts, @@ -544,7 +546,8 @@ def test_apc_multiple_prompts_all_cached_outputs( except ValueError: pass - compare_operator = check_logprobs_close if num_logprobs > 0 else check_outputs_equal + compare_operator = check_logprobs_close if num_logprobs > 0 \ + else check_outputs_equal MULTIPLE = 300 @@ -610,7 +613,8 @@ def test_apc_multiple_prompts_block_align_alignment( except ValueError: pass - compare_operator = check_logprobs_close if num_logprobs > 0 else check_outputs_equal + compare_operator = check_logprobs_close if num_logprobs > 0 \ + else check_outputs_equal MULTIPLE = 300 @@ -680,7 +684,8 @@ def test_apc_multiple_prompts_partial_cached_outputs( except ValueError: pass - compare_operator = check_logprobs_close if num_logprobs > 0 else check_outputs_equal + compare_operator = check_logprobs_close if num_logprobs > 0 \ + else check_outputs_equal MULTIPLE = 300 From 9dd6b81debb5fe3fe6a28381230c4a611e631637 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 2 Oct 2025 03:27:25 -0400 Subject: [PATCH 071/105] Precommit fixes Signed-off-by: Stanislaw Wozniak --- tests/models/language/generation/test_hybrid.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index b07bdac3d3a2..13a2c5b46574 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable + import pytest from tests.models.registry import HF_EXAMPLE_MODELS @@ -398,7 +400,7 @@ def test_apc_single_prompt( except ValueError: pass - compare_operator = check_logprobs_close if num_logprobs > 0 \ + compare_operator: Callable = check_logprobs_close if num_logprobs > 0 \ else check_outputs_equal MULTIPLE = 300 @@ -464,7 +466,7 @@ def test_apc_single_prompt_block_align_alignment( except ValueError: pass - compare_operator = check_logprobs_close if num_logprobs > 0 \ + compare_operator: Callable = check_logprobs_close if num_logprobs > 0 \ else check_outputs_equal MULTIPLE = 300 @@ -546,7 +548,7 @@ def test_apc_multiple_prompts_all_cached_outputs( except ValueError: pass - compare_operator = check_logprobs_close if num_logprobs > 0 \ + compare_operator: Callable = check_logprobs_close if num_logprobs > 0 \ else check_outputs_equal MULTIPLE = 300 @@ -613,7 +615,7 @@ def test_apc_multiple_prompts_block_align_alignment( except ValueError: pass - compare_operator = check_logprobs_close if num_logprobs > 0 \ + compare_operator: Callable = check_logprobs_close if num_logprobs > 0 \ else check_outputs_equal MULTIPLE = 300 @@ -684,7 +686,7 @@ def test_apc_multiple_prompts_partial_cached_outputs( except ValueError: pass - compare_operator = check_logprobs_close if num_logprobs > 0 \ + compare_operator: Callable = check_logprobs_close if num_logprobs > 0 \ else check_outputs_equal MULTIPLE = 300 From 63e9217ad358f57353d18c842cf6f4e994184f45 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 2 Oct 2025 03:45:55 -0400 Subject: [PATCH 072/105] Precommit fixes Signed-off-by: Stanislaw Wozniak --- .../models/language/generation/test_hybrid.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 13a2c5b46574..2e01f6f15948 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable +from typing import Any, Callable import pytest @@ -400,8 +400,8 @@ def test_apc_single_prompt( except ValueError: pass - compare_operator: Callable = check_logprobs_close if num_logprobs > 0 \ - else check_outputs_equal + compare_operator: Callable[..., Any] = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal MULTIPLE = 300 @@ -466,8 +466,8 @@ def test_apc_single_prompt_block_align_alignment( except ValueError: pass - compare_operator: Callable = check_logprobs_close if num_logprobs > 0 \ - else check_outputs_equal + compare_operator: Callable[..., Any] = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal MULTIPLE = 300 @@ -548,8 +548,8 @@ def test_apc_multiple_prompts_all_cached_outputs( except ValueError: pass - compare_operator: Callable = check_logprobs_close if num_logprobs > 0 \ - else check_outputs_equal + compare_operator: Callable[..., Any] = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal MULTIPLE = 300 @@ -615,8 +615,8 @@ def test_apc_multiple_prompts_block_align_alignment( except ValueError: pass - compare_operator: Callable = check_logprobs_close if num_logprobs > 0 \ - else check_outputs_equal + compare_operator: Callable[..., Any] = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal MULTIPLE = 300 @@ -686,8 +686,8 @@ def test_apc_multiple_prompts_partial_cached_outputs( except ValueError: pass - compare_operator: Callable = check_logprobs_close if num_logprobs > 0 \ - else check_outputs_equal + compare_operator: Callable[..., Any] = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal MULTIPLE = 300 From c0eed4af02b431fa8ffed8b6775129bd65e7b542 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 2 Oct 2025 07:00:51 -0400 Subject: [PATCH 073/105] Precommit fixes Signed-off-by: Stanislaw Wozniak --- .../models/language/generation/test_hybrid.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2e01f6f15948..0b62421ed730 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable +from typing import Callable import pytest @@ -400,8 +400,8 @@ def test_apc_single_prompt( except ValueError: pass - compare_operator: Callable[..., Any] = check_logprobs_close \ - if num_logprobs > 0 else check_outputs_equal + compare_operator: Callable = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal # type: ignore MULTIPLE = 300 @@ -466,8 +466,8 @@ def test_apc_single_prompt_block_align_alignment( except ValueError: pass - compare_operator: Callable[..., Any] = check_logprobs_close \ - if num_logprobs > 0 else check_outputs_equal + compare_operator: Callable = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal # type: ignore MULTIPLE = 300 @@ -548,8 +548,8 @@ def test_apc_multiple_prompts_all_cached_outputs( except ValueError: pass - compare_operator: Callable[..., Any] = check_logprobs_close \ - if num_logprobs > 0 else check_outputs_equal + compare_operator: Callable = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal # type: ignore MULTIPLE = 300 @@ -615,8 +615,8 @@ def test_apc_multiple_prompts_block_align_alignment( except ValueError: pass - compare_operator: Callable[..., Any] = check_logprobs_close \ - if num_logprobs > 0 else check_outputs_equal + compare_operator: Callable = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal # type: ignore MULTIPLE = 300 @@ -686,8 +686,8 @@ def test_apc_multiple_prompts_partial_cached_outputs( except ValueError: pass - compare_operator: Callable[..., Any] = check_logprobs_close \ - if num_logprobs > 0 else check_outputs_equal + compare_operator: Callable = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal # type: ignore MULTIPLE = 300 From 1425b7335a0d7bb6e739f358e392255d1a29c1e1 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Thu, 2 Oct 2025 15:53:38 +0000 Subject: [PATCH 074/105] Precommit fixes and documentation Signed-off-by: Thomas Ortner --- .../models/language/generation/test_hybrid.py | 88 ++++++++++++------- .../layers/mamba/mamba_mixer2.py | 18 +++- .../layers/mamba/ops/causal_conv1d.py | 40 ++++----- 3 files changed, 91 insertions(+), 55 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 0b62421ed730..7cf819c751a9 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -337,12 +337,13 @@ def test_fp32_cache_state( # Helper functions for the APC tests -def _get_vllm_runner_params(model, enforce_eager, max_model_len): +def _get_vllm_runner_params(model, enforce_eager, max_model_len, tensor_parallel_size=1): return { 'model_name': model, 'enable_prefix_caching': False, 'enforce_eager': enforce_eager, 'max_model_len': max_model_len, + 'tensor_parallel_size': tensor_parallel_size, 'disable_cascade_attn': True, ## not verified yet 'disable_log_stats': False, ## collect APC stats 'gpu_memory_utilization': 0.4 @@ -378,6 +379,7 @@ def _get_vLLM_output(vllm_runner, # of the test is executed using `check_outputs_equal` # instead of `check_logprobs_close` @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_apc_single_prompt( @@ -390,6 +392,7 @@ def test_apc_single_prompt( n_repetitions: int, enforce_eager: bool, num_logprobs: int, + tensor_parallel_size: int, cache_dtype_param: str, ) -> None: @@ -411,7 +414,8 @@ def test_apc_single_prompt( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, - max_model_len) + max_model_len, + tensor_parallel_size=tensor_parallel_size) vllm_runner_kwargs[cache_dtype_param] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, @@ -436,7 +440,7 @@ def test_apc_single_prompt( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[4]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) @pytest.mark.parametrize("enforce_eager", [True]) @@ -444,6 +448,7 @@ def test_apc_single_prompt( # of the test is executed using `check_outputs_equal` # instead of `check_logprobs_close` @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_apc_single_prompt_block_align_alignment( @@ -456,6 +461,7 @@ def test_apc_single_prompt_block_align_alignment( n_repetitions: int, enforce_eager: bool, num_logprobs: int, + tensor_parallel_size: int, cache_dtype_param: str, ) -> None: @@ -477,7 +483,8 @@ def test_apc_single_prompt_block_align_alignment( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, - max_model_len) + max_model_len, + tensor_parallel_size=tensor_parallel_size) vllm_runner_kwargs[cache_dtype_param] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, @@ -526,6 +533,7 @@ def test_apc_single_prompt_block_align_alignment( # of the test is executed using `check_outputs_equal` # instead of `check_logprobs_close` @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_apc_multiple_prompts_all_cached_outputs( @@ -538,6 +546,7 @@ def test_apc_multiple_prompts_all_cached_outputs( n_repetitions: int, enforce_eager: bool, num_logprobs: int, + tensor_parallel_size: int, cache_dtype_param: str, ) -> None: @@ -559,7 +568,8 @@ def test_apc_multiple_prompts_all_cached_outputs( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, - max_model_len) + max_model_len, + tensor_parallel_size=tensor_parallel_size) vllm_runner_kwargs[cache_dtype_param] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, @@ -583,8 +593,7 @@ def test_apc_multiple_prompts_all_cached_outputs( name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", ) - - + @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) @@ -593,6 +602,7 @@ def test_apc_multiple_prompts_all_cached_outputs( # of the test is executed using `check_outputs_equal` # instead of `check_logprobs_close` @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_apc_multiple_prompts_block_align_alignment( @@ -605,6 +615,7 @@ def test_apc_multiple_prompts_block_align_alignment( n_repetitions: int, enforce_eager: bool, num_logprobs: int, + tensor_parallel_size: int, cache_dtype_param: str, ) -> None: @@ -617,43 +628,53 @@ def test_apc_multiple_prompts_block_align_alignment( compare_operator: Callable = check_logprobs_close \ if num_logprobs > 0 else check_outputs_equal # type: ignore - + MULTIPLE = 300 # Sample prompts. This custom prompt is used, as it causes the most issues prompt_text = "The president of the United States is " prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31] - generated_prompts = [ - prompt_text[offset:] * MULTIPLE for offset in prompt_offsets - ] + generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets] max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, - max_model_len) + vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, max_model_len, + tensor_parallel_size) vllm_runner_kwargs[cache_dtype_param] = "float32" - - vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, - vllm_runner_kwargs, - generated_prompts, max_tokens, - num_logprobs) + + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs) vllm_runner_kwargs['enable_prefix_caching'] = True - vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, - vllm_runner_kwargs, - generated_prompts, max_tokens, - num_logprobs, n_repetitions) + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + # Retrieve the default mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ + mamba_block_size - for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): - # In the first repetition, the caches are filled - # In the second repetition, these caches are reused + mamba_block_size_multiplier = 10 + for offsets in [ + -3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3 + ]: - compare_operator( - outputs_0_lst=vllm_outputs_no_cache[0], - outputs_1_lst=vllm_outputs_cache_itn, - name_0="vllm_no_cache", - name_1=f"vllm_cache_it_{r_idx + 1}", - ) + vllm_runner_kwargs[ + 'max_num_batched_tokens'] = mamba_block_size_multiplier * mamba_block_size - \ + offsets + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs, n_repetitions) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate( + vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) @@ -664,6 +685,7 @@ def test_apc_multiple_prompts_block_align_alignment( # of the test is executed using `check_outputs_equal` # instead of `check_logprobs_close` @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_apc_multiple_prompts_partial_cached_outputs( @@ -676,6 +698,7 @@ def test_apc_multiple_prompts_partial_cached_outputs( n_repetitions: int, enforce_eager: bool, num_logprobs: int, + tensor_parallel_size: int, cache_dtype_param: str, ) -> None: @@ -697,7 +720,8 @@ def test_apc_multiple_prompts_partial_cached_outputs( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, - max_model_len) + max_model_len, + tensor_parallel_size=tensor_parallel_size) vllm_runner_kwargs[cache_dtype_param] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 48cc022343e5..94b4f3ed0497 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -577,6 +577,9 @@ def forward_cuda( ) if cache_enabled: + # If prefix caching is enabled, retrieve the relevant variables + # for prefill and decode + # Split decodes and prefills: last_state_idx_d, last_state_idx_p = torch.split( attn_metadata.last_computed_token_block_idx, @@ -623,8 +626,17 @@ def forward_cuda( # Process prefill requests if has_prefill: # 2. Convolution sequence transformation - # - "cache_indices" updates the conv_state cache in positions - # pointed to by "state_indices_tensor" + # - It will read the initial states for every sequence, + # that has "has_initial_states_p" == True, + # from "cache_indices", using "state_indices_tensor_p". + # - It updates the "conv_state" cache in positions pointed + # to by "state_indices_tensor_p". + # In particular, it will always write the state at the + # sequence end. + # In addition, "current_first_idx_p" and "current_last_idx_p" + # are provided (which are pointers into + # "state_indices_tensor_p"), it will write additional cache + # states aligned at "block_size_to_align". x = hidden_states_B_C_p.transpose( 0, 1) # this is the form that causal-conv see hidden_states_B_C_p = causal_conv1d_fn( @@ -686,8 +698,8 @@ def forward_cuda( state_dtype=ssm_state.dtype) if cache_enabled: - n_blocks_to_fill = current_last_idx_p - current_first_idx_p # Save states for sequences with more than just the final state: + n_blocks_to_fill = current_last_idx_p - current_first_idx_p for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): cache_blocks_to_fill = state_indices_tensor_p[ seq_idx, current_first_idx_p[seq_idx]: diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 176b5c9148e4..6dd43ba97df5 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -84,8 +84,11 @@ def _causal_conv1d_fwd_kernel( # continuous batching # find the actual sequence length seqlen = sequence_end_index - sequence_start_index + B_size = (stride_block_m * BLOCK_M) + if IS_CACHE_ENABLED: - # Handle the case if prefix caching is enabled + # Handle the case if prefix caching is enabled. + # In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr" # Get the length of the completed sequence so far and compute the offset. current_first_index = tl.load(current_first_idx + idx_seq) @@ -94,7 +97,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching # Compute the offset where the first stride_block_m-aligned first full block is # Value in "token-space" - B_size = (stride_block_m * BLOCK_M) sequence_completed_offset_token = sequence_completed_index % B_size seq_completed_offset = B_size - sequence_completed_offset_token seq_end_offset = (seqlen - seq_completed_offset) % B_size @@ -114,7 +116,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching current_last_index = 0 conv_state_init_index = 0 current_first_index = 0 - sequence_offset_token_index = 0 last_full_block_token_index = 0 token_offset = BLOCK_M * chunk_offset @@ -202,7 +203,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching (idx_feats < dim)[None, :] ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] # Compute the offset where the last block should be written in the conv_states @@ -219,7 +219,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] tl.debug_barrier() # NOTE: use this due to bug in Triton compiler - tl.store(conv_states_ptrs_target, new_conv_state, mask) + tl.store(conv_states_ptrs_target, loaded_x, mask) else: if load_init_state: @@ -316,16 +316,18 @@ def _causal_conv1d_fwd_kernel( # continuous batching conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') - # Store intermediate states aligned with stride_cache_chunk - # The states are cached starting from the last stride_block_m. + # Store intermediate states aligned with stride_block_m + # The additional states are cached starting from the last stride_block_m. # For example: - # If n_block_to_fill = 0, then the state at the sequence is cached. - # If n_block_to_fill > 0, then the states at the sequence and at the n_block_to_fill-last stride_block_m are cached. + # If n_block_to_fill = 0, then only the state at the sequence end is cached and the process below is not involved. + # If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last + # stride_block_m are cached. + # For example chunk_offset = n_block_to_fill stores the state at last_full_block if (chunk_offset - 1) < n_block_to_fill: # Store the states at the chunk boundaries from the start of the sequence idx_tokens_last = ( last_full_block_token_index - - (n_block_to_fill - chunk_offset) * stride_block_m * BLOCK_M - + (n_block_to_fill - chunk_offset) * B_size - state_len) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] x_ptrs = x_ptr + (idx_tokens_last * stride_x_token)[:, None] + ( idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] @@ -334,7 +336,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching (idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[None, :] ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) - new_conv_state = tl.load(x_ptrs, mask_x, 1.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] # cache_idx @@ -349,10 +350,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] idx_tokens_conv * stride_conv_state_tok)[:, None] - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats - < dim)[None, :] + mask = (idx_tokens_conv < state_len)[:, None] & \ + (idx_feats < dim)[None, :] tl.debug_barrier() # NOTE: use this due to bug in Triton compiler - tl.store(conv_states_ptrs_target, new_conv_state, mask) + tl.store(conv_states_ptrs_target, loaded_x, mask) if HAS_BIAS: bias = bias_ptr + idx_feats @@ -490,11 +491,11 @@ def causal_conv1d_fn( in this case, the kernel will not process entries at indices 0 and 3 current_first_idx: (batch) int32 - The first cache block to be filled. This tensor indexes into cache_indices + The pointer into cache_indices, which signifies the first cache block to be filled. current_last_idx: (batch) int32 - The last cache block to be filled. This tensor indexes into cache_indices + The pointer into cache_indices, which signifies the last cache block to be filled. last_state_idx: (batch) int32 - The cache block for the init values. This tensor indexes into cache_indices + The pointer into cache_indices, which signifies the cache block containing the initial state. seq_lens_completed: (batch) int32 The number of tokens already completed for each sequence block_size_to_align: int @@ -539,7 +540,6 @@ def causal_conv1d_fn( np2_statelen = triton.next_power_of_2(state_len) padded_batch = query_start_loc.size(0) - 1 - stride_x_seq = 0 stride_x_dim = x.stride(0) stride_x_token = x.stride(1) stride_w_dim = weight.stride(0) @@ -1057,9 +1057,9 @@ def causal_conv1d_update( and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. current_last_idx: (batch) int32 - The last cache block to be filled. This tensor indexes into cache_indices + The pointer into cache_indices, which signifies the last cache block to be filled. last_state_idx: (batch) int32 - The cache block for the init values. This tensor indexes into cache_indices + The pointer into cache_indices, which signifies the cache block containing the initial state. num_accepted_tokens: (batch,), dtype int32 If not None, it indicates the number of accepted tokens for each sequence in the batch. From 6e8faf9bdfe1970580db28d23be4bf47cdf128a9 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 2 Oct 2025 16:43:10 -0400 Subject: [PATCH 075/105] Precommit fixes Signed-off-by: Stanislaw Wozniak --- .../models/language/generation/test_hybrid.py | 72 ++++++++++++------- .../layers/mamba/ops/causal_conv1d.py | 10 +-- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 7cf819c751a9..2dae53a8ee0c 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -337,7 +337,10 @@ def test_fp32_cache_state( # Helper functions for the APC tests -def _get_vllm_runner_params(model, enforce_eager, max_model_len, tensor_parallel_size=1): +def _get_vllm_runner_params(model, + enforce_eager, + max_model_len, + tensor_parallel_size=1): return { 'model_name': model, 'enable_prefix_caching': False, @@ -413,9 +416,11 @@ def test_apc_single_prompt( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, - max_model_len, - tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs = _get_vllm_runner_params( + model, + enforce_eager, + max_model_len, + tensor_parallel_size=tensor_parallel_size) vllm_runner_kwargs[cache_dtype_param] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, @@ -482,9 +487,11 @@ def test_apc_single_prompt_block_align_alignment( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, - max_model_len, - tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs = _get_vllm_runner_params( + model, + enforce_eager, + max_model_len, + tensor_parallel_size=tensor_parallel_size) vllm_runner_kwargs[cache_dtype_param] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, @@ -567,9 +574,11 @@ def test_apc_multiple_prompts_all_cached_outputs( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, - max_model_len, - tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs = _get_vllm_runner_params( + model, + enforce_eager, + max_model_len, + tensor_parallel_size=tensor_parallel_size) vllm_runner_kwargs[cache_dtype_param] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, @@ -593,7 +602,8 @@ def test_apc_multiple_prompts_all_cached_outputs( name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", ) - + + @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) @@ -628,22 +638,27 @@ def test_apc_multiple_prompts_block_align_alignment( compare_operator: Callable = check_logprobs_close \ if num_logprobs > 0 else check_outputs_equal # type: ignore - + MULTIPLE = 300 # Sample prompts. This custom prompt is used, as it causes the most issues prompt_text = "The president of the United States is " prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31] - generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets] + generated_prompts = [ + prompt_text[offset:] * MULTIPLE for offset in prompt_offsets + ] max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, max_model_len, + vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, + max_model_len, tensor_parallel_size) vllm_runner_kwargs[cache_dtype_param] = "float32" - - vllm_outputs_no_cache, _ = _get_vLLM_output( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs) + + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) vllm_runner_kwargs['enable_prefix_caching'] = True with vllm_runner(**vllm_runner_kwargs) as vllm_model: @@ -657,14 +672,16 @@ def test_apc_multiple_prompts_block_align_alignment( ]: vllm_runner_kwargs[ - 'max_num_batched_tokens'] = mamba_block_size_multiplier * mamba_block_size - \ - offsets - vllm_outputs_cache_rep, _ = _get_vLLM_output( - vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs, n_repetitions) + 'max_num_batched_tokens'] = mamba_block_size_multiplier * \ + mamba_block_size - offsets + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, num_logprobs, + n_repetitions) # Check alignment of the output logits when using APC - for r_idx, vllm_outputs_cache_itn in enumerate( - vllm_outputs_cache_rep): + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): # In the first repetition, the caches are filled # In the second repetition, these caches are reused @@ -674,7 +691,6 @@ def test_apc_multiple_prompts_block_align_alignment( name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", ) - @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) @@ -719,9 +735,11 @@ def test_apc_multiple_prompts_partial_cached_outputs( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, - max_model_len, - tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs = _get_vllm_runner_params( + model, + enforce_eager, + max_model_len, + tensor_parallel_size=tensor_parallel_size) vllm_runner_kwargs[cache_dtype_param] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 6dd43ba97df5..8b733847953e 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -320,15 +320,15 @@ def _causal_conv1d_fwd_kernel( # continuous batching # The additional states are cached starting from the last stride_block_m. # For example: # If n_block_to_fill = 0, then only the state at the sequence end is cached and the process below is not involved. - # If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last + # If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last # stride_block_m are cached. # For example chunk_offset = n_block_to_fill stores the state at last_full_block if (chunk_offset - 1) < n_block_to_fill: # Store the states at the chunk boundaries from the start of the sequence - idx_tokens_last = ( - last_full_block_token_index - - (n_block_to_fill - chunk_offset) * B_size - - state_len) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] + idx_tokens_last = (last_full_block_token_index - + (n_block_to_fill - chunk_offset) * B_size - + state_len) + tl.arange( + 0, NP2_STATELEN) # [BLOCK_M] x_ptrs = x_ptr + (idx_tokens_last * stride_x_token)[:, None] + ( idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] From 0141f15d61f48acc815182a6cb5136a0c2bd51ce Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Fri, 3 Oct 2025 08:29:25 +0000 Subject: [PATCH 076/105] Fixed test_hybrid.py for models w/o mamba_block-size Signed-off-by: Thomas Ortner --- tests/models/language/generation/test_hybrid.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2dae53a8ee0c..2118bf58eb6c 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -445,7 +445,7 @@ def test_apc_single_prompt( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[4]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) @pytest.mark.parametrize("enforce_eager", [True]) @@ -505,6 +505,11 @@ def test_apc_single_prompt_block_align_alignment( mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ mamba_block_size + # In case the hybrid model does not have the + # "mamba_block_size" assume a fixed constant + if mamba_block_size is None: + mamba_block_size = 512 + mamba_block_size_multiplier = 10 for offsets in [ -3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3 @@ -666,6 +671,11 @@ def test_apc_multiple_prompts_block_align_alignment( mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ mamba_block_size + # In case the hybrid model does not have the + # "mamba_block_size" assume a fixed constant + if mamba_block_size is None: + mamba_block_size = 512 + mamba_block_size_multiplier = 10 for offsets in [ -3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3 From 8d0077a97980196a1c007650bf3f5e4c3fd9236f Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 3 Oct 2025 06:50:11 -0400 Subject: [PATCH 077/105] Addressing feedback and cleanup. Signed-off-by: Stanislaw Wozniak --- vllm/config/cache.py | 9 +---- .../layers/mamba/mamba_mixer2.py | 3 +- vllm/v1/attention/backends/mamba2_attn.py | 26 +++++++------- vllm/v1/core/single_type_kv_cache_manager.py | 35 ++++++------------- vllm/v1/kv_cache_interface.py | 8 ++--- vllm/v1/worker/gpu_model_runner.py | 5 ++- 6 files changed, 29 insertions(+), 57 deletions(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index d477a2eee8ca..17fe58fef016 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -25,7 +25,6 @@ CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] MambaDType = Literal["auto", "float32"] -MambaCacheStrategy = Literal["disabled", "all", "last"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] @@ -94,13 +93,7 @@ class CacheConfig: """ Optional override for mamba page size; used by hybrid mamba/attention models to ensure exact alignment with attention page size.""" mamba_block_size: Optional[int] = None - """Size of a contiguous cache block in number of tokens for mamba cache.""" - mamba_cache_strategy: MambaCacheStrategy = "all" - """Logic for mamba cache: - * disabled - turn off prefix caching - * all - keep states for all prefixes - * last - keep the states of the last full blocks after each request - """ + """Size of a contiguous cache block in number of tokens for mamba cache.""" mamba_cache_dtype: MambaDType = "auto" """The data type to use for the Mamba cache (both the conv as well as the ssm state). If set to 'auto', the data type will be inferred from the model diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 94b4f3ed0497..eee4915eda40 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -506,8 +506,7 @@ def forward_cuda( cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p mamba_block_size = attn_metadata.cache_spec.block_size - cache_strategy = attn_metadata.cache_spec.cache_strategy - cache_enabled = (cache_strategy != 'disabled') + cache_enabled = attn_metadata.cache_spec.enable_prefix_caching # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index cdeec640be3f..53b88c162e29 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -75,7 +75,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") assert isinstance(kv_cache_spec, MambaSpec) - if kv_cache_spec.cache_strategy == "all": + if kv_cache_spec.enable_prefix_caching: self.state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs, cdiv(vllm_config.model_config.max_model_len, @@ -129,17 +129,7 @@ def build(self, nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None assert isinstance(self.kv_cache_spec, MambaSpec) - if self.kv_cache_spec.cache_strategy == "disabled": - # Always return just a single block per each request: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, - 0] - # Additional cache-related varaiables: - current_last_token_block_idx = None - current_first_token_block_idx = None - last_computed_token_block_idx = None - last_computed_token_block_offset = None - seq_lens_completed = None - else: + if self.kv_cache_spec.enable_prefix_caching: # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor @@ -168,6 +158,16 @@ def build(self, # -1 in case it's non-computed and causes later issues with indexing last_computed_token_block_idx = \ last_computed_token_block_idx.clamp(min=0) + else: + # Always return just a single block per each request: + state_indices_tensor = common_attn_metadata.block_table_tensor[:, + 0] + # Additional cache-related varaiables: + current_last_token_block_idx = None + current_first_token_block_idx = None + last_computed_token_block_idx = None + last_computed_token_block_offset = None + seq_lens_completed = None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -261,7 +261,7 @@ def build(self, state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID - if self.kv_cache_spec.cache_strategy != 'disabled': + if self.kv_cache_spec.enable_prefix_caching: self.current_last_token_block_idx[:num_decodes].copy_( current_last_token_block_idx, non_blocking=True) current_last_token_block_idx = \ diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 4fc354a957d8..8d529149c0c0 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -545,7 +545,7 @@ def find_longest_cache_hit( assert dcp_world_size == 1, "DCP not support mamba now." computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) - if kv_cache_spec.cache_strategy == "disabled": + if kv_cache_spec.enable_prefix_caching == False: return computed_blocks #return empty list if cache is disabled max_num_blocks = max_length // kv_cache_spec.block_size @@ -567,25 +567,17 @@ def find_longest_cache_hit( def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: - # TODO: For "all" strategy, and potentially for "last" - # we should already start removing initial blocks + # Here unused blocks may be freed up for running requests. + # Future enhancement: Free up all blocks that aren't needed by Mamba2 + # (for which find_longest_cache_hit returns block_pool.null_block) pass def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> int: - assert isinstance(self.kv_cache_spec, MambaSpec) - if self.kv_cache_spec.cache_strategy == "disabled": - return 0 - - # Same as full attention logic: - blocks = self.req_to_blocks[request_id] - num_common_blocks = 0 - for block in blocks: - if block.ref_cnt == num_running_requests: - num_common_blocks += 1 - else: - break - return num_common_blocks + """ + cascade attention is not supported by mamba + """ + return 0 def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, @@ -629,7 +621,7 @@ def allocate_new_blocks(self, request_id: str, num_tokens += (self.kv_cache_spec.block_size * self.kv_cache_spec.num_speculative_blocks) - if self.kv_cache_spec.cache_strategy == "disabled": + if self.kv_cache_spec.enable_prefix_caching == False: new_blocks = super().allocate_new_blocks(request_id, num_tokens) assert len(self.req_to_blocks[request_id]) == 1, ( "MambaManager should only allocate 1 block for each request.") @@ -641,14 +633,7 @@ def allocate_new_blocks(self, request_id: str, if num_new_blocks <= 0: return [] else: - if num_new_blocks > 2 and \ - self.kv_cache_spec.cache_strategy == "last": - # for the last strategy only - allocate 2 blocks: - # one for block_size aligned state - # and one for the last temporary state - new_blocks = self.block_pool.get_new_blocks(2) - else: - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) return new_blocks diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 48f51ea34557..4cf3b6a5d435 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -10,7 +10,6 @@ from typing_extensions import Self from vllm.config import VllmConfig -from vllm.config.cache import MambaCacheStrategy from vllm.logger import init_logger from vllm.utils import cdiv, get_dtype_size @@ -221,7 +220,7 @@ class MambaSpec(KVCacheSpec): dtypes: tuple[torch.dtype] page_size_padded: Optional[int] = None mamba_type: str = "mamba2" - cache_strategy: MambaCacheStrategy = "disabled" + enable_prefix_caching: bool = False num_speculative_blocks: int = 0 @property @@ -235,10 +234,7 @@ def page_size_bytes(self) -> int: return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - if self.cache_strategy == "last": - # Keeps the last full block and one non-full block state: - return 2 * self.page_size_bytes - elif self.cache_strategy == "all": + if self.enable_prefix_caching: # Keeps a state at every block boundary: max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d0b158d020f0..c70fdd757784 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4214,7 +4214,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # Set block_size to max_model_len, so that mamba model # will always have only one block mamba_block_size = self.vllm_config.model_config.max_model_len - self.vllm_config.cache_config.mamba_cache_strategy = "disabled" page_size_padded = ( self.vllm_config.cache_config.mamba_page_size_padded) @@ -4224,8 +4223,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: shapes=mamba_module.get_state_shape(), dtypes=mamba_module.get_state_dtype(), block_size=mamba_block_size, - cache_strategy=self.vllm_config.cache_config. - mamba_cache_strategy, + enable_prefix_caching=self.vllm_config.cache_config. + enable_prefix_caching, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type, num_speculative_blocks=( From 37dacff51983afff9cbe8834d41f98ba8d1f6421 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 3 Oct 2025 07:04:24 -0400 Subject: [PATCH 078/105] Pre-commit fixes. Signed-off-by: Stanislaw Wozniak --- .../layers/mamba/ops/ssd_combined.py | 38 +++++++++---------- vllm/model_executor/models/falcon_h1.py | 4 ++ vllm/v1/core/single_type_kv_cache_manager.py | 4 +- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index e9e589115b8a..a35602834fce 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -159,25 +159,25 @@ def _mamba_chunk_scan_combined_fwd(x, def mamba_chunk_scan_combined_varlen( - x, - dt, - A, - B, - C, - chunk_size, - cu_seqlens, - cu_chunk_seqlens, - last_chunk_indices, - seq_idx, - out, - D=None, - z=None, - dt_bias=None, - initial_states=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - return_intermediate_states=False, - state_dtype=None, + x, + dt, + A, + B, + C, + chunk_size, + cu_seqlens, + cu_chunk_seqlens, + last_chunk_indices, + seq_idx, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_intermediate_states=False, + state_dtype=None, ): """ Argument: diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index a91a84ac5181..f382018e2222 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -540,8 +540,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config + assert (not cache_config.enable_prefix_caching + ), "FalconH1 currently does not support prefix caching" + self.quant_config = vllm_config.quant_config super().__init__() diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8d529149c0c0..ec98a97fbe9b 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -545,7 +545,7 @@ def find_longest_cache_hit( assert dcp_world_size == 1, "DCP not support mamba now." computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) - if kv_cache_spec.enable_prefix_caching == False: + if not kv_cache_spec.enable_prefix_caching: return computed_blocks #return empty list if cache is disabled max_num_blocks = max_length // kv_cache_spec.block_size @@ -621,7 +621,7 @@ def allocate_new_blocks(self, request_id: str, num_tokens += (self.kv_cache_spec.block_size * self.kv_cache_spec.num_speculative_blocks) - if self.kv_cache_spec.enable_prefix_caching == False: + if not self.kv_cache_spec.enable_prefix_caching: new_blocks = super().allocate_new_blocks(request_id, num_tokens) assert len(self.req_to_blocks[request_id]) == 1, ( "MambaManager should only allocate 1 block for each request.") From 17986f89468f172ed5c8f4064b38c9ce8e8cb9c1 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 3 Oct 2025 07:17:16 -0400 Subject: [PATCH 079/105] Pre-commit fixes. Signed-off-by: Stanislaw Wozniak --- .../layers/mamba/ops/ssd_combined.py | 38 +++++++++---------- vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index a35602834fce..e9e589115b8a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -159,25 +159,25 @@ def _mamba_chunk_scan_combined_fwd(x, def mamba_chunk_scan_combined_varlen( - x, - dt, - A, - B, - C, - chunk_size, - cu_seqlens, - cu_chunk_seqlens, - last_chunk_indices, - seq_idx, - out, - D=None, - z=None, - dt_bias=None, - initial_states=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - return_intermediate_states=False, - state_dtype=None, + x, + dt, + A, + B, + C, + chunk_size, + cu_seqlens, + cu_chunk_seqlens, + last_chunk_indices, + seq_idx, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_intermediate_states=False, + state_dtype=None, ): """ Argument: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c70fdd757784..0aae0e162f3f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4224,7 +4224,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtypes=mamba_module.get_state_dtype(), block_size=mamba_block_size, enable_prefix_caching=self.vllm_config.cache_config. - enable_prefix_caching, + enable_prefix_caching, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type, num_speculative_blocks=( From f71ad6de36457c1361bd250bff392008e13a97d8 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Fri, 3 Oct 2025 11:23:40 +0000 Subject: [PATCH 080/105] Integrated code review comments. NOTE: "cache_seqlens" has been removed from causal_conv1d_update Signed-off-by: Thomas Ortner --- .../layers/mamba/ops/causal_conv1d.py | 116 +++++++++--------- 1 file changed, 55 insertions(+), 61 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 8b733847953e..ecef4fb5f8bc 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -20,15 +20,17 @@ def _causal_conv1d_fwd_kernel( # continuous batching w_ptr, # (dim, width) bias_ptr, initial_states_ptr, # conv_states_ptr - cache_indices_ptr, # (dim, cu_seqlen) + cache_indices_ptr, # (batch, n_blocks + padding) The second dimension contains + # the block indices relevant for each sequence + # plus potential 0-padding at the beginning and at the end has_initial_states_ptr, query_start_loc_ptr, batch_ptr, token_chunk_offset_ptr, - current_first_idx, # (dim,) - current_last_idx, # (dim,) - last_state_idx, # (dim,) - seq_lens_completed, # (dim,) + current_first_idx, # (batch,) + current_last_idx, # (batch,) + initial_state_idx, # (batch,) + seq_lens_completed, # (batch,) o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions dim: tl.constexpr, @@ -53,7 +55,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, - IS_CACHE_ENABLED: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, USE_PAD_SLOT: tl.constexpr, NP2_STATELEN: tl.constexpr, BLOCK_M: tl.constexpr, @@ -84,9 +86,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching # find the actual sequence length seqlen = sequence_end_index - sequence_start_index - B_size = (stride_block_m * BLOCK_M) + B_size: tl.constexpr = (stride_block_m * BLOCK_M) - if IS_CACHE_ENABLED: + if IS_APC_ENABLED: # Handle the case if prefix caching is enabled. # In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr" @@ -110,7 +112,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching n_block_to_fill = current_last_index - current_first_index # Get the index of the init block - conv_state_init_index = tl.load(last_state_idx + idx_seq) + conv_state_init_index = tl.load(initial_state_idx + idx_seq) else: n_block_to_fill = 0 current_last_index = 0 @@ -125,16 +127,16 @@ def _causal_conv1d_fwd_kernel( # continuous batching x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_cache_indices + - conv_state_init_index).to(tl.int64) + conv_states_input_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_cache_indices + + conv_state_init_index).to(tl.int64) if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: + if conv_states_input_coord == pad_slot_id: # not processing as this is not the actual sequence return conv_states_base = (conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] @@ -206,13 +208,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] # Compute the offset where the last block should be written in the conv_states - conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * stride_cache_indices + - current_last_index).to(tl.int64) + conv_states_output_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_cache_indices + + current_last_index).to(tl.int64) conv_states_ptrs_target = ( - conv_states_ptr + (conv_states_offset * stride_conv_state_seq) - + # Offset from seq + conv_states_ptr + (conv_states_output_coord * + stride_conv_state_seq) + # Offset from seq (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] idx_tokens_conv * stride_conv_state_tok)[:, None] @@ -228,12 +230,12 @@ def _causal_conv1d_fwd_kernel( # continuous batching conv_states_ptrs_source = ( conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)[None, :] + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) + mask = ((conv_states_input_coord < num_cache_lines) & ((idx_tokens_conv + seqlen) < state_len)[:, None] & (idx_feats < dim)[None, :]) conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) @@ -339,14 +341,14 @@ def _causal_conv1d_fwd_kernel( # continuous batching idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] # cache_idx - conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * stride_cache_indices + - current_first_index + - (chunk_offset - 1)).to(tl.int64) + conv_states_output_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_cache_indices + + current_first_index + + (chunk_offset - 1)).to(tl.int64) conv_states_ptrs_target = ( - conv_states_ptr + (conv_states_offset * stride_conv_state_seq) - + # Offset from seq + conv_states_ptr + (conv_states_output_coord * + stride_conv_state_seq) + # Offset from seq (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] idx_tokens_conv * stride_conv_state_tok)[:, None] @@ -445,7 +447,7 @@ def causal_conv1d_fn( pad_slot_id: int = PAD_SLOT_ID, current_first_idx: Optional[torch.Tensor] = None, current_last_idx: Optional[torch.Tensor] = None, - last_state_idx: Optional[torch.Tensor] = None, + initial_state_idx: Optional[torch.Tensor] = None, seq_lens_completed: Optional[torch.Tensor] = None, block_size_to_align=0, metadata=None, @@ -490,13 +492,13 @@ def causal_conv1d_fn( for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - current_first_idx: (batch) int32 + current_first_idx: (batch,), dtype int32 The pointer into cache_indices, which signifies the first cache block to be filled. - current_last_idx: (batch) int32 + current_last_idx: (batch,), dtype int32 The pointer into cache_indices, which signifies the last cache block to be filled. - last_state_idx: (batch) int32 + initial_state_idx: (batch,), dtype int32 The pointer into cache_indices, which signifies the cache block containing the initial state. - seq_lens_completed: (batch) int32 + seq_lens_completed: (batch,), dtype int32 The number of tokens already completed for each sequence block_size_to_align: int The block size to align the cached states to @@ -677,7 +679,7 @@ def grid(META): token_chunk_offset_ptr, current_first_idx, current_last_idx, - last_state_idx, + initial_state_idx, seq_lens_completed, out, # Matrix dimensions @@ -702,7 +704,7 @@ def grid(META): HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], - IS_CACHE_ENABLED=current_last_idx is not None, + IS_APC_ENABLED=current_last_idx is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, #launch_cooperative_grid=True @@ -723,8 +725,8 @@ def _causal_conv1d_update_kernel( conv_state_indices_ptr, num_accepted_tokens_ptr, query_start_loc_ptr, # (batch + 1) - current_last_idx, - last_state_idx, + current_last_idx, # (batch,) + initial_state_idx, #(batch,) o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -752,7 +754,7 @@ def _causal_conv1d_update_kernel( KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_VARLEN: tl.constexpr, - IS_CACHE_ENABLED: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, @@ -766,21 +768,21 @@ def _causal_conv1d_update_kernel( # [BLOCK_N,] elements along the feature-dimension (channel) idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - if IS_CACHE_ENABLED: - # Get the state from the last_state_idx - conv_state_init = tl.load(last_state_idx + idx_seq) + if IS_APC_ENABLED: + # Get the state from the initial_state_idx + conv_state_init = tl.load(initial_state_idx + idx_seq) current_last_index = tl.load(current_last_idx + idx_seq) else: conv_state_init = 0 current_last_index = 0 # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - conv_state_init).to(tl.int64) + conv_states_input_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + conv_state_init).to(tl.int64) if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: + if conv_states_input_coord == pad_slot_id: # not processing as this is not the actual sequence return @@ -824,7 +826,7 @@ def _causal_conv1d_update_kernel( # STEP 1: READ init_state data conv_states_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)) mask_w = idx_feats < dim @@ -852,12 +854,12 @@ def _causal_conv1d_update_kernel( # window manner, at each forward pass, the tokens are shift by 1, so we # load since idx_tokens + 1. conv_state_ptrs_source = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) + conv_state_token_offset * stride_conv_state_tok + (idx_feats * stride_conv_state_dim)[None, :] + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) + mask = ((conv_states_input_coord < num_cache_lines) & ((idx_tokens + seqlen) < state_len)[:, None] & (idx_feats < dim)[None, :]) conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) @@ -876,7 +878,7 @@ def _causal_conv1d_update_kernel( new_conv_state = tl.where(mask, conv_state, loaded_x) - # Get the state from the last_state_idx + # Get the state from the initial_state_idx # cache_idx conv_states_offset = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices + @@ -1026,14 +1028,13 @@ def causal_conv1d_update( weight: torch.Tensor, bias: Optional[torch.Tensor] = None, activation: Union[bool, str, None] = None, - cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, num_accepted_tokens: Optional[torch.Tensor] = None, query_start_loc: Optional[torch.Tensor] = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, current_last_idx: Optional[torch.Tensor] = None, - last_state_idx: Optional[torch.Tensor] = None, + initial_state_idx: Optional[torch.Tensor] = None, validate_data=False, ): """ @@ -1047,18 +1048,13 @@ def causal_conv1d_update( conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state - starting at the index - @cache_seqlens % state_len. conv_state_indices: (batch,), dtype int32 If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. - current_last_idx: (batch) int32 + current_last_idx: (batch,), dtype int32 The pointer into cache_indices, which signifies the last cache block to be filled. - last_state_idx: (batch) int32 + initial_state_idx: (batch,), dtype int32 The pointer into cache_indices, which signifies the cache block containing the initial state. num_accepted_tokens: (batch,), dtype int32 If not None, it indicates the number of accepted tokens for each @@ -1080,7 +1076,6 @@ def causal_conv1d_update( out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ if validate_data: - assert cache_seqlens is None # not implemented yet - ok for vLLM assert pad_slot_id is not None assert x.stride(1) == 1 if isinstance(activation, bool): @@ -1120,7 +1115,6 @@ def causal_conv1d_update( assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this - assert cache_seqlens is None # not needed for vLLM - circular buffer # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' out = x @@ -1163,7 +1157,7 @@ def grid(META): num_accepted_tokens, query_start_loc, current_last_idx, - last_state_idx, + initial_state_idx, out, # Matrix dimensions batch, @@ -1191,7 +1185,7 @@ def grid(META): KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_VARLEN=query_start_loc is not None, - IS_CACHE_ENABLED=current_last_idx is not None, + IS_APC_ENABLED=current_last_idx is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, From a22c8ab75e397c47e5ad1445e2727b111655a230 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 3 Oct 2025 07:24:41 -0400 Subject: [PATCH 081/105] Pre-commit fixes. Signed-off-by: Stanislaw Wozniak --- vllm/config/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 17fe58fef016..bdfa99cd79a3 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -93,7 +93,7 @@ class CacheConfig: """ Optional override for mamba page size; used by hybrid mamba/attention models to ensure exact alignment with attention page size.""" mamba_block_size: Optional[int] = None - """Size of a contiguous cache block in number of tokens for mamba cache.""" + """Size of a contiguous cache block in number of tokens for mamba cache.""" mamba_cache_dtype: MambaDType = "auto" """The data type to use for the Mamba cache (both the conv as well as the ssm state). If set to 'auto', the data type will be inferred from the model From b49e33f42f866f06b09d87da645d19dc0cca3a94 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Fri, 3 Oct 2025 11:25:13 +0000 Subject: [PATCH 082/105] Adjusted mamba_mixer2.py to new conv1D naming Signed-off-by: Thomas Ortner --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 94b4f3ed0497..2717947293c6 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -649,7 +649,7 @@ def forward_cuda( cache_indices=state_indices_tensor_p, current_first_idx=current_first_idx_p, current_last_idx=current_last_idx_p, - last_state_idx=last_state_idx_p, + initial_state_idx=last_state_idx_p, seq_lens_completed=seq_lens_completed_p, block_size_to_align=mamba_block_size, metadata=attn_metadata, @@ -764,7 +764,7 @@ def forward_cuda( self.activation, conv_state_indices=state_indices_tensor_d, current_last_idx=current_last_idx_d, - last_state_idx=last_state_idx_d, + initial_state_idx=last_state_idx_d, ) hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( From 0e574a776776bd9ecf3eef998b61ca018e7f2b19 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Fri, 3 Oct 2025 11:39:30 +0000 Subject: [PATCH 083/105] Fix assertion for block_size_to_align Signed-off-by: Thomas Ortner --- vllm/model_executor/layers/mamba/ops/causal_conv1d.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index ecef4fb5f8bc..fba97defc01f 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -21,7 +21,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching bias_ptr, initial_states_ptr, # conv_states_ptr cache_indices_ptr, # (batch, n_blocks + padding) The second dimension contains - # the block indices relevant for each sequence + # the block indices relevant for each sequence # plus potential 0-padding at the beginning and at the end has_initial_states_ptr, query_start_loc_ptr, @@ -592,7 +592,9 @@ def causal_conv1d_fn( assert (dim, width) == weight.shape assert is_channel_last, "Need to run in channel-last layout" if block_size_to_align is not None and block_size_to_align > 0: - assert block_size_to_align % BLOCK_M, "The mamba block size needs to be divisible by the BLOCK_M" + assert ( + block_size_to_align % BLOCK_M + ) == 0, "The mamba block size needs to be divisible by the BLOCK_M" else: block_size_to_align = BLOCK_M From ae21a8c2c2a7823c3134f35b64904f67f2db5eec Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Fri, 3 Oct 2025 15:57:04 +0000 Subject: [PATCH 084/105] Integrated code review comments Signed-off-by: Thomas Ortner --- .../models/language/generation/test_hybrid.py | 66 ++++--------------- 1 file changed, 11 insertions(+), 55 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2118bf58eb6c..2ca7375b59ad 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -337,18 +337,12 @@ def test_fp32_cache_state( # Helper functions for the APC tests -def _get_vllm_runner_params(model, - enforce_eager, - max_model_len, - tensor_parallel_size=1): +def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1): return { 'model_name': model, 'enable_prefix_caching': False, - 'enforce_eager': enforce_eager, 'max_model_len': max_model_len, 'tensor_parallel_size': tensor_parallel_size, - 'disable_cascade_attn': True, ## not verified yet - 'disable_log_stats': False, ## collect APC stats 'gpu_memory_utilization': 0.4 } @@ -377,14 +371,11 @@ def _get_vLLM_output(vllm_runner, @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) # If num_logprobs is set to -1, then the stringent version # of the test is executed using `check_outputs_equal` # instead of `check_logprobs_close` @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("cache_dtype_param", - ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_apc_single_prompt( hf_runner, vllm_runner, @@ -393,10 +384,8 @@ def test_apc_single_prompt( model: str, max_tokens: int, n_repetitions: int, - enforce_eager: bool, num_logprobs: int, tensor_parallel_size: int, - cache_dtype_param: str, ) -> None: try: @@ -417,11 +406,8 @@ def test_apc_single_prompt( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( - model, - enforce_eager, - max_model_len, - tensor_parallel_size=tensor_parallel_size) - vllm_runner_kwargs[cache_dtype_param] = "float32" + model, max_model_len, tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, @@ -448,14 +434,11 @@ def test_apc_single_prompt( @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) # If num_logprobs is set to -1, then the stringent version # of the test is executed using `check_outputs_equal` # instead of `check_logprobs_close` @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("cache_dtype_param", - ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_apc_single_prompt_block_align_alignment( hf_runner, vllm_runner, @@ -464,10 +447,8 @@ def test_apc_single_prompt_block_align_alignment( model: str, max_tokens: int, n_repetitions: int, - enforce_eager: bool, num_logprobs: int, tensor_parallel_size: int, - cache_dtype_param: str, ) -> None: try: @@ -488,11 +469,8 @@ def test_apc_single_prompt_block_align_alignment( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( - model, - enforce_eager, - max_model_len, - tensor_parallel_size=tensor_parallel_size) - vllm_runner_kwargs[cache_dtype_param] = "float32" + model, max_model_len, tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, @@ -540,14 +518,11 @@ def test_apc_single_prompt_block_align_alignment( @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) # If num_logprobs is set to -1, then the stringent version # of the test is executed using `check_outputs_equal` # instead of `check_logprobs_close` @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("cache_dtype_param", - ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_apc_multiple_prompts_all_cached_outputs( hf_runner, vllm_runner, @@ -556,10 +531,8 @@ def test_apc_multiple_prompts_all_cached_outputs( model: str, max_tokens: int, n_repetitions: int, - enforce_eager: bool, num_logprobs: int, tensor_parallel_size: int, - cache_dtype_param: str, ) -> None: try: @@ -580,11 +553,8 @@ def test_apc_multiple_prompts_all_cached_outputs( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( - model, - enforce_eager, - max_model_len, - tensor_parallel_size=tensor_parallel_size) - vllm_runner_kwargs[cache_dtype_param] = "float32" + model, max_model_len, tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, @@ -612,14 +582,11 @@ def test_apc_multiple_prompts_all_cached_outputs( @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) # If num_logprobs is set to -1, then the stringent version # of the test is executed using `check_outputs_equal` # instead of `check_logprobs_close` @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("cache_dtype_param", - ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_apc_multiple_prompts_block_align_alignment( hf_runner, vllm_runner, @@ -628,10 +595,8 @@ def test_apc_multiple_prompts_block_align_alignment( model: str, max_tokens: int, n_repetitions: int, - enforce_eager: bool, num_logprobs: int, tensor_parallel_size: int, - cache_dtype_param: str, ) -> None: try: @@ -655,10 +620,9 @@ def test_apc_multiple_prompts_block_align_alignment( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) - vllm_runner_kwargs = _get_vllm_runner_params(model, enforce_eager, - max_model_len, + vllm_runner_kwargs = _get_vllm_runner_params(model, max_model_len, tensor_parallel_size) - vllm_runner_kwargs[cache_dtype_param] = "float32" + vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, @@ -706,14 +670,11 @@ def test_apc_multiple_prompts_block_align_alignment( @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) -@pytest.mark.parametrize("enforce_eager", [True]) # If num_logprobs is set to -1, then the stringent version # of the test is executed using `check_outputs_equal` # instead of `check_logprobs_close` @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("cache_dtype_param", - ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_apc_multiple_prompts_partial_cached_outputs( hf_runner, vllm_runner, @@ -722,10 +683,8 @@ def test_apc_multiple_prompts_partial_cached_outputs( model: str, max_tokens: int, n_repetitions: int, - enforce_eager: bool, num_logprobs: int, tensor_parallel_size: int, - cache_dtype_param: str, ) -> None: try: @@ -746,11 +705,8 @@ def test_apc_multiple_prompts_partial_cached_outputs( max_model_len = max( len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( - model, - enforce_eager, - max_model_len, - tensor_parallel_size=tensor_parallel_size) - vllm_runner_kwargs[cache_dtype_param] = "float32" + model, max_model_len, tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32" vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, vllm_runner_kwargs, From 60cc1cd02d18099826b8e8417858b104cf5701e5 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 14:35:11 -0400 Subject: [PATCH 085/105] cache_enabled -> enable_prefix_caching Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 5526222048cc..07bc33de4348 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -455,6 +455,8 @@ def __init__(self, self.cache_config = cache_config self.prefix = prefix + assert self.cache_config is not None + def forward_native( self, hidden_states: torch.Tensor, @@ -488,7 +490,8 @@ def forward_cuda( # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = forward_context.attn_metadata - cache_enabled = False + assert self.cache_config is not None + prefix_caching_enabled = self.cache_config.enable_prefix_caching if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] @@ -506,7 +509,6 @@ def forward_cuda( cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p mamba_block_size = attn_metadata.cache_spec.block_size - cache_enabled = attn_metadata.cache_spec.enable_prefix_caching # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -575,7 +577,7 @@ def forward_cuda( dim=0, ) - if cache_enabled: + if prefix_caching_enabled: # If prefix caching is enabled, retrieve the relevant variables # for prefill and decode @@ -662,7 +664,7 @@ def forward_cuda( initial_states = None if (has_initial_states_p is not None and prep_initial_states): kernel_ssm_indices = state_indices_tensor_p - if cache_enabled: + if prefix_caching_enabled: kernel_ssm_indices = state_indices_tensor_p.gather( 1, last_state_idx_p.unsqueeze(1)).squeeze(1) initial_states = torch.where( @@ -689,14 +691,14 @@ def forward_cuda( cu_chunk_seqlens=cu_chunk_seqlen_p, last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, - return_intermediate_states=cache_enabled, + return_intermediate_states=prefix_caching_enabled, dt_softplus=True, dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), state_dtype=ssm_state.dtype) - if cache_enabled: + if prefix_caching_enabled: # Save states for sequences with more than just the final state: n_blocks_to_fill = current_last_idx_p - current_first_idx_p for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): @@ -739,7 +741,7 @@ def forward_cuda( # Process decode requests if has_decode: - if cache_enabled: + if prefix_caching_enabled: state_indices_tensor_d_input = \ state_indices_tensor_d.gather(1, last_state_idx_d.unsqueeze(1)).squeeze(1) From 0d1b054b0a2ef5b5035d88ef60a38344dcc7e982 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 14:36:51 -0400 Subject: [PATCH 086/105] Reduce diff Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 07bc33de4348..b5f28f05633c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -455,8 +455,6 @@ def __init__(self, self.cache_config = cache_config self.prefix = prefix - assert self.cache_config is not None - def forward_native( self, hidden_states: torch.Tensor, @@ -490,6 +488,7 @@ def forward_cuda( # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = forward_context.attn_metadata + assert self.cache_config is not None prefix_caching_enabled = self.cache_config.enable_prefix_caching if attn_metadata is not None: From e9f22570fc917ef5cf835a7f5a3f2c29897679d6 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 14:42:42 -0400 Subject: [PATCH 087/105] Removed unused code Signed-off-by: Thomas Parnell --- vllm/v1/core/single_type_kv_cache_manager.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index ec98a97fbe9b..2d052a44d04d 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -545,8 +545,6 @@ def find_longest_cache_hit( assert dcp_world_size == 1, "DCP not support mamba now." computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) - if not kv_cache_spec.enable_prefix_caching: - return computed_blocks #return empty list if cache is disabled max_num_blocks = max_length // kv_cache_spec.block_size # Search from right to left and early stop when a match is found. From e733552ba4eb54381751ee34765660677ded1e5d Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 14:44:15 -0400 Subject: [PATCH 088/105] Update comment Signed-off-by: Thomas Parnell --- vllm/v1/core/single_type_kv_cache_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 2d052a44d04d..1f1e07070138 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -566,7 +566,7 @@ def find_longest_cache_hit( def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Here unused blocks may be freed up for running requests. - # Future enhancement: Free up all blocks that aren't needed by Mamba2 + # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 # (for which find_longest_cache_hit returns block_pool.null_block) pass From 6ebc97f73cb4ed44787117005f4ecd553fff68e0 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 14:51:07 -0400 Subject: [PATCH 089/105] Remove duplicate code Signed-off-by: Thomas Parnell --- vllm/v1/core/single_type_kv_cache_manager.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 1f1e07070138..00911dcec069 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -619,21 +619,7 @@ def allocate_new_blocks(self, request_id: str, num_tokens += (self.kv_cache_spec.block_size * self.kv_cache_spec.num_speculative_blocks) - if not self.kv_cache_spec.enable_prefix_caching: - new_blocks = super().allocate_new_blocks(request_id, num_tokens) - assert len(self.req_to_blocks[request_id]) == 1, ( - "MambaManager should only allocate 1 block for each request.") - return new_blocks - - req_blocks = self.req_to_blocks[request_id] - num_required_blocks = cdiv(num_tokens, self.block_size) - num_new_blocks = num_required_blocks - len(req_blocks) - if num_new_blocks <= 0: - return [] - else: - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) - return new_blocks + return super().allocate_new_blocks(request_id, num_tokens) class CrossAttentionManager(SingleTypeKVCacheManager): From 19b33a0874b6077a09c130ad8a73b4a7ec9848ea Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 14:52:20 -0400 Subject: [PATCH 090/105] reduce diff Signed-off-by: Thomas Parnell --- vllm/v1/core/single_type_kv_cache_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 00911dcec069..c7bcadc11581 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -618,7 +618,6 @@ def allocate_new_blocks(self, request_id: str, if self.kv_cache_spec.num_speculative_blocks > 0: num_tokens += (self.kv_cache_spec.block_size * self.kv_cache_spec.num_speculative_blocks) - return super().allocate_new_blocks(request_id, num_tokens) From 9ab2dfc88287b9e13d15c6d350d11bd220178c4c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 15:05:38 -0400 Subject: [PATCH 091/105] Remove cache spec from mamba metadata Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 2 +- vllm/model_executor/models/config.py | 1 + vllm/v1/attention/backends/mamba2_attn.py | 12 ++++-------- vllm/v1/worker/gpu_model_runner.py | 9 +-------- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index b5f28f05633c..4f923ceec1d0 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -490,6 +490,7 @@ def forward_cuda( attn_metadata: AttentionMetadata = forward_context.attn_metadata assert self.cache_config is not None + mamba_block_size = self.cache_config.mamba_block_size prefix_caching_enabled = self.cache_config.enable_prefix_caching if attn_metadata is not None: assert isinstance(attn_metadata, dict) @@ -507,7 +508,6 @@ def forward_cuda( query_start_loc_p = attn_metadata.query_start_loc_p cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p - mamba_block_size = attn_metadata.cache_spec.block_size # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 1d7150f9effb..3dc446d40a32 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -411,6 +411,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # with mamba layers, use FlashInfer instead). attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token) + cache_config.mamba_block_size = model_config.max_model_len # override attention block size if either (a) the # user has not set it or (b) the user has set it diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index c137f0dbce37..679578fe9209 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -15,7 +15,7 @@ CommonAttentionMetadata, compute_causal_conv1d_metadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import AttentionSpec def compute_varlen_chunk_metadata( @@ -132,7 +132,6 @@ class Mamba2AttentionMetadata: nums_dict: Optional[dict] = None batch_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None - cache_spec: Optional[MambaSpec] = None class Mamba2AttentionMetadataBuilder( @@ -144,8 +143,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") - assert isinstance(kv_cache_spec, MambaSpec) - if kv_cache_spec.enable_prefix_caching: + if self.vllm_config.cache_config.enable_prefix_caching: self.state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs, cdiv(vllm_config.model_config.max_model_len, @@ -198,8 +196,7 @@ def build(self, # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - assert isinstance(self.kv_cache_spec, MambaSpec) - if self.kv_cache_spec.enable_prefix_caching: + if self.vllm_config.cache_config.enable_prefix_caching: # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor @@ -331,7 +328,7 @@ def build(self, state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID - if self.kv_cache_spec.enable_prefix_caching: + if self.vllm_config.cache_config.enable_prefix_caching: self.current_last_token_block_idx[:num_decodes].copy_( current_last_token_block_idx, non_blocking=True) current_last_token_block_idx = \ @@ -370,7 +367,6 @@ def build(self, seq_lens=seq_lens, prep_initial_states=prep_initial_states, chunk_size=self.chunk_size, - cache_spec=self.kv_cache_spec, has_initial_states_p=has_initial_states_p, seq_idx_p=seq_idx_p, state_indices_tensor=state_indices_tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0a47208f2791..f7b90b0bb7e4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4240,14 +4240,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: not in ["qwen3_next"]): raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") - if self.vllm_config.cache_config.enable_prefix_caching: - mamba_block_size = \ - self.vllm_config.cache_config.mamba_block_size - else: - # Set block_size to max_model_len, so that mamba model - # will always have only one block - mamba_block_size = self.vllm_config.model_config.max_model_len - + mamba_block_size = (self.vllm_config.cache_config.mamba_block_size) page_size_padded = ( self.vllm_config.cache_config.mamba_page_size_padded) From f6293daf913c676f618024b938c29354750dedbb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 15:12:13 -0400 Subject: [PATCH 092/105] remove enable_prefix_caching from MambaSpec Signed-off-by: Thomas Parnell --- vllm/v1/kv_cache_interface.py | 9 ++------- vllm/v1/worker/gpu_model_runner.py | 4 +--- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4cf3b6a5d435..054ab591b817 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -220,7 +220,6 @@ class MambaSpec(KVCacheSpec): dtypes: tuple[torch.dtype] page_size_padded: Optional[int] = None mamba_type: str = "mamba2" - enable_prefix_caching: bool = False num_speculative_blocks: int = 0 @property @@ -234,12 +233,8 @@ def page_size_bytes(self) -> int: return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - if self.enable_prefix_caching: - # Keeps a state at every block boundary: - max_model_len = vllm_config.model_config.max_model_len - return cdiv(max_model_len, self.block_size) * self.page_size_bytes - # By default keeps the last state only: - return self.page_size_bytes + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes @dataclass(frozen=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f7b90b0bb7e4..11e24e4d13dc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4240,7 +4240,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: not in ["qwen3_next"]): raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") - mamba_block_size = (self.vllm_config.cache_config.mamba_block_size) + mamba_block_size = self.vllm_config.cache_config.mamba_block_size page_size_padded = ( self.vllm_config.cache_config.mamba_page_size_padded) @@ -4249,8 +4249,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: shapes=mamba_module.get_state_shape(), dtypes=mamba_module.get_state_dtype(), block_size=mamba_block_size, - enable_prefix_caching=self.vllm_config.cache_config. - enable_prefix_caching, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type, num_speculative_blocks=( From 0806c064a2d78087e5831b352391f6f6b3dfa052 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 15:23:30 -0400 Subject: [PATCH 093/105] Disable prefix caching by default for hybrid models Signed-off-by: Thomas Parnell --- vllm/engine/arg_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bf293a4d2aa9..89a881675ad6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1563,7 +1563,12 @@ def _set_default_args(self, usage_context: UsageContext, self.enable_prefix_caching = False if self.enable_prefix_caching is None: - self.enable_prefix_caching = True + # Disable prefix caching default for hybrid models + # since the feature is still experimental. + if model_config.is_hybrid: + self.enable_prefix_caching = False + else: + self.enable_prefix_caching = True else: pooling_type = model_config.pooler_config.pooling_type From ac31d4853a346277f860cebd0bc29b4f8e3e3c53 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 15:29:18 -0400 Subject: [PATCH 094/105] Add logging about disabling cascade attn Signed-off-by: Thomas Parnell --- vllm/model_executor/models/config.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 3dc446d40a32..fb67968ae90a 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -312,6 +312,11 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "support for prefix caching: disabling.") cache_config.enable_prefix_caching = False + # TODO(tdoublep): remove once cascade attention is supported + logger.info("Disabling cascade attention since it is not supported " + "for hybrid models.") + model_config.disable_cascade_attn = True + # TODO(tdoublep): remove as full cuda graph support is added FCG_NOT_SUPPORTED_MODELS = [ "Lfm2ForCausalLM", @@ -375,13 +380,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=model_config.max_model_len, ).page_size_bytes - # Cascade attn doesn't work with Mamba: - # * enable_prefix_caching = True -> fails - # * enable_prefix_caching = False -> cascade attention is triggered, - # but always terminates early, not raising any exception - # Thus, it's more effective to disable the cascade attention logic: - model_config.disable_cascade_attn = True - if cache_config.enable_prefix_caching: # With prefix caching, select attention block size to # optimize for mamba kernel performance From e57cf9cfb7b0cbe9c6e64e853cbc112b684193a8 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 15:36:37 -0400 Subject: [PATCH 095/105] rename seq_lens_completed -> context_lens Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 10 +++---- .../layers/mamba/ops/causal_conv1d.py | 10 +++---- vllm/v1/attention/backends/mamba2_attn.py | 26 +++++++++---------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 4f923ceec1d0..5c8e3a55b1c6 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -594,9 +594,9 @@ def forward_cuda( attn_metadata.current_first_token_block_idx, [num_decodes, num_prefills], dim=0) - _, seq_lens_completed_p = torch.split( - attn_metadata.seq_lens_completed, [num_decodes, num_prefills], - dim=0) + _, context_lens_p = torch.split(attn_metadata.context_lens, + [num_decodes, num_prefills], + dim=0) _, last_computed_offset_p = torch.split( attn_metadata.last_computed_token_block_offset, [num_decodes, num_prefills], @@ -605,7 +605,7 @@ def forward_cuda( last_state_idx_d, last_state_idx_p = None, None current_last_idx_d, current_last_idx_p = None, None _, current_first_idx_p = None, None - _, seq_lens_completed_p = None, None + _, context_lens_p = None, None # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -650,7 +650,7 @@ def forward_cuda( current_first_idx=current_first_idx_p, current_last_idx=current_last_idx_p, initial_state_idx=last_state_idx_p, - seq_lens_completed=seq_lens_completed_p, + context_lens=context_lens_p, block_size_to_align=mamba_block_size, metadata=attn_metadata, query_start_loc=query_start_loc_p).transpose( diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index fba97defc01f..44be0dac0919 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -30,7 +30,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching current_first_idx, # (batch,) current_last_idx, # (batch,) initial_state_idx, # (batch,) - seq_lens_completed, # (batch,) + context_lens, # (batch,) o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions dim: tl.constexpr, @@ -95,7 +95,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching # Get the length of the completed sequence so far and compute the offset. current_first_index = tl.load(current_first_idx + idx_seq) current_last_index = tl.load(current_last_idx + idx_seq) - sequence_completed_index = tl.load(seq_lens_completed + idx_seq) + sequence_completed_index = tl.load(context_lens + idx_seq) # Compute the offset where the first stride_block_m-aligned first full block is # Value in "token-space" @@ -448,7 +448,7 @@ def causal_conv1d_fn( current_first_idx: Optional[torch.Tensor] = None, current_last_idx: Optional[torch.Tensor] = None, initial_state_idx: Optional[torch.Tensor] = None, - seq_lens_completed: Optional[torch.Tensor] = None, + context_lens: Optional[torch.Tensor] = None, block_size_to_align=0, metadata=None, validate_data=False, @@ -498,7 +498,7 @@ def causal_conv1d_fn( The pointer into cache_indices, which signifies the last cache block to be filled. initial_state_idx: (batch,), dtype int32 The pointer into cache_indices, which signifies the cache block containing the initial state. - seq_lens_completed: (batch,), dtype int32 + context_lens: (batch,), dtype int32 The number of tokens already completed for each sequence block_size_to_align: int The block size to align the cached states to @@ -682,7 +682,7 @@ def grid(META): current_first_idx, current_last_idx, initial_state_idx, - seq_lens_completed, + context_lens, out, # Matrix dimensions dim, diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 679578fe9209..55fdd0379d6e 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -125,7 +125,7 @@ class Mamba2AttentionMetadata: current_last_token_block_idx: torch.Tensor current_first_token_block_idx: torch.Tensor last_computed_token_block_idx: torch.Tensor - seq_lens_completed: torch.Tensor + context_lens: torch.Tensor last_computed_token_block_offset: torch.Tensor # The following attributes are for triton implementation of causal_conv1d @@ -166,7 +166,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device=device, ) - self.seq_lens_completed = torch.empty( + self.context_lens = torch.empty( (self.decode_cudagraph_max_bs, ), dtype=torch.int32, device=device, @@ -205,10 +205,10 @@ def build(self, seq_lens_pending = ( torch.roll(common_attn_metadata.query_start_loc, -1, -1) - common_attn_metadata.query_start_loc)[:-1] - seq_lens_completed = common_attn_metadata.seq_lens - \ + context_lens = common_attn_metadata.seq_lens - \ seq_lens_pending last_computed_token_block_offset = \ - seq_lens_completed % mamba_block_size + context_lens % mamba_block_size # Indices: last_computed <= current_first <= current_last # Cases: # last_computed == current_first if last state was partially @@ -217,10 +217,10 @@ def build(self, # only one state will be stored # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: current_last_token_block_idx = cdiv( - seq_lens_completed + seq_lens_pending, mamba_block_size) - 1 - current_first_token_block_idx = cdiv(seq_lens_completed + 1, + context_lens + seq_lens_pending, mamba_block_size) - 1 + current_first_token_block_idx = cdiv(context_lens + 1, mamba_block_size) - 1 - last_computed_token_block_idx = cdiv(seq_lens_completed, + last_computed_token_block_idx = cdiv(context_lens, mamba_block_size) - 1 # -1 in case it's non-computed and causes later issues with indexing last_computed_token_block_idx = \ @@ -234,7 +234,7 @@ def build(self, current_first_token_block_idx = None last_computed_token_block_idx = None last_computed_token_block_offset = None - seq_lens_completed = None + context_lens = None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -347,10 +347,10 @@ def build(self, self.last_computed_token_block_idx[:num_input_tokens] last_computed_token_block_idx[num_decodes:] = 0 - self.seq_lens_completed[:num_decodes].copy_(seq_lens_completed, - non_blocking=True) - seq_lens_completed = self.seq_lens_completed[:num_input_tokens] - seq_lens_completed[num_decodes:] = 0 + self.context_lens[:num_decodes].copy_(context_lens, + non_blocking=True) + context_lens = self.context_lens[:num_input_tokens] + context_lens[num_decodes:] = 0 self.last_computed_token_block_offset[:num_decodes].copy_( last_computed_token_block_offset, non_blocking=True) @@ -378,7 +378,7 @@ def build(self, current_last_token_block_idx=current_last_token_block_idx, current_first_token_block_idx=current_first_token_block_idx, last_computed_token_block_idx=last_computed_token_block_idx, - seq_lens_completed=seq_lens_completed, + context_lens=context_lens, last_computed_token_block_offset=last_computed_token_block_offset, ) return attn_metadata From a49f94d94774ba93f8346f33b0e487863e31e81f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 15:47:28 -0400 Subject: [PATCH 096/105] Consistent naming between mamba_mixer2 and mamba2 metadata Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 20 +-- vllm/v1/attention/backends/mamba2_attn.py | 117 +++++++++--------- 2 files changed, 63 insertions(+), 74 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 5c8e3a55b1c6..91bbc1a38526 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -582,25 +582,15 @@ def forward_cuda( # Split decodes and prefills: last_state_idx_d, last_state_idx_p = torch.split( - attn_metadata.last_computed_token_block_idx, - [num_decodes, num_prefills], + attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0) current_last_idx_d, current_last_idx_p = torch.split( - attn_metadata.current_last_token_block_idx, - [num_decodes, num_prefills], + attn_metadata.current_last_idx, [num_decodes, num_prefills], dim=0) # Prefill-only variables: - _, current_first_idx_p = torch.split( - attn_metadata.current_first_token_block_idx, - [num_decodes, num_prefills], - dim=0) - _, context_lens_p = torch.split(attn_metadata.context_lens, - [num_decodes, num_prefills], - dim=0) - _, last_computed_offset_p = torch.split( - attn_metadata.last_computed_token_block_offset, - [num_decodes, num_prefills], - dim=0) + current_first_idx_p = attn_metadata.current_first_idx_p + context_lens_p = attn_metadata.context_lens_p + last_computed_offset_p = attn_metadata.last_computed_offset_p else: last_state_idx_d, last_state_idx_p = None, None current_last_idx_d, current_last_idx_p = None, None diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 55fdd0379d6e..a8959daa4c10 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -122,11 +122,11 @@ class Mamba2AttentionMetadata: last_chunk_indices_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] - current_last_token_block_idx: torch.Tensor - current_first_token_block_idx: torch.Tensor - last_computed_token_block_idx: torch.Tensor - context_lens: torch.Tensor - last_computed_token_block_offset: torch.Tensor + current_last_idx: torch.Tensor + current_first_idx_p: torch.Tensor + last_state_idx: torch.Tensor + context_lens_p: torch.Tensor + last_computed_offset_p: torch.Tensor # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None @@ -151,27 +151,27 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device=device, ) - self.current_last_token_block_idx = torch.empty( + self.current_last_idx = torch.empty( (self.decode_cudagraph_max_bs, ), dtype=torch.int32, device=device, ) - self.current_first_token_block_idx = torch.empty( + self.current_first_idx_p = torch.empty( (self.decode_cudagraph_max_bs, ), dtype=torch.int32, device=device, ) - self.last_computed_token_block_idx = torch.empty( + self.last_state_idx = torch.empty( (self.decode_cudagraph_max_bs, ), dtype=torch.int32, device=device, ) - self.context_lens = torch.empty( + self.context_lens_p = torch.empty( (self.decode_cudagraph_max_bs, ), dtype=torch.int32, device=device, ) - self.last_computed_token_block_offset = torch.empty( + self.last_computed_offset_p = torch.empty( (self.decode_cudagraph_max_bs, ), dtype=torch.int32, device=device, @@ -205,10 +205,10 @@ def build(self, seq_lens_pending = ( torch.roll(common_attn_metadata.query_start_loc, -1, -1) - common_attn_metadata.query_start_loc)[:-1] - context_lens = common_attn_metadata.seq_lens - \ + context_lens_p = common_attn_metadata.seq_lens - \ seq_lens_pending - last_computed_token_block_offset = \ - context_lens % mamba_block_size + last_computed_offset_p = \ + context_lens_p % mamba_block_size # Indices: last_computed <= current_first <= current_last # Cases: # last_computed == current_first if last state was partially @@ -216,25 +216,24 @@ def build(self, # current_first == current_last if no block crossing occurs, and # only one state will be stored # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: - current_last_token_block_idx = cdiv( - context_lens + seq_lens_pending, mamba_block_size) - 1 - current_first_token_block_idx = cdiv(context_lens + 1, - mamba_block_size) - 1 - last_computed_token_block_idx = cdiv(context_lens, - mamba_block_size) - 1 + current_last_idx = cdiv(context_lens_p + seq_lens_pending, + mamba_block_size) - 1 + current_first_idx_p = cdiv(context_lens_p + 1, + mamba_block_size) - 1 + last_state_idx = cdiv(context_lens_p, mamba_block_size) - 1 # -1 in case it's non-computed and causes later issues with indexing - last_computed_token_block_idx = \ - last_computed_token_block_idx.clamp(min=0) + last_state_idx = \ + last_state_idx.clamp(min=0) else: # Always return just a single block per each request: state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] # Additional cache-related varaiables: - current_last_token_block_idx = None - current_first_token_block_idx = None - last_computed_token_block_idx = None - last_computed_token_block_offset = None - context_lens = None + current_last_idx = None + current_first_idx_p = None + last_state_idx = None + last_computed_offset_p = None + context_lens_p = None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -329,34 +328,34 @@ def build(self, state_indices_tensor[num_decodes:] = PAD_SLOT_ID if self.vllm_config.cache_config.enable_prefix_caching: - self.current_last_token_block_idx[:num_decodes].copy_( - current_last_token_block_idx, non_blocking=True) - current_last_token_block_idx = \ - self.current_last_token_block_idx[:num_input_tokens] - current_last_token_block_idx[num_decodes:] = 0 - - self.current_first_token_block_idx[:num_decodes].copy_( - current_first_token_block_idx, non_blocking=True) - current_first_token_block_idx = \ - self.current_first_token_block_idx[:num_input_tokens] - current_first_token_block_idx[num_decodes:] = 0 - - self.last_computed_token_block_idx[:num_decodes].copy_( - last_computed_token_block_idx, non_blocking=True) - last_computed_token_block_idx = \ - self.last_computed_token_block_idx[:num_input_tokens] - last_computed_token_block_idx[num_decodes:] = 0 - - self.context_lens[:num_decodes].copy_(context_lens, - non_blocking=True) - context_lens = self.context_lens[:num_input_tokens] - context_lens[num_decodes:] = 0 - - self.last_computed_token_block_offset[:num_decodes].copy_( - last_computed_token_block_offset, non_blocking=True) - last_computed_token_block_offset = \ - self.last_computed_token_block_offset[:num_input_tokens] - last_computed_token_block_offset[num_decodes:] = 0 + self.current_last_idx[:num_decodes].copy_(current_last_idx, + non_blocking=True) + current_last_idx = \ + self.current_last_idx[:num_input_tokens] + current_last_idx[num_decodes:] = 0 + + self.current_first_idx_p[:num_decodes].copy_( + current_first_idx_p, non_blocking=True) + current_first_idx_p = \ + self.current_first_idx_p[:num_input_tokens] + current_first_idx_p[num_decodes:] = 0 + + self.last_state_idx[:num_decodes].copy_(last_state_idx, + non_blocking=True) + last_state_idx = \ + self.last_state_idx[:num_input_tokens] + last_state_idx[num_decodes:] = 0 + + self.context_lens_p[:num_decodes].copy_(context_lens_p, + non_blocking=True) + context_lens_p = self.context_lens_p[:num_input_tokens] + context_lens_p[num_decodes:] = 0 + + self.last_computed_offset_p[:num_decodes].copy_( + last_computed_offset_p, non_blocking=True) + last_computed_offset_p = \ + self.last_computed_offset_p[:num_input_tokens] + last_computed_offset_p[num_decodes:] = 0 attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, @@ -375,10 +374,10 @@ def build(self, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr, - current_last_token_block_idx=current_last_token_block_idx, - current_first_token_block_idx=current_first_token_block_idx, - last_computed_token_block_idx=last_computed_token_block_idx, - context_lens=context_lens, - last_computed_token_block_offset=last_computed_token_block_offset, + current_last_idx=current_last_idx, + current_first_idx_p=current_first_idx_p, + last_state_idx=last_state_idx, + context_lens_p=context_lens_p, + last_computed_offset_p=last_computed_offset_p, ) return attn_metadata From 785ef4a4f86b3c4f486e03dcced6eb9753145663 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 15:53:55 -0400 Subject: [PATCH 097/105] Remove FCG-handling for prefill-only tensors Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/mamba2_attn.py | 32 ----------------------- 1 file changed, 32 deletions(-) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index a8959daa4c10..bd9be7dc6afd 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -156,26 +156,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device=device, ) - self.current_first_idx_p = torch.empty( - (self.decode_cudagraph_max_bs, ), - dtype=torch.int32, - device=device, - ) self.last_state_idx = torch.empty( (self.decode_cudagraph_max_bs, ), dtype=torch.int32, device=device, ) - self.context_lens_p = torch.empty( - (self.decode_cudagraph_max_bs, ), - dtype=torch.int32, - device=device, - ) - self.last_computed_offset_p = torch.empty( - (self.decode_cudagraph_max_bs, ), - dtype=torch.int32, - device=device, - ) def build(self, common_prefix_len: int, @@ -334,29 +319,12 @@ def build(self, self.current_last_idx[:num_input_tokens] current_last_idx[num_decodes:] = 0 - self.current_first_idx_p[:num_decodes].copy_( - current_first_idx_p, non_blocking=True) - current_first_idx_p = \ - self.current_first_idx_p[:num_input_tokens] - current_first_idx_p[num_decodes:] = 0 - self.last_state_idx[:num_decodes].copy_(last_state_idx, non_blocking=True) last_state_idx = \ self.last_state_idx[:num_input_tokens] last_state_idx[num_decodes:] = 0 - self.context_lens_p[:num_decodes].copy_(context_lens_p, - non_blocking=True) - context_lens_p = self.context_lens_p[:num_input_tokens] - context_lens_p[num_decodes:] = 0 - - self.last_computed_offset_p[:num_decodes].copy_( - last_computed_offset_p, non_blocking=True) - last_computed_offset_p = \ - self.last_computed_offset_p[:num_input_tokens] - last_computed_offset_p[num_decodes:] = 0 - attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, From d68e4916aab5ba01daf4e92037c4850f3a1de701 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 15:57:50 -0400 Subject: [PATCH 098/105] minor cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 91bbc1a38526..56df9cf511e6 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -579,8 +579,6 @@ def forward_cuda( if prefix_caching_enabled: # If prefix caching is enabled, retrieve the relevant variables # for prefill and decode - - # Split decodes and prefills: last_state_idx_d, last_state_idx_p = torch.split( attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0) @@ -594,8 +592,8 @@ def forward_cuda( else: last_state_idx_d, last_state_idx_p = None, None current_last_idx_d, current_last_idx_p = None, None - _, current_first_idx_p = None, None - _, context_lens_p = None, None + current_first_idx_p = None + context_lens_p = None # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs From 5e8a63a171ff42d9871315c3d6c0e15b04e943ea Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 15:59:21 -0400 Subject: [PATCH 099/105] Add comment Signed-off-by: Thomas Parnell --- vllm/model_executor/models/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index fb67968ae90a..4344b48cdefa 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -393,6 +393,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # End result: # attn_block_size = 512 # mamba_block_size = 512 (aligned to a multiple of chunk_size) + # TODO(tdoublep): this constraint can be relaxed fairly + # easily by changing the way we layout chunks in the + # mamba2 kernels. chunk_size = model_config.get_mamba_chunk_size() attn_tokens_per_mamba_state = \ cdiv(mamba_page_size, attn_page_size_1_token) From 36cc1ce098eac7698e3c34cb63682fd39fca5531 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 16:20:44 -0400 Subject: [PATCH 100/105] Enable other mamba2 models Signed-off-by: Thomas Parnell --- vllm/model_executor/models/config.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 4344b48cdefa..59906bf74976 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -292,15 +292,14 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config compilation_config = vllm_config.compilation_config - # TODO: find a way to keep this list updated, or redundant + # TODO(@tdoublep) find a better way to do this than whitelist MAMBA2_MODELS = [ "BambaForCausalLM", - #"FalconH1ForCausalLM", + "FalconH1ForCausalLM", "GraniteMoeHybridForCausalLM", "Mamba2ForCausalLM", - #"NemotronHForCausalLM", - #"Plamo2ForCausalLM", - #"Zamba2ForCausalLM", + "NemotronHForCausalLM", + "Zamba2ForCausalLM", ] if cache_config.enable_prefix_caching: if model_config.architecture in MAMBA2_MODELS: From b006aba0b2bcc02e2bf46586d0d6ade75851db4a Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 17:12:31 -0400 Subject: [PATCH 101/105] Fix computation of prefill-only tensors Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/mamba2_attn.py | 31 +++++++++++++++-------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index bd9be7dc6afd..49fe1584e79c 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -181,6 +181,10 @@ def build(self, # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + context_lens, context_lens_p = None, None + current_first_idx, current_first_idx_p = None, None + last_computed_offset, last_computed_offset_p = None, None + if self.vllm_config.cache_config.enable_prefix_caching: # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor @@ -190,10 +194,10 @@ def build(self, seq_lens_pending = ( torch.roll(common_attn_metadata.query_start_loc, -1, -1) - common_attn_metadata.query_start_loc)[:-1] - context_lens_p = common_attn_metadata.seq_lens - \ + context_lens = common_attn_metadata.seq_lens - \ seq_lens_pending - last_computed_offset_p = \ - context_lens_p % mamba_block_size + last_computed_offset = \ + context_lens % mamba_block_size # Indices: last_computed <= current_first <= current_last # Cases: # last_computed == current_first if last state was partially @@ -201,24 +205,21 @@ def build(self, # current_first == current_last if no block crossing occurs, and # only one state will be stored # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: - current_last_idx = cdiv(context_lens_p + seq_lens_pending, + current_last_idx = cdiv(context_lens + seq_lens_pending, mamba_block_size) - 1 - current_first_idx_p = cdiv(context_lens_p + 1, - mamba_block_size) - 1 - last_state_idx = cdiv(context_lens_p, mamba_block_size) - 1 + current_first_idx = cdiv(context_lens + 1, mamba_block_size) - 1 + last_state_idx = cdiv(context_lens, mamba_block_size) - 1 # -1 in case it's non-computed and causes later issues with indexing last_state_idx = \ last_state_idx.clamp(min=0) + else: # Always return just a single block per each request: state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] # Additional cache-related varaiables: current_last_idx = None - current_first_idx_p = None last_state_idx = None - last_computed_offset_p = None - context_lens_p = None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -238,6 +239,16 @@ def build(self, query_start_loc_p = common_attn_metadata.query_start_loc[ -num_prefills - 1:] - num_decode_tokens + if self.vllm_config.cache_config.enable_prefix_caching: + assert context_lens is not None + context_lens_p = context_lens[num_reqs - num_prefills:num_reqs] + assert last_computed_offset is not None + last_computed_offset_p = last_computed_offset[ + num_reqs - num_prefills:num_reqs] + assert current_first_idx is not None + current_first_idx_p = current_first_idx[num_reqs - + num_prefills:num_reqs] + num_computed_tokens_p = \ common_attn_metadata.num_computed_tokens_cpu[ num_reqs - num_prefills:num_reqs] From b256275bcda76b5322dc997ef0db1979a86799f8 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 17:18:10 -0400 Subject: [PATCH 102/105] Ensure that mamba_block_size always get set Signed-off-by: Thomas Parnell --- vllm/model_executor/models/config.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 59906bf74976..354fda8c7a52 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -292,6 +292,10 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config compilation_config = vllm_config.compilation_config + # Set mamba block size to max_model_len (this may get + # override by prefix caching logic later) + cache_config.mamba_block_size = model_config.max_model_len + # TODO(@tdoublep) find a better way to do this than whitelist MAMBA2_MODELS = [ "BambaForCausalLM", @@ -411,8 +415,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # with mamba layers, use FlashInfer instead). attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token) - cache_config.mamba_block_size = model_config.max_model_len + print("cache_config.mamba_block_size: ", cache_config.mamba_block_size) # override attention block size if either (a) the # user has not set it or (b) the user has set it # too small. From 745af7385c90bd8060a8e65b90649be49a259bec Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Fri, 3 Oct 2025 21:52:12 +0000 Subject: [PATCH 103/105] Fixing argument description of conv1D Signed-off-by: Thomas Ortner --- .../layers/mamba/ops/causal_conv1d.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 44be0dac0919..a02bba5d4ddd 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -493,11 +493,11 @@ def causal_conv1d_fn( in this case, the kernel will not process entries at indices 0 and 3 current_first_idx: (batch,), dtype int32 - The pointer into cache_indices, which signifies the first cache block to be filled. + The pointer into cache_indices, where the first cache block to be filled is located. current_last_idx: (batch,), dtype int32 - The pointer into cache_indices, which signifies the last cache block to be filled. + The pointer into cache_indices, where the last cache block to be filled is located. initial_state_idx: (batch,), dtype int32 - The pointer into cache_indices, which signifies the cache block containing the initial state. + The pointer into cache_indices, where the cache block containing the initial state is located. context_lens: (batch,), dtype int32 The number of tokens already completed for each sequence block_size_to_align: int @@ -1055,9 +1055,9 @@ def causal_conv1d_update( and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. current_last_idx: (batch,), dtype int32 - The pointer into cache_indices, which signifies the last cache block to be filled. + The pointer into conv_state_indices, where the last cache block to be filled is located. initial_state_idx: (batch,), dtype int32 - The pointer into cache_indices, which signifies the cache block containing the initial state. + The pointer into conv_state_indices, where the cache block containing the initial state is located. num_accepted_tokens: (batch,), dtype int32 If not None, it indicates the number of accepted tokens for each sequence in the batch. @@ -1070,9 +1070,9 @@ def causal_conv1d_update( If query_start_loc is not None, this indicates the maximum query length in the batch. pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded + if conv_state_indices is passed, lets the kernel identify padded entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` From dcfc5ad42981310e33405d1a3981f94d0bf1ae82 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 18:55:02 -0400 Subject: [PATCH 104/105] Remove assert for models that support prefix caching Signed-off-by: Thomas Parnell --- vllm/model_executor/models/falcon_h1.py | 3 --- vllm/model_executor/models/mamba2.py | 3 --- vllm/model_executor/models/nemotron_h.py | 3 --- vllm/model_executor/models/zamba2.py | 3 --- 4 files changed, 12 deletions(-) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index f382018e2222..ccea9add093f 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -540,11 +540,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert (not cache_config.enable_prefix_caching - ), "FalconH1 currently does not support prefix caching" self.quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index f8a5a8f6081b..250698a61387 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -222,11 +222,8 @@ def get_mamba_state_shape_from_config( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Mamba does not support prefix caching" super().__init__() self.config = config diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 987920ecc331..c89550923938 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -505,11 +505,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "NemotronH currently does not support prefix caching" self.quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 1d68320bd9b2..1803fa259cf4 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -868,11 +868,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: (not supported by Mamba) """ config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Mamba does not support prefix caching" super().__init__() self.config = config From a0a4c40a8032b6aa773fc413b55e8abfb3e622d3 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 3 Oct 2025 22:35:48 -0400 Subject: [PATCH 105/105] remove debug print Signed-off-by: Thomas Parnell --- vllm/model_executor/models/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 354fda8c7a52..283cd2bb8b41 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -416,7 +416,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token) - print("cache_config.mamba_block_size: ", cache_config.mamba_block_size) # override attention block size if either (a) the # user has not set it or (b) the user has set it # too small.