Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions vllm/attention/ops/vit_attn_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import einops
import torch
import torch.nn.functional as F

from vllm.utils.torch_utils import direct_register_custom_op

Expand Down Expand Up @@ -123,3 +124,55 @@ def vit_flash_attn_wrapper(
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
)


# TODO: Once we have a torch 2.10, we can use tensor slices
# so we won't need to wrap this in custom ops
def torch_sdpa_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb question: How does this fix the tensor slicing issue? It seems that you simply put the whole SDPA into a separate function but the code remains identical?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is when torch.compile tries to trace this code - by wrapping it in a custom_op and calling that, the internals of the function are not traced (so the tracing bug does not trigger)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it is not just that we move it into a function, but that we also leverage the custom_op mechanism here to make it opaque

k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer


def torch_sdpa_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)


direct_register_custom_op(
op_name="torch_sdpa_wrapper",
op_func=torch_sdpa_wrapper,
fake_impl=torch_sdpa_wrapper_fake,
)


def vit_torch_sdpa_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)
44 changes: 16 additions & 28 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
vit_xformers_attn_wrapper,
)
from vllm.compilation.decorators import support_torch_compile
Expand Down Expand Up @@ -442,41 +443,28 @@ def forward(
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = einops.rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
context_layer = vit_torch_sdpa_wrapper(
q,
k,
v,
cu_seqlens,
)
elif self.attn_backend == _Backend.XFORMERS:
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)

output, _ = self.proj(context_layer)
return output


# (FIXME): Enable this after dynamic slicing is fixed
# See https://github.com/vllm-project/vllm/pull/27760
# @support_torch_compile(
# dynamic_arg_dims={
# "x": 0,
# "cu_seqlens": 0,
# "rotary_pos_emb": 0,
# "seqlens": 0,
# },
# mark_unbacked_dims={"seqlens": 0},
# )
@support_torch_compile(
dynamic_arg_dims={
"x": 0,
"cu_seqlens": 0,
"rotary_pos_emb": 0,
"seqlens": 0,
},
mark_unbacked_dims={"seqlens": 0},
)
class Qwen2_5_VisionBlock(nn.Module):
def __init__(
self,
Expand Down