Skip to content

Commit 69ebbe1

Browse files
tlrmchlsmthlulmer
authored andcommitted
Revert "[Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of U… (vllm-project#14848)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent fc9448c commit 69ebbe1

File tree

1 file changed

+8
-22
lines changed

1 file changed

+8
-22
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -466,17 +466,10 @@ def forward_cuda(
466466
if has_prefill:
467467

468468
initial_states = None
469-
470-
if has_initial_states is not None and torch.any(
471-
has_initial_states):
472-
473-
# vectorized ssm_state zero init
474-
batched_zero_init_func = torch.vmap(
475-
lambda idx: mamba_cache_params.ssm_state[idx].zero_())
476-
batched_zero_init_func(
477-
mamba_cache_params.
478-
state_indices_tensor[~has_initial_states].unsqueeze(
479-
dim=-1), )
469+
if has_initial_states is not None and any(has_initial_states):
470+
for idx in mamba_cache_params.state_indices_tensor[
471+
~has_initial_states]:
472+
mamba_cache_params.ssm_state[idx].zero_()
480473
initial_states = mamba_cache_params.ssm_state[
481474
mamba_cache_params.state_indices_tensor]
482475

@@ -500,17 +493,10 @@ def forward_cuda(
500493
dt_limit=(0.0, float("inf")),
501494
)
502495

503-
# vectorized ssm state update using vmap
504-
# the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
505-
# limitation which doesn't allow use of `item()`
506-
# Note: the lambda capture can happen where ssm_state is initialized
507-
# instead of here
508-
batched_copy = torch.vmap(
509-
lambda idx, source_state: mamba_cache_params.ssm_state[
510-
idx].copy_(source_state))
511-
batched_copy(
512-
mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1),
513-
varlen_state)
496+
# update ssm states
497+
# - varlen state is a (batch, nheads, headdim, dstate) tensor
498+
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
499+
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
514500

515501
# - reshape
516502
hidden_states = scan_output.view(seq_len, -1)

0 commit comments

Comments
 (0)