|
190 | 190 | import functools |
191 | 191 | from abc import abstractmethod |
192 | 192 | from dataclasses import dataclass, field |
193 | | -from typing import Generic, Optional, TypeVar, Union |
| 193 | +from typing import ClassVar, Generic, Optional, TypeVar, Union |
194 | 194 |
|
195 | 195 | import torch |
196 | 196 | from tqdm import tqdm |
@@ -454,6 +454,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): |
454 | 454 | understand this class |
455 | 455 | """ |
456 | 456 |
|
| 457 | + # Whether the backend supports reordering the batch such that |
| 458 | + # short sequences (i.e. verification for speculative decoding) are |
| 459 | + # classified as decode requests. |
| 460 | + # If True, this will increase `reorder_batch_threshold` (below) when |
| 461 | + # speculative decoding is enabled, and set `require_uniform=True` when |
| 462 | + # when reordering the batch. Non-uniform decode requests will |
| 463 | + # fall back to prefill in this case. |
| 464 | + supports_uniform_spec_as_decode: ClassVar[bool] = False |
| 465 | + |
| 466 | + # The threshold for reordering the batch into decode and prefill requests. |
| 467 | + # If > 1, the batch will be reordered such that requests with |
| 468 | + # query length <= threshold are classified as decode requests. |
| 469 | + # Use `supports_uniform_spec_as_decode` (above) to set this automatically |
| 470 | + # when speculative decoding is enabled. |
457 | 471 | reorder_batch_threshold: int = 1 |
458 | 472 |
|
459 | 473 | @staticmethod |
@@ -503,6 +517,7 @@ def __init__( |
503 | 517 | self.model_config = vllm_config.model_config |
504 | 518 | parallel_config = vllm_config.parallel_config |
505 | 519 | self.compilation_config = vllm_config.compilation_config |
| 520 | + self.vllm_config = vllm_config |
506 | 521 | self.device = device |
507 | 522 |
|
508 | 523 | self.num_heads = self.model_config.get_num_attention_heads(parallel_config) |
@@ -578,6 +593,11 @@ def __init__( |
578 | 593 | device=device, |
579 | 594 | ) |
580 | 595 |
|
| 596 | + supports_spec_as_decode = self.supports_uniform_spec_as_decode |
| 597 | + self._init_reorder_batch_threshold( |
| 598 | + self.reorder_batch_threshold, supports_spec_as_decode |
| 599 | + ) |
| 600 | + |
581 | 601 | def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): |
582 | 602 | qo_indptr = prefill.query_start_loc |
583 | 603 |
|
@@ -714,7 +734,9 @@ def build( |
714 | 734 |
|
715 | 735 | num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( |
716 | 736 | split_decodes_and_prefills( |
717 | | - common_attn_metadata, decode_threshold=self.reorder_batch_threshold |
| 737 | + common_attn_metadata, |
| 738 | + decode_threshold=self.reorder_batch_threshold, |
| 739 | + require_uniform=self.supports_uniform_spec_as_decode, |
718 | 740 | ) |
719 | 741 | ) |
720 | 742 |
|
|
0 commit comments