@@ -173,14 +173,16 @@ def __init__(self,
173173 if self .enforce_eager is None :
174174 self .enforce_eager = False
175175
176- if (not self .disable_sliding_window
177- and self .hf_text_config .model_type == "gemma2"
178- and self .hf_text_config .sliding_window is not None ):
176+ has_interleaved_attention = isinstance (self .hf_text_config .sliding_window , list ) or (self .hf_text_config .model_type in ["gemma2" ] and self .hf_text_config .sliding_window is not None )
177+
178+ if (not self .disable_sliding_window and has_interleaved_attention ):
179+ sliding_window_len_min = get_min_sliding_window (self .hf_text_config .sliding_window )
180+
179181 print_warning_once (
180- "Gemma 2 uses sliding window attention for every odd layer , "
182+ f" { self . hf_text_config . model_type } has interleaved attention, "
181183 "which is currently not supported by vLLM. Disabling sliding "
182184 "window and capping the max length to the sliding window size "
183- f"({ self . hf_text_config . sliding_window } )." )
185+ f"({ sliding_window_len_min } )." )
184186 self .disable_sliding_window = True
185187
186188 self .max_model_len = _get_and_verify_max_len (
@@ -422,7 +424,7 @@ def verify_with_parallel_config(
422424 "pipeline parallelism currently. Disabling it." )
423425 self .use_async_output_proc = False
424426
425- def get_hf_config_sliding_window (self ) -> Optional [int ]:
427+ def get_hf_config_sliding_window (self ) -> Optional [Union [ int , List [ int ]] ]:
426428 """Get the sliding window size, or None if disabled."""
427429
428430 # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
@@ -1680,7 +1682,7 @@ def _get_and_verify_max_len(
16801682 hf_config : PretrainedConfig ,
16811683 max_model_len : Optional [int ],
16821684 disable_sliding_window : bool ,
1683- sliding_window_len : Optional [int ],
1685+ sliding_window_len : Optional [Union [ int , List [ Optional [ int ]]] ],
16841686 spec_target_max_model_len : Optional [int ] = None ,
16851687) -> int :
16861688 """Get and verify the model's maximum length."""
@@ -1713,9 +1715,11 @@ def _get_and_verify_max_len(
17131715 # If sliding window is manually disabled, max_length should be less
17141716 # than the sliding window length in the model config.
17151717 if disable_sliding_window and sliding_window_len is not None :
1718+
1719+ sliding_window_len_min = get_min_sliding_window (sliding_window_len )
17161720 max_len_key = "sliding_window" \
1717- if sliding_window_len < derived_max_model_len else max_len_key
1718- derived_max_model_len = min (derived_max_model_len , sliding_window_len )
1721+ if sliding_window_len_min < derived_max_model_len else max_len_key
1722+ derived_max_model_len = min (derived_max_model_len , sliding_window_len_min )
17191723
17201724 # If none of the keys were found in the config, use a default and
17211725 # log a warning.
@@ -1803,6 +1807,13 @@ def _get_and_verify_max_len(
18031807 return int (max_model_len )
18041808
18051809
1810+ def get_min_sliding_window (sliding_window : Union [int , List [Optional [int ]]]) -> int :
1811+ if isinstance (sliding_window , list ):
1812+ return min ([s for s in sliding_window if s is not None ])
1813+ else :
1814+ return sliding_window
1815+
1816+
18061817def get_served_model_name (model : str ,
18071818 served_model_name : Optional [Union [str , List [str ]]]):
18081819 """
0 commit comments