Skip to content

Commit b6e20a7

Browse files
WIP
1 parent e9d517f commit b6e20a7

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

vllm/config.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,16 @@ def __init__(self,
173173
if self.enforce_eager is None:
174174
self.enforce_eager = False
175175

176-
if (not self.disable_sliding_window
177-
and self.hf_text_config.model_type == "gemma2"
178-
and self.hf_text_config.sliding_window is not None):
176+
has_interleaved_attention = isinstance(self.hf_text_config.sliding_window, list) or (self.hf_text_config.model_type in ["gemma2"] and self.hf_text_config.sliding_window is not None)
177+
178+
if (not self.disable_sliding_window and has_interleaved_attention):
179+
sliding_window_len_min = get_min_sliding_window(self.hf_text_config.sliding_window)
180+
179181
print_warning_once(
180-
"Gemma 2 uses sliding window attention for every odd layer, "
182+
f"{self.hf_text_config.model_type} has interleaved attention, "
181183
"which is currently not supported by vLLM. Disabling sliding "
182184
"window and capping the max length to the sliding window size "
183-
f"({self.hf_text_config.sliding_window}).")
185+
f"({sliding_window_len_min}).")
184186
self.disable_sliding_window = True
185187

186188
self.max_model_len = _get_and_verify_max_len(
@@ -422,7 +424,7 @@ def verify_with_parallel_config(
422424
"pipeline parallelism currently. Disabling it.")
423425
self.use_async_output_proc = False
424426

425-
def get_hf_config_sliding_window(self) -> Optional[int]:
427+
def get_hf_config_sliding_window(self) -> Optional[Union[int, List[int]]]:
426428
"""Get the sliding window size, or None if disabled."""
427429

428430
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
@@ -1680,7 +1682,7 @@ def _get_and_verify_max_len(
16801682
hf_config: PretrainedConfig,
16811683
max_model_len: Optional[int],
16821684
disable_sliding_window: bool,
1683-
sliding_window_len: Optional[int],
1685+
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
16841686
spec_target_max_model_len: Optional[int] = None,
16851687
) -> int:
16861688
"""Get and verify the model's maximum length."""
@@ -1713,9 +1715,11 @@ def _get_and_verify_max_len(
17131715
# If sliding window is manually disabled, max_length should be less
17141716
# than the sliding window length in the model config.
17151717
if disable_sliding_window and sliding_window_len is not None:
1718+
1719+
sliding_window_len_min = get_min_sliding_window(sliding_window_len)
17161720
max_len_key = "sliding_window" \
1717-
if sliding_window_len < derived_max_model_len else max_len_key
1718-
derived_max_model_len = min(derived_max_model_len, sliding_window_len)
1721+
if sliding_window_len_min < derived_max_model_len else max_len_key
1722+
derived_max_model_len = min(derived_max_model_len, sliding_window_len_min)
17191723

17201724
# If none of the keys were found in the config, use a default and
17211725
# log a warning.
@@ -1803,6 +1807,13 @@ def _get_and_verify_max_len(
18031807
return int(max_model_len)
18041808

18051809

1810+
def get_min_sliding_window(sliding_window: Union[int, List[Optional[int]]]) -> int:
1811+
if isinstance(sliding_window, list):
1812+
return min([s for s in sliding_window if s is not None])
1813+
else:
1814+
return sliding_window
1815+
1816+
18061817
def get_served_model_name(model: str,
18071818
served_model_name: Optional[Union[str, List[str]]]):
18081819
"""

0 commit comments

Comments
 (0)