1616from vllm .distributed import (divide , get_tensor_model_parallel_rank ,
1717 get_tensor_model_parallel_world_size )
1818from vllm .forward_context import get_forward_context
19- from vllm .model_executor .layers .fla .ops . chunk import chunk_gated_delta_rule , RMSNormGated , chunk
19+ from vllm .model_executor .layers .fla .ops import chunk , chunk_gated_delta_rule
2020from vllm .model_executor .layers .fla .ops .fused_recurrent import \
2121 fused_recurrent_gated_delta_rule
2222from vllm .model_executor .layers .fused_moe import FusedMoE
@@ -579,7 +579,7 @@ def _forward_core(
579579 mixed_qkv_spec = mixed_qkv_spec .view (
580580 attn_metadata .num_spec_decodes , - 1 , mixed_qkv_spec .size (- 1 ))
581581 mixed_qkv_spec = rearrange (mixed_qkv_spec , 'b l d -> b d l' )
582- mixed_qkv_spec = causal_conv1d_update (
582+ mixed_qkv_spec = causal_conv1d . causal_conv1d_update (
583583 mixed_qkv_spec ,
584584 conv_state ,
585585 conv_weights ,
@@ -596,7 +596,7 @@ def _forward_core(
596596 if attn_metadata .num_prefills > 0 :
597597 # - "cache_indices" updates the conv_state cache in positions
598598 # pointed to by "mamba_cache_params.state_indices_tensor"
599- mixed_qkv_non_spec = causal_conv1d_fn (
599+ mixed_qkv_non_spec = causal_conv1d . causal_conv1d_fn (
600600 mixed_qkv_non_spec .transpose (0 , 1 ),
601601 conv_weights ,
602602 self .conv1d .bias ,
@@ -607,7 +607,7 @@ def _forward_core(
607607 query_start_loc = non_spec_query_start_loc ,
608608 ).transpose (0 , 1 )
609609 elif attn_metadata .num_decodes > 0 :
610- mixed_qkv_non_spec = causal_conv1d_update (
610+ mixed_qkv_non_spec = causal_conv1d . causal_conv1d_update (
611611 mixed_qkv_non_spec ,
612612 conv_state ,
613613 conv_weights ,
0 commit comments