@@ -466,17 +466,10 @@ def forward_cuda(
466466 if has_prefill :
467467
468468 initial_states = None
469-
470- if has_initial_states is not None and torch .any (
471- has_initial_states ):
472-
473- # vectorized ssm_state zero init
474- batched_zero_init_func = torch .vmap (
475- lambda idx : mamba_cache_params .ssm_state [idx ].zero_ ())
476- batched_zero_init_func (
477- mamba_cache_params .
478- state_indices_tensor [~ has_initial_states ].unsqueeze (
479- dim = - 1 ), )
469+ if has_initial_states is not None and any (has_initial_states ):
470+ for idx in mamba_cache_params .state_indices_tensor [
471+ ~ has_initial_states ]:
472+ mamba_cache_params .ssm_state [idx ].zero_ ()
480473 initial_states = mamba_cache_params .ssm_state [
481474 mamba_cache_params .state_indices_tensor ]
482475
@@ -500,17 +493,10 @@ def forward_cuda(
500493 dt_limit = (0.0 , float ("inf" )),
501494 )
502495
503- # vectorized ssm state update using vmap
504- # the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
505- # limitation which doesn't allow use of `item()`
506- # Note: the lambda capture can happen where ssm_state is initialized
507- # instead of here
508- batched_copy = torch .vmap (
509- lambda idx , source_state : mamba_cache_params .ssm_state [
510- idx ].copy_ (source_state ))
511- batched_copy (
512- mamba_cache_params .state_indices_tensor .unsqueeze (dim = - 1 ),
513- varlen_state )
496+ # update ssm states
497+ # - varlen state is a (batch, nheads, headdim, dstate) tensor
498+ for i , idx in enumerate (mamba_cache_params .state_indices_tensor ):
499+ mamba_cache_params .ssm_state [idx ].copy_ (varlen_state [i ])
514500
515501 # - reshape
516502 hidden_states = scan_output .view (seq_len , - 1 )
0 commit comments