@@ -50,20 +50,15 @@ def copy_blocks(
5050
5151
5252@dataclass
53- class TorchSDPAMetadata (AttentionMetadataPerStage , PagedAttentionMetadata ):
53+ class TorchSDPAMetadata (AttentionMetadata , PagedAttentionMetadata ,
54+ AttentionMetadataPerStage ):
5455 """Metadata for TorchSDPABackend.
5556 """
5657 # Currently, input sequences can only contain all prompts
5758 # or all decoding. True if all sequences are prompts.
5859 is_prompt : bool
60+ slot_mapping : torch .Tensor
5961 prompt_lens : Optional [List [int ]]
60- prompt_lens_tensor : Optional [torch .Tensor ]
61-
62- max_subquery_len : Optional [int ] = None
63- max_prompt_len : Optional [int ] = None
64- subquery_start_loc : Optional [torch .Tensor ] = None
65- seq_start_loc : Optional [torch .Tensor ] = None
66- use_cuda_graph : bool = False
6762
6863 def __post_init__ (self ):
6964 # Set during the execution of the first attention op.
@@ -111,7 +106,7 @@ def forward(
111106 key : torch .Tensor ,
112107 value : torch .Tensor ,
113108 kv_cache : Optional [torch .Tensor ],
114- attn_metadata : AttentionMetadata [ TorchSDPAMetadata ] ,
109+ attn_metadata : TorchSDPAMetadata ,
115110 kv_scale : float ,
116111 ) -> torch .Tensor :
117112 """Forward pass with torch SDPA and PagedAttention.
@@ -140,51 +135,36 @@ def forward(
140135 attn_metadata .kv_cache_dtype ,
141136 kv_scale )
142137
143- num_prefill_tokens = attn_metadata .num_prefill_tokens
144- num_decode_tokens = attn_metadata .num_decode_tokens
145- assert key .shape [0 ] == num_prefill_tokens + num_decode_tokens
146- assert value .shape [0 ] == num_prefill_tokens + num_decode_tokens
147-
148- output = torch .empty_like (query )
149- # Query for decode. KV is not needed because it is already cached.
150- decode_query = query [num_prefill_tokens :]
151- # QKV for prefill.
152- query = query [:num_prefill_tokens ]
153- key = key [:num_prefill_tokens ]
154- value = value [:num_prefill_tokens ]
155-
156- assert query .shape [0 ] == num_prefill_tokens
157- assert decode_query .shape [0 ] == num_decode_tokens
158-
159- if prefill_meta := attn_metadata .prefill_metadata :
160- if (kv_cache is None or prefill_meta .block_tables .numel () == 0 ):
138+ if attn_metadata .is_prompt :
139+ if (kv_cache is None or attn_metadata .block_tables .numel () == 0 ):
161140 if self .num_kv_heads != self .num_heads :
162141 key = key .repeat_interleave (self .num_queries_per_kv , dim = 1 )
163142 value = value .repeat_interleave (self .num_queries_per_kv ,
164143 dim = 1 )
165144
166- if prefill_meta .attn_bias is None :
145+ if attn_metadata .attn_bias is None :
167146 if self .alibi_slopes is not None :
168147 att_masks = _make_alibi_bias (
169148 self .alibi_slopes , query .dtype ,
170- prefill_meta .prompt_lens ) # type: ignore
149+ attn_metadata .prompt_lens ) # type: ignore
171150 elif self .sliding_window is not None :
172151 att_masks = _make_sliding_window_bias (
173- prefill_meta .prompt_lens , self .sliding_window ,
152+ attn_metadata .prompt_lens , self .sliding_window ,
174153 query .dtype ) # type: ignore
175154 else :
176- att_masks = [None ] * len (prefill_meta .prompt_lens )
177- prefill_meta .attn_bias = att_masks
155+ att_masks = [None ] * len (attn_metadata .prompt_lens )
156+ attn_metadata .attn_bias = att_masks
178157
179158 query = query .movedim (0 , query .dim () - 2 )
180159 key = key .movedim (0 , key .dim () - 2 )
181160 value = value .movedim (0 , value .dim () - 2 )
182161
183162 start = 0
184- out = torch .empty ((num_tokens , self .num_heads , self .head_size ),
185- dtype = query .dtype )
186- for prompt_len , mask in zip (prefill_meta .prompt_lens ,
187- prefill_meta .attn_bias ):
163+ output = torch .empty (
164+ (num_tokens , self .num_heads , self .head_size ),
165+ dtype = query .dtype )
166+ for prompt_len , mask in zip (attn_metadata .prompt_lens ,
167+ attn_metadata .attn_bias ):
188168 end = start + prompt_len
189169 sub_out = scaled_dot_product_attention (
190170 query [:, start :end , :],
@@ -194,32 +174,28 @@ def forward(
194174 dropout_p = 0.0 ,
195175 is_causal = not self .need_mask ,
196176 scale = self .scale ).movedim (query .dim () - 2 , 0 )
197- out [start :end , :, :] = sub_out
177+ output [start :end , :, :] = sub_out
198178 start = end
199- assert out .shape == output [:num_prefill_tokens ].shape
200- output [:num_prefill_tokens ] = out
201179 else :
202180 # prefix-enabled attention
203181 raise RuntimeError (
204182 "Torch SDPA backend doesn't support prefix decoding." )
205183
206- if decode_meta := attn_metadata . decode_metadata :
184+ else :
207185 # Decoding run.
208- out = PagedAttention .forward_decode (
209- decode_query ,
186+ output = PagedAttention .forward_decode (
187+ query ,
210188 key_cache ,
211189 value_cache ,
212- decode_meta .block_tables ,
213- decode_meta .context_lens ,
214- decode_meta .max_context_len ,
190+ attn_metadata .block_tables ,
191+ attn_metadata .context_lens ,
192+ attn_metadata .max_context_len ,
215193 attn_metadata .kv_cache_dtype ,
216194 self .num_kv_heads ,
217195 self .scale ,
218196 self .alibi_slopes ,
219197 kv_scale ,
220198 )
221- assert out .shape == output [num_prefill_tokens :].shape
222- output [num_prefill_tokens :]
223199
224200 # Reshape the output tensor.
225201 return output .view (- 1 , self .num_heads * self .head_size )
@@ -241,7 +217,7 @@ def _make_alibi_bias(
241217 bias = bias [None , :] - bias [:, None ]
242218
243219 num_heads = alibi_slopes .shape [0 ]
244- bias = bias [None , :].expand ( num_heads , prompt_len , prompt_len )
220+ bias = bias [None , :].repeat (( num_heads , 1 , 1 ) )
245221 bias .mul_ (alibi_slopes [:, None , None ])
246222 inf_mask = torch .empty (
247223 (1 , prompt_len , prompt_len ),
0 commit comments