From 34ff477e4d13e913be05a0b77962a349e15ea6ea Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 27 Feb 2025 22:46:13 +0800 Subject: [PATCH 1/2] fix qwen2.5-vl overflow issue Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/minicpmo.py | 11 +++-------- vllm/model_executor/models/qwen2_5_vl.py | 7 ++++++- vllm/model_executor/models/utils.py | 10 ++++++++++ vllm/model_executor/models/whisper.py | 9 +++------ 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index e354e5323327..e6111f46143d 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -47,7 +47,7 @@ MiniCPMVMultiModalDataParser, MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, _minicpmv_field_config) -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix CPU_DEVICE = torch.device("cpu") @@ -469,13 +469,8 @@ def forward( training=self.training) hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() - or torch.isnan(hidden_states).any()): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, - min=-clamp_value, - max=clamp_value) + if hidden_states.dtype == torch.float16: + hidden_states = cast_overflow_tensors(hidden_states) outputs = (hidden_states, ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 858cf28d2b87..0dbff665b5d3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -63,7 +63,7 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision) -from .utils import (AutoWeightsLoader, WeightsMapper, +from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) from .vision import get_vit_attn_backend @@ -641,6 +641,11 @@ def forward( cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb) + # For Qwen2.5-VL-3B, float16 will overflow at last block + # for long visual tokens sequences. + if hidden_states.dtype == torch.float16: + hidden_states = cast_overflow_tensors(hidden_states) + # adapter hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index fff4be34ddbe..5cfef4690a77 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -641,3 +641,13 @@ def extract_layer_index(layer_name: str) -> int: assert len(int_vals) == 1, (f"layer name {layer_name} should" " only contain one integer") return int_vals[0] + + +def cast_overflow_tensors( + tensors: torch.Tensor, + offset: float = 1000, +) -> Dict[str, torch.Tensor]: + if tensors.isinf().any() or tensors.isnan().any(): + clamp_value = torch.finfo(tensors.dtype).max - offset + tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) + return tensors diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index e5f77e08c403..a2eefbc6d899 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -35,7 +35,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .interfaces import SupportsMultiModal, SupportsTranscription -from .utils import AutoWeightsLoader, WeightsMapper, make_layers +from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, + make_layers) logger = init_logger(__name__) @@ -285,11 +286,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - if hidden_states.isinf().any() or hidden_states.isnan().any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, - min=-clamp_value, - max=clamp_value) + hidden_states = cast_overflow_tensors(hidden_states) return hidden_states From d30d0c133c3f47a6436788f82398b21eb23df8c0 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 27 Feb 2025 23:37:40 +0800 Subject: [PATCH 2/2] fix wrong annotation Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 5cfef4690a77..f9aa5da39a5f 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -646,7 +646,7 @@ def extract_layer_index(layer_name: str) -> int: def cast_overflow_tensors( tensors: torch.Tensor, offset: float = 1000, -) -> Dict[str, torch.Tensor]: +) -> torch.Tensor: if tensors.isinf().any() or tensors.isnan().any(): clamp_value = torch.finfo(tensors.dtype).max - offset tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)