@@ -673,7 +673,9 @@ def _run_memory_efficient_xformers_forward(
673673
674674 # Cross-attention mask is non-causal
675675 attn_bias = BlockDiagonalMask .from_seqlens (
676- attn_metadata .seq_lens , attn_metadata .encoder_seq_lens )
676+ attn_metadata .seq_lens ,
677+ attn_metadata .encoder_seq_lens ,
678+ device = query .device )
677679
678680 # Encoder branch of encoder-decoder model uses
679681 # attn_metadata.encoder_seq_lens
@@ -683,7 +685,7 @@ def _run_memory_efficient_xformers_forward(
683685
684686 # Encoder self-attention mask is non-causal
685687 attn_bias = BlockDiagonalMask .from_seqlens (
686- attn_metadata .encoder_seq_lens )
688+ attn_metadata .encoder_seq_lens , device = query . device )
687689
688690 # Self-attention block of encoder-only model just
689691 # uses the seq_lens directly.
@@ -692,7 +694,7 @@ def _run_memory_efficient_xformers_forward(
692694
693695 # Encoder self-attention mask is non-causal
694696 attn_bias = BlockDiagonalMask .from_seqlens (
695- attn_metadata .seq_lens )
697+ attn_metadata .seq_lens , device = query . device )
696698
697699 # Self-attention block of decoder branch just
698700 # uses the seq_lens directly
@@ -701,7 +703,7 @@ def _run_memory_efficient_xformers_forward(
701703
702704 # Decoder self-attention mask is causal
703705 attn_bias = BlockDiagonalCausalMask .from_seqlens (
704- attn_metadata .seq_lens )
706+ attn_metadata .seq_lens , device = query . device )
705707 else :
706708 raise ValueError ("Unknown AttentionType: %s" , attn_type )
707709
0 commit comments