Skip to content

Commit 8ece569

Browse files
cynthieyeshreyankg
authored andcommitted
[Perf]:Optimize qwen2-vl to reduce cudaMemcpyAsync (vllm-project#14377)
Signed-off-by: cynthieye <[email protected]>
1 parent 03fbe82 commit 8ece569

File tree

2 files changed

+70
-24
lines changed

2 files changed

+70
-24
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,12 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
255255
return q, k, v
256256

257257
def forward(
258-
self,
259-
x: torch.Tensor,
260-
cu_seqlens: torch.Tensor,
261-
rotary_pos_emb: torch.Tensor,
258+
self,
259+
x: torch.Tensor,
260+
cu_seqlens: torch.Tensor,
261+
rotary_pos_emb: torch.Tensor,
262+
max_seqlen: Optional[int] = None, # Only used for Flash Attention
263+
seqlens: Optional[list[int]] = None, # Only used for xFormers
262264
) -> torch.Tensor:
263265
# [s, b, c] --> [s, b, head * 3 * head_dim]
264266
x, _ = self.qkv(x)
@@ -285,7 +287,6 @@ def forward(
285287

286288
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
287289

288-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
289290
output = flash_attn_varlen_func(q,
290291
k,
291292
v,
@@ -321,7 +322,6 @@ def forward(
321322
from xformers import ops as xops
322323
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
323324

324-
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
325325
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
326326
kv_seqlen=None,
327327
device=q.device)
@@ -364,11 +364,20 @@ def __init__(
364364
quant_config=quant_config,
365365
prefix=f"{prefix}.mlp")
366366

367-
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
368-
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
367+
def forward(
368+
self,
369+
x: torch.Tensor,
370+
cu_seqlens: torch.Tensor,
371+
rotary_pos_emb: torch.Tensor,
372+
max_seqlen: Optional[int] = None, # Only used for Flash Attention
373+
seqlens: Optional[list[int]] = None, # Only used for xFormers
374+
) -> torch.Tensor:
369375
x = x + self.attn(self.norm1(x),
370376
cu_seqlens=cu_seqlens,
371-
rotary_pos_emb=rotary_pos_emb)
377+
rotary_pos_emb=rotary_pos_emb,
378+
max_seqlen=max_seqlen,
379+
seqlens=seqlens)
380+
372381
x = x + self.mlp(self.norm2(x))
373382
return x
374383

@@ -528,6 +537,7 @@ def __init__(
528537
quant_config=quant_config,
529538
prefix=f"{prefix}.merger",
530539
)
540+
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
531541

532542
@property
533543
def dtype(self) -> torch.dtype:
@@ -633,14 +643,25 @@ def forward(
633643

634644
# transformers
635645
hidden_states = hidden_states.unsqueeze(1)
646+
647+
max_seqlen = None
648+
seqlens = None
649+
if self.attn_backend == _Backend.FLASH_ATTN:
650+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
651+
elif self.attn_backend == _Backend.XFORMERS:
652+
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
636653
for layer_num, blk in enumerate(self.blocks):
637654
if layer_num in self.fullatt_block_indexes:
638655
cu_seqlens_now = cu_seqlens
639656
else:
640657
cu_seqlens_now = cu_window_seqlens
641-
hidden_states = blk(hidden_states,
642-
cu_seqlens=cu_seqlens_now,
643-
rotary_pos_emb=rotary_pos_emb)
658+
hidden_states = blk(
659+
hidden_states,
660+
cu_seqlens=cu_seqlens_now,
661+
rotary_pos_emb=rotary_pos_emb,
662+
max_seqlen=max_seqlen,
663+
seqlens=seqlens,
664+
)
644665

645666
# For Qwen2.5-VL-3B, float16 will overflow at last block
646667
# for long visual tokens sequences.

vllm/model_executor/models/qwen2_vl.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,12 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
303303
return q, k, v
304304

305305
def forward(
306-
self,
307-
x: torch.Tensor,
308-
cu_seqlens: torch.Tensor,
309-
rotary_pos_emb: torch.Tensor,
306+
self,
307+
x: torch.Tensor,
308+
cu_seqlens: torch.Tensor,
309+
rotary_pos_emb: torch.Tensor,
310+
max_seqlen: Optional[int] = None, # Only used for Flash Attention
311+
seqlens: Optional[list[int]] = None, # Only used for xFormers
310312
) -> torch.Tensor:
311313

312314
# [s, b, c] --> [s, b, 3 * head * head_dim]
@@ -329,7 +331,6 @@ def forward(
329331

330332
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
331333

332-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
333334
output = flash_attn_varlen_func(q,
334335
k,
335336
v,
@@ -365,7 +366,6 @@ def forward(
365366
from xformers import ops as xops
366367
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
367368

368-
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
369369
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
370370
kv_seqlen=None,
371371
device=q.device)
@@ -409,11 +409,22 @@ def __init__(
409409
quant_config=quant_config,
410410
prefix=f"{prefix}.mlp")
411411

412-
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
413-
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
414-
x = x + self.attn(self.norm1(x),
415-
cu_seqlens=cu_seqlens,
416-
rotary_pos_emb=rotary_pos_emb)
412+
def forward(
413+
self,
414+
x: torch.Tensor,
415+
cu_seqlens: torch.Tensor,
416+
rotary_pos_emb: torch.Tensor,
417+
max_seqlen: Optional[int] = None, # Only used for Flash Attention
418+
seqlens: Optional[list[int]] = None, # Only used for xFormers
419+
) -> torch.Tensor:
420+
x = x + self.attn(
421+
self.norm1(x),
422+
cu_seqlens=cu_seqlens,
423+
rotary_pos_emb=rotary_pos_emb,
424+
max_seqlen=max_seqlen,
425+
seqlens=seqlens,
426+
)
427+
417428
x = x + self.mlp(self.norm2(x))
418429
return x
419430

@@ -570,6 +581,7 @@ def __init__(
570581
quant_config=quant_config,
571582
prefix=f"{prefix}.merger",
572583
)
584+
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
573585

574586
@property
575587
def dtype(self) -> torch.dtype:
@@ -624,8 +636,21 @@ def forward(
624636

625637
# transformers
626638
x = x.unsqueeze(1)
639+
640+
max_seqlen = None
641+
seqlens = None
642+
if self.attn_backend == _Backend.FLASH_ATTN:
643+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
644+
elif self.attn_backend == _Backend.XFORMERS:
645+
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
627646
for blk in self.blocks:
628-
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
647+
x = blk(
648+
x,
649+
cu_seqlens=cu_seqlens,
650+
rotary_pos_emb=rotary_pos_emb,
651+
max_seqlen=max_seqlen,
652+
seqlens=seqlens,
653+
)
629654

630655
# adapter
631656
x = self.merger(x)

0 commit comments

Comments
 (0)