11# SPDX-License-Identifier: Apache-2.0
22"""Attention layer ROCm GPUs."""
3+ import itertools
34from dataclasses import dataclass
45from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type
56
@@ -342,28 +343,27 @@ def _get_seq_len_block_table_args(
342343 Decoder attn -> select entirely decoder self-attention-related fields
343344 Encoder/decoder cross-attn -> select encoder sequence lengths
344345 Encoder attn -> select encoder sequence lengths fields
346+ Encoder-only attn -> select prefill sequence lengths with
347+ bidirectional attention
345348
346349 Arguments:
347350
348351 * attn_metadata: Attention metadata structure associated with attention op
349352 * attn_type: encoder attention, decoder self-attention,
350- encoder/decoder cross-attention
353+ encoder/decoder cross-attention, encoder-only
351354
352355 Returns:
353356
354357 * Appropriate sequence-lengths tensors for query and key
355358 * Appropriate max sequence-length scalar
359+ * Causal masking flag
356360 '''
357361
358- partial_prefix_sum = 0
359362 if attn_type == AttentionType .ENCODER :
360363 assert attn_metadata .encoder_seq_lens is not None
361364 assert attn_metadata .encoder_seq_lens_tensor is not None
362365 query_seq_start_loc = torch .tensor (
363- [0 ] + [
364- partial_prefix_sum := partial_prefix_sum + i
365- for i in attn_metadata .encoder_seq_lens
366- ],
366+ list (itertools .accumulate ([0 ] + attn_metadata .encoder_seq_lens )),
367367 device = attn_metadata .encoder_seq_lens_tensor .device ,
368368 dtype = attn_metadata .encoder_seq_lens_tensor .dtype )
369369 causal_mask = False
@@ -372,16 +372,29 @@ def _get_seq_len_block_table_args(
372372 return (query_seq_start_loc , attn_metadata .max_encoder_seq_len ,
373373 query_seq_start_loc , attn_metadata .max_encoder_seq_len ,
374374 attn_metadata .encoder_seq_lens , causal_mask )
375+
376+ elif attn_type == AttentionType .ENCODER_ONLY :
377+ # For encoder-only models, we use the prefill sequence lengths
378+ assert attn_metadata .seq_lens is not None
379+ assert attn_metadata .seq_lens_tensor is not None
380+ query_seq_start_loc = torch .tensor (
381+ list (itertools .accumulate ([0 ] + attn_metadata .seq_lens )),
382+ device = attn_metadata .seq_lens_tensor .device ,
383+ dtype = attn_metadata .seq_lens_tensor .dtype )
384+ max_seq_len = attn_metadata .max_prefill_seq_len
385+ # Encoder-only models typically use bidirectional attention
386+ causal_mask = False
387+
388+ return (query_seq_start_loc , max_seq_len , query_seq_start_loc ,
389+ max_seq_len , attn_metadata .seq_lens , causal_mask )
390+
375391 elif attn_type == AttentionType .DECODER :
376392 # Decoder self-attention
377393 # Choose max_seq_len based on whether we are in prompt_run
378394 assert attn_metadata .seq_lens is not None
379395 assert attn_metadata .seq_lens_tensor is not None
380396 query_seq_start_loc = torch .tensor (
381- [0 ] + [
382- partial_prefix_sum := partial_prefix_sum + i
383- for i in attn_metadata .seq_lens
384- ],
397+ list (itertools .accumulate ([0 ] + attn_metadata .seq_lens )),
385398 device = attn_metadata .seq_lens_tensor .device ,
386399 dtype = attn_metadata .seq_lens_tensor .dtype )
387400 max_seq_len = attn_metadata .max_prefill_seq_len
@@ -393,21 +406,14 @@ def _get_seq_len_block_table_args(
393406 assert attn_metadata .seq_lens is not None
394407 assert attn_metadata .encoder_seq_lens_tensor is not None
395408 query_start_loc = torch .tensor (
396- [0 ] + [
397- partial_prefix_sum := partial_prefix_sum + i
398- for i in attn_metadata .seq_lens
399- ],
409+ list (itertools .accumulate ([0 ] + attn_metadata .seq_lens )),
400410 device = attn_metadata .encoder_seq_lens_tensor .device ,
401411 dtype = attn_metadata .encoder_seq_lens_tensor .dtype )
402412
403- partial_prefix_sum = 0
404413 assert attn_metadata .encoder_seq_lens is not None
405414 assert attn_metadata .seq_lens_tensor is not None
406415 key_seq_start_loc = torch .tensor (
407- [0 ] + [
408- partial_prefix_sum := partial_prefix_sum + i
409- for i in attn_metadata .encoder_seq_lens
410- ],
416+ list (itertools .accumulate ([0 ] + attn_metadata .encoder_seq_lens )),
411417 device = attn_metadata .seq_lens_tensor .device ,
412418 dtype = attn_metadata .seq_lens_tensor .dtype )
413419 causal_mask = False
@@ -584,6 +590,8 @@ def forward(
584590 will match encoder sequence lengths, pass encoder sequence
585591 attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
586592 max_encoder_seq_len)
593+ * ENCODER_ONLY: bidirectional attention with no KV caching;
594+ use prefill sequence attributes
587595
588596 Args:
589597 query: shape = [num_tokens, num_heads * head_size]
@@ -608,7 +616,11 @@ def forward(
608616 else :
609617 assert value is None
610618
611- if self .attn_type != AttentionType .ENCODER and kv_cache .numel () > 0 :
619+ # Only update KV cache for decoder self-attention
620+ # and encoder-decoder cross-attention
621+ if self .attn_type not in [
622+ AttentionType .ENCODER , AttentionType .ENCODER_ONLY
623+ ] and kv_cache .numel () > 0 :
612624 key_cache , value_cache = PagedAttention .split_kv_cache (
613625 kv_cache , self .num_kv_heads , self .head_size )
614626
@@ -632,6 +644,9 @@ def forward(
632644
633645 if self .attn_type != AttentionType .ENCODER :
634646 num_prefill_tokens = attn_metadata .num_prefill_tokens
647+ elif self .attn_type == AttentionType .ENCODER_ONLY :
648+ # For encoder-only models, all tokens are processed in one go
649+ num_prefill_tokens = query .shape [0 ]
635650 else :
636651 assert attn_metadata .num_encoder_tokens is not None
637652 num_prefill_tokens = attn_metadata .num_encoder_tokens
@@ -642,8 +657,13 @@ def forward(
642657 # QKV for prefill.
643658 query = query [:num_prefill_tokens ]
644659
660+ # For encoder-only and encoder models,
661+ # we process all tokens at once
662+ # For decoder and encoder-decoder,
663+ # we may need to limit key/value to prefill tokens
645664 if key is not None and value is not None \
646- and self .attn_type != AttentionType .ENCODER_DECODER :
665+ and self .attn_type not in [AttentionType .ENCODER_DECODER ,
666+ AttentionType .ENCODER_ONLY ]:
647667 key = key [:num_prefill_tokens ]
648668 value = value [:num_prefill_tokens ]
649669
@@ -678,7 +698,7 @@ def forward(
678698 self .alibi_slopes ,
679699 query .dtype ,
680700 seq_lens ,
681- make_attn_mask = False ) # type: ignore
701+ make_attn_mask = causal_mask ) # type: ignore
682702 out , _ = self .attn_func (
683703 query ,
684704 key ,
@@ -703,7 +723,7 @@ def forward(
703723 self .alibi_slopes ,
704724 query .dtype ,
705725 attn_metadata .seq_lens ,
706- make_attn_mask = True ) # type: ignore
726+ make_attn_mask = causal_mask ) # type: ignore
707727 query = query .movedim (0 , query .dim () - 2 )
708728 key = key .movedim (0 , key .dim () - 2 )
709729 value = value .movedim (0 , value .dim () - 2 )
@@ -729,7 +749,7 @@ def forward(
729749 max_seqlen_q = prefill_meta .max_prefill_seq_len ,
730750 max_seqlen_k = key_max_seq_len ,
731751 softmax_scale = self .scale ,
732- causal = True ,
752+ causal = causal_mask ,
733753 window_size = self .sliding_window ,
734754 alibi_slopes = self .alibi_slopes ,
735755 softcap = self .logits_soft_cap ,
@@ -742,25 +762,29 @@ def forward(
742762 else :
743763 output = out
744764 else :
745- # prefix-enabled attention
746- output [:num_prefill_tokens ] = PagedAttention .forward_prefix (
747- query ,
748- key ,
749- value ,
750- self .kv_cache_dtype ,
751- key_cache ,
752- value_cache ,
753- prefill_meta .block_tables ,
754- prefill_meta .query_start_loc ,
755- prefill_meta .seq_lens_tensor ,
756- prefill_meta .max_query_len ,
757- self .alibi_slopes ,
758- self .sliding_window [0 ],
759- layer ._k_scale ,
760- layer ._v_scale ,
761- )
762-
763- if decode_meta := attn_metadata .decode_metadata :
765+ # prefix-enabled attention -
766+ # not applicable for encoder-only models
767+ if self .attn_type != AttentionType .ENCODER_ONLY :
768+ output [:
769+ num_prefill_tokens ] = PagedAttention .forward_prefix (
770+ query ,
771+ key ,
772+ value ,
773+ self .kv_cache_dtype ,
774+ key_cache ,
775+ value_cache ,
776+ prefill_meta .block_tables ,
777+ prefill_meta .query_start_loc ,
778+ prefill_meta .seq_lens_tensor ,
779+ prefill_meta .max_query_len ,
780+ self .alibi_slopes ,
781+ self .sliding_window [0 ],
782+ layer ._k_scale ,
783+ layer ._v_scale ,
784+ )
785+ # Skip decode phase for encoder-only models
786+ if (decode_meta := attn_metadata .decode_metadata ) and (
787+ self .attn_type != AttentionType .ENCODER_ONLY ):
764788 # Decoding run.
765789 # Whether to use rocm custom paged attention or not
766790 num_seqs , num_heads , head_size = decode_query .shape
@@ -885,4 +909,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
885909 and (qtype == torch .half or qtype == torch .bfloat16 )
886910 and (head_size == 64 or head_size == 128 )
887911 and (block_size == 16 or block_size == 32 )
888- and (gqa_ratio >= 1 and gqa_ratio <= 16 ) and max_seq_len <= 32768 )
912+ and (gqa_ratio >= 1 and gqa_ratio <= 16 ) and max_seq_len <= 32768 )
0 commit comments