|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 | import ast |
4 | 4 | from dataclasses import replace |
5 | | -from typing import Optional |
| 5 | +from importlib.util import find_spec |
| 6 | +from typing import Optional, Protocol |
6 | 7 |
|
7 | 8 | import numpy as np |
8 | 9 | import torch |
|
20 | 21 | from vllm.platforms import current_platform |
21 | 22 | from vllm.utils import is_pin_memory_available |
22 | 23 | from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata |
23 | | -from vllm.v1.attention.backends.rocm_aiter_fa import ( |
24 | | - AiterFlashAttentionMetadata) |
25 | 24 | from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, |
26 | 25 | TreeAttentionMetadataBuilder) |
27 | 26 | from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata |
|
34 | 33 | PADDING_SLOT_ID = -1 |
35 | 34 |
|
36 | 35 |
|
| 36 | +class EagleAttentionMetadata(Protocol): |
| 37 | + # Required attributes |
| 38 | + num_actual_tokens: int |
| 39 | + max_query_len: int |
| 40 | + query_start_loc: torch.Tensor |
| 41 | + max_seq_len: int |
| 42 | + seq_lens: torch.Tensor |
| 43 | + block_table: torch.Tensor |
| 44 | + slot_mapping: torch.Tensor |
| 45 | + |
| 46 | + |
37 | 47 | class EagleProposer: |
38 | 48 |
|
39 | 49 | def __init__( |
@@ -97,6 +107,20 @@ def __init__( |
97 | 107 | dtype=self.dtype, |
98 | 108 | device=device) |
99 | 109 |
|
| 110 | + # Determine allowed attention backends once during initialization. |
| 111 | + self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] |
| 112 | + if current_platform.is_rocm(): |
| 113 | + rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] |
| 114 | + # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend |
| 115 | + if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): |
| 116 | + from vllm.v1.attention.backends.rocm_aiter_fa import ( |
| 117 | + AiterFlashAttentionMetadata) |
| 118 | + rocm_types.append(AiterFlashAttentionMetadata) |
| 119 | + self.allowed_attn_types = tuple(rocm_types) |
| 120 | + else: |
| 121 | + self.allowed_attn_types = (FlashAttentionMetadata, |
| 122 | + TreeAttentionMetadata) |
| 123 | + |
100 | 124 | # Parse the speculative token tree. |
101 | 125 | spec_token_tree = self.speculative_config.speculative_token_tree |
102 | 126 | self.tree_choices: list[tuple[int, |
@@ -165,7 +189,7 @@ def propose( |
165 | 189 | for layer_name in self.attn_layer_names: |
166 | 190 | per_layer_attn_metadata[layer_name] = attn_metadata |
167 | 191 | if self.use_cuda_graph and \ |
168 | | - num_tokens <= self.cudagraph_batch_sizes[-1]: |
| 192 | + num_tokens <= self.cudagraph_batch_sizes[-1]: |
169 | 193 | num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) |
170 | 194 | else: |
171 | 195 | num_input_tokens = num_tokens |
@@ -225,25 +249,13 @@ def propose( |
225 | 249 | # TODO: Currently, MTP module released by deepseek only has |
226 | 250 | # one layer. Adapt this code to support multiple layers once |
227 | 251 | # there's a multi-layer MTP module. |
228 | | - |
229 | | - # On ROCm, both AiterFlashAttention and TritonAttention |
230 | | - # support multi-token eagle spec decode. |
231 | | - if current_platform.is_rocm(): |
232 | | - assert isinstance( |
233 | | - attn_metadata, |
234 | | - (TritonAttentionMetadata, AiterFlashAttentionMetadata, |
235 | | - FlashAttentionMetadata)) |
236 | | - else: |
237 | | - # Currently, only FlashAttention supports multi-token eagle spec |
238 | | - # decode. This is because the code below makes assumptions about |
239 | | - # attn_metadata attributes available. |
240 | | - assert isinstance(attn_metadata, FlashAttentionMetadata) |
| 252 | + assert isinstance(attn_metadata, self.allowed_attn_types) |
241 | 253 |
|
242 | 254 | # Generate the remaining draft tokens. |
243 | 255 | draft_token_ids_list = [draft_token_ids] |
244 | 256 |
|
245 | 257 | if self.use_cuda_graph and \ |
246 | | - batch_size <= self.cudagraph_batch_sizes[-1]: |
| 258 | + batch_size <= self.cudagraph_batch_sizes[-1]: |
247 | 259 | input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) |
248 | 260 | else: |
249 | 261 | input_batch_size = batch_size |
@@ -449,7 +461,7 @@ def propose_tree( |
449 | 461 | num_tokens, -1) |
450 | 462 |
|
451 | 463 | if self.use_cuda_graph and \ |
452 | | - num_tokens <= self.cudagraph_batch_sizes[-1]: |
| 464 | + num_tokens <= self.cudagraph_batch_sizes[-1]: |
453 | 465 | num_input_tokens = self.vllm_config.pad_for_cudagraph( |
454 | 466 | num_tokens) |
455 | 467 | else: |
@@ -508,19 +520,19 @@ def prepare_inputs( |
508 | 520 | """ |
509 | 521 | # E.g. |
510 | 522 | # common_attn_metadata.query_start_loc{_cpu}: |
511 | | - # [0, q1, q1 + q2, q1 + q2 + q3] |
| 523 | + # [0, q1, q1 + q2, q1 + q2 + q3] |
512 | 524 | # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] |
513 | 525 | # num_rejected_tokens: [n1, n2, n3] |
514 | 526 | # This function computes the intermediate values: |
515 | 527 | # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] |
516 | 528 | # And returns: |
517 | 529 | # common_attn_metadata.query_start_loc{_cpu}: |
518 | | - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] |
| 530 | + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] |
519 | 531 | # common_attn_metadata.seq_lens{_cpu}: |
520 | | - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] |
| 532 | + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] |
521 | 533 | # token_indices: [0, 1, ..., q1 - n1 - 1, |
522 | | - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, |
523 | | - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] |
| 534 | + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, |
| 535 | + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] |
524 | 536 |
|
525 | 537 | device = common_attn_metadata.query_start_loc.device |
526 | 538 | query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu |
@@ -564,9 +576,9 @@ def prepare_inputs( |
564 | 576 | old_query_start_locs_expanded = np.repeat( |
565 | 577 | query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) |
566 | 578 | # Final token indices are: |
567 | | - # [0, 1, // req 1 |
568 | | - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 |
569 | | - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 |
| 579 | + # [0, 1, // req 1 |
| 580 | + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 |
| 581 | + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 |
570 | 582 | token_indices_np = token_offests + old_query_start_locs_expanded |
571 | 583 | token_indices = torch.from_numpy(token_indices_np).to( |
572 | 584 | device, non_blocking=True) |
@@ -616,20 +628,18 @@ def load_model(self, target_model: nn.Module) -> None: |
616 | 628 | target_language_model = target_model |
617 | 629 | # share embed_tokens with the target model if needed |
618 | 630 | if get_pp_group().world_size == 1 \ |
619 | | - and self.model.model.embed_tokens.weight.shape \ |
620 | | - == target_language_model.model.embed_tokens.weight.shape: |
| 631 | + and self.model.model.embed_tokens.weight.shape \ |
| 632 | + == target_language_model.model.embed_tokens.weight.shape: |
621 | 633 | logger.info( |
622 | | - "Assuming the EAGLE head shares the same vocab embedding" \ |
623 | | - " with the target model." |
624 | | - ) |
| 634 | + "Assuming the EAGLE head shares the same vocab embedding" |
| 635 | + " with the target model.") |
625 | 636 | del self.model.model.embed_tokens |
626 | 637 | self.model.model.embed_tokens = ( |
627 | 638 | target_language_model.model.embed_tokens) |
628 | 639 | else: |
629 | 640 | logger.info( |
630 | | - "The EAGLE head's vocab embedding will be loaded separately" \ |
631 | | - " from the target model." |
632 | | - ) |
| 641 | + "The EAGLE head's vocab embedding will be loaded separately" |
| 642 | + " from the target model.") |
633 | 643 |
|
634 | 644 | # share lm_head with the target model if needed |
635 | 645 | # some model definition do not define lm_head explicitly |
|
0 commit comments