@@ -82,8 +82,15 @@ def __init__(
8282 self .kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE [
8383 cache_config .cache_dtype ]
8484
85- self .is_multimodal_model = model_config .is_multimodal_model
85+ # NOTE(woosuk): sliding_window is None for models with interleaved
86+ # attention. Use interleaved_sliding_window instead.
8687 self .sliding_window = model_config .get_sliding_window ()
88+ self .interleaved_sliding_window = getattr (
89+ model_config .hf_text_config , "interleaved_sliding_window" , None )
90+ self .window_size = (self .sliding_window
91+ or self .interleaved_sliding_window )
92+
93+ self .is_multimodal_model = model_config .is_multimodal_model
8794 self .block_size = cache_config .block_size
8895 self .max_model_len = model_config .max_model_len
8996 self .max_num_blocks_per_req = cdiv (self .max_model_len , self .block_size )
@@ -674,7 +681,7 @@ def _compute_cascade_attn_prefix_len(
674681 num_query_heads = self .num_query_heads ,
675682 num_kv_heads = self .num_kv_heads ,
676683 use_alibi = False , # FIXME
677- use_sliding_window = self .sliding_window is not None ,
684+ use_sliding_window = self .window_size is not None ,
678685 num_sms = self .num_sms ,
679686 )
680687 return common_prefix_len if use_cascade else 0
0 commit comments