diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index ec8e1f2ee5a6..9fa76634e1fc 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -673,7 +673,9 @@ def _run_memory_efficient_xformers_forward( # Cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) + attn_metadata.seq_lens, + attn_metadata.encoder_seq_lens, + device=query.device) # Encoder branch of encoder-decoder model uses # attn_metadata.encoder_seq_lens @@ -683,7 +685,7 @@ def _run_memory_efficient_xformers_forward( # Encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.encoder_seq_lens) + attn_metadata.encoder_seq_lens, device=query.device) # Self-attention block of encoder-only model just # uses the seq_lens directly. @@ -692,7 +694,7 @@ def _run_memory_efficient_xformers_forward( # Encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens) + attn_metadata.seq_lens, device=query.device) # Self-attention block of decoder branch just # uses the seq_lens directly @@ -701,7 +703,7 @@ def _run_memory_efficient_xformers_forward( # Decoder self-attention mask is causal attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens) + attn_metadata.seq_lens, device=query.device) else: raise ValueError("Unknown AttentionType: %s", attn_type)