|
1 | 1 | """Attention layer ROCm GPUs.""" |
2 | 2 | from dataclasses import dataclass |
3 | | -from typing import Any, Dict, List, Optional, Tuple, Type |
| 3 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type |
4 | 4 |
|
5 | 5 | import torch |
6 | 6 |
|
|
15 | 15 | from vllm.logger import init_logger |
16 | 16 | from vllm.platforms import current_platform |
17 | 17 |
|
| 18 | +if TYPE_CHECKING: |
| 19 | + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata |
| 20 | + |
18 | 21 | logger = init_logger(__name__) |
19 | 22 |
|
20 | 23 | _PARTITION_SIZE_ROCM = 512 |
@@ -186,6 +189,59 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: |
186 | 189 | ) |
187 | 190 | return self._cached_decode_metadata |
188 | 191 |
|
| 192 | + def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", |
| 193 | + sampled_token_ids: Optional[torch.Tensor], |
| 194 | + block_size: int, num_seqs: int, num_queries: int): |
| 195 | + """ |
| 196 | + Update metadata in-place to advance one decode step. |
| 197 | + """ |
| 198 | + # When using cudagraph, the num_seqs is padded to the next captured |
| 199 | + # batch sized, but num_queries tracks the actual number of requests in |
| 200 | + # the batch. For --enforce-eager mode, num_seqs == num_queries |
| 201 | + if num_seqs != num_queries: |
| 202 | + assert num_seqs > num_queries |
| 203 | + assert self.use_cuda_graph |
| 204 | + |
| 205 | + assert self.num_prefills == 0 |
| 206 | + assert self.num_prefill_tokens == 0 |
| 207 | + assert self.num_decode_tokens == num_seqs |
| 208 | + assert self.slot_mapping.shape == (num_seqs, ) |
| 209 | + |
| 210 | + assert self.seq_lens is not None |
| 211 | + assert len(self.seq_lens) == num_seqs |
| 212 | + assert self.seq_lens_tensor is not None |
| 213 | + assert self.seq_lens_tensor.shape == (num_seqs, ) |
| 214 | + assert self.max_query_len == 1 |
| 215 | + assert self.max_prefill_seq_len == 0 |
| 216 | + assert self.max_decode_seq_len == max(self.seq_lens) |
| 217 | + |
| 218 | + assert self.query_start_loc is not None |
| 219 | + assert self.query_start_loc.shape == (num_queries + 1, ) |
| 220 | + assert self.seq_start_loc is not None |
| 221 | + assert self.seq_start_loc.shape == (num_seqs + 1, ) |
| 222 | + |
| 223 | + assert self.context_lens_tensor is not None |
| 224 | + assert self.context_lens_tensor.shape == (num_queries, ) |
| 225 | + |
| 226 | + assert self.block_tables is not None |
| 227 | + assert self.block_tables.shape[0] == num_seqs |
| 228 | + |
| 229 | + # Update query lengths. Note that we update only queries and not seqs, |
| 230 | + # since tensors may be padded due to captured cuda graph batch size |
| 231 | + for i in range(num_queries): |
| 232 | + self.seq_lens[i] += 1 |
| 233 | + self.max_decode_seq_len = max(self.seq_lens) |
| 234 | + |
| 235 | + ops.advance_step_flashattn(num_seqs=num_seqs, |
| 236 | + num_queries=num_queries, |
| 237 | + block_size=block_size, |
| 238 | + input_tokens=model_input.input_tokens, |
| 239 | + sampled_token_ids=sampled_token_ids, |
| 240 | + input_positions=model_input.input_positions, |
| 241 | + seq_lens=self.seq_lens_tensor, |
| 242 | + slot_mapping=self.slot_mapping, |
| 243 | + block_tables=self.block_tables) |
| 244 | + |
189 | 245 |
|
190 | 246 | class ROCmFlashAttentionMetadataBuilder( |
191 | 247 | CommonMetadataBuilder[ROCmFlashAttentionMetadata]): |
|
0 commit comments