From 13454ade8409af1dd6a0f30a080d1b0e92ef49e4 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Wed, 29 Oct 2025 11:46:53 -0700 Subject: [PATCH 1/2] Move sdpa to custom op until tensor slicing supported Signed-off-by: Lucas Kabela --- vllm/attention/ops/vit_attn_wrappers.py | 53 ++++++++++++++++++++++++ vllm/model_executor/models/qwen2_5_vl.py | 24 ++++------- 2 files changed, 60 insertions(+), 17 deletions(-) diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index f71f49a1a31b..89aeb9d76f5b 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -14,6 +14,7 @@ import einops import torch +import torch.nn.functional as F from vllm.utils.torch_utils import direct_register_custom_op @@ -123,3 +124,55 @@ def vit_flash_attn_wrapper( return torch.ops.vllm.flash_attn_maxseqlen_wrapper( q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa ) + + +# TODO: Once we have a torch 2.10, we can use tensor slices +# so we won't need to wrap this in custom ops +def torch_sdpa_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> torch.Tensor: + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = ( + einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) + output_i = einops.rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous() + return context_layer + + +def torch_sdpa_wrapper_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> torch.Tensor: + b, s, h, d = q.shape + return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + + +direct_register_custom_op( + op_name="torch_sdpa_wrapper", + op_func=torch_sdpa_wrapper, + fake_impl=torch_sdpa_wrapper_fake, +) + + +def vit_torch_sdpa_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> torch.Tensor: + return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index dfaeb663bbe2..c33218dd2495 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -49,6 +49,7 @@ ) from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, + vit_torch_sdpa_wrapper, vit_xformers_attn_wrapper, ) from vllm.compilation.decorators import support_torch_compile @@ -436,23 +437,12 @@ def forward( q = q.contiguous() k = k.contiguous() v = v.contiguous() - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = einops.rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = einops.rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + context_layer = vit_torch_sdpa_wrapper( + q, + k, + v, + cu_seqlens, + ) elif self.attn_backend == _Backend.XFORMERS: context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) From 50ffd98e11093278b3708e0088b4e090e7c3db0b Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Wed, 29 Oct 2025 17:00:49 -0700 Subject: [PATCH 2/2] Reenable compiling vision block Signed-off-by: Lucas Kabela --- vllm/model_executor/models/qwen2_5_vl.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c33218dd2495..94a4f30cbae2 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -450,17 +450,15 @@ def forward( return output -# (FIXME): Enable this after dynamic slicing is fixed -# See https://github.com/vllm-project/vllm/pull/27760 -# @support_torch_compile( -# dynamic_arg_dims={ -# "x": 0, -# "cu_seqlens": 0, -# "rotary_pos_emb": 0, -# "seqlens": 0, -# }, -# mark_unbacked_dims={"seqlens": 0}, -# ) +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + "cu_seqlens": 0, + "rotary_pos_emb": 0, + "seqlens": 0, + }, + mark_unbacked_dims={"seqlens": 0}, +) class Qwen2_5_VisionBlock(nn.Module): def __init__( self,