Skip to content

Commit 0631414

Browse files
authored
[vllm] Support speculative decoding in vllm rolling batch (#2413)
1 parent 77041b5 commit 0631414

File tree

5 files changed

+79
-0
lines changed

5 files changed

+79
-0
lines changed

engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ class VllmRbProperties(Properties):
5959
enable_prefix_caching: Optional[bool] = False
6060
disable_sliding_window: Optional[bool] = False
6161
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
62+
use_v2_block_manager: bool = False
63+
64+
# Speculative decoding configuration.
65+
speculative_model: Optional[str] = None
66+
speculative_model_quantization: Optional[str] = None
67+
speculative_draft_tensor_parallel_size: Optional[int] = None
68+
num_speculative_tokens: Optional[int] = None
69+
speculative_max_model_len: Optional[int] = None
70+
speculative_disable_by_batch_size: Optional[int] = None
71+
ngram_prompt_lookup_max: Optional[int] = None
72+
ngram_prompt_lookup_min: Optional[int] = None
73+
spec_decoding_acceptance_method: str = 'rejection_sampler'
74+
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
75+
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
76+
qlora_adapter_name_or_path: Optional[str] = None
77+
disable_logprobs_during_spec_decoding: Optional[bool] = None
6278

6379
@field_validator('engine')
6480
def validate_engine(cls, engine):

engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,27 @@ def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs:
266266
enable_prefix_caching=config.enable_prefix_caching,
267267
disable_sliding_window=config.disable_sliding_window,
268268
max_num_seqs=config.max_rolling_batch_size,
269+
use_v2_block_manager=config.use_v2_block_manager,
270+
speculative_model=config.speculative_model,
271+
speculative_model_quantization=config.
272+
speculative_model_quantization,
273+
speculative_draft_tensor_parallel_size=config.
274+
speculative_draft_tensor_parallel_size,
275+
num_speculative_tokens=config.num_speculative_tokens,
276+
speculative_max_model_len=config.speculative_max_model_len,
277+
speculative_disable_by_batch_size=config.
278+
speculative_disable_by_batch_size,
279+
ngram_prompt_lookup_max=config.ngram_prompt_lookup_max,
280+
ngram_prompt_lookup_min=config.ngram_prompt_lookup_min,
281+
spec_decoding_acceptance_method=config.
282+
spec_decoding_acceptance_method,
283+
typical_acceptance_sampler_posterior_threshold=config.
284+
typical_acceptance_sampler_posterior_threshold,
285+
typical_acceptance_sampler_posterior_alpha=config.
286+
typical_acceptance_sampler_posterior_alpha,
287+
qlora_adapter_name_or_path=config.qlora_adapter_name_or_path,
288+
disable_logprobs_during_spec_decoding=config.
289+
disable_logprobs_during_spec_decoding,
269290
)
270291

271292

tests/integration/llm/client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,18 @@ def get_model_name():
463463
"seq_length": [256],
464464
"tokenizer": "tiiuae/falcon-11B"
465465
},
466+
"llama-68m-speculative-medusa": {
467+
"max_memory_per_gpu": [25.0],
468+
"batch_size": [1, 4],
469+
"seq_length": [256],
470+
"tokenizer": "JackFram/llama-68m"
471+
},
472+
"llama-68m-speculative-eagle": {
473+
"max_memory_per_gpu": [25.0],
474+
"batch_size": [1, 4],
475+
"seq_length": [256],
476+
"tokenizer": "JackFram/llama-68m"
477+
},
466478
"llama-7b-unmerged-lora": {
467479
"max_memory_per_gpu": [15.0, 15.0],
468480
"batch_size": [3],

tests/integration/llm/prepare.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,24 @@
625625
"option.tensor_parallel_degree": 4,
626626
"option.enable_chunked_prefill": "true",
627627
},
628+
"llama-68m-speculative-medusa": {
629+
"option.model_id": "s3://djl-llm/llama-68m/",
630+
"option.task": "text-generation",
631+
"option.speculative_model": "s3://djl-llm/llama-2-tiny/",
632+
"option.num_speculative_tokens": 4,
633+
"option.use_v2_block_manager": True,
634+
"option.tensor_parallel_degree": 1,
635+
"option.max_rolling_batch_size": 4,
636+
},
637+
"llama-68m-speculative-eagle": {
638+
"option.model_id": "s3://djl-llm/llama-68m/",
639+
"option.task": "text-generation",
640+
"option.speculative_model": "abhigoyal/vllm-eagle-llama-68m-random",
641+
"option.num_speculative_tokens": 4,
642+
"option.use_v2_block_manager": True,
643+
"option.tensor_parallel_degree": 1,
644+
"option.max_rolling_batch_size": 4,
645+
},
628646
"llama-7b-unmerged-lora": {
629647
"option.model_id": "s3://djl-llm/huggyllama-llama-7b",
630648
"option.tensor_parallel_degree": "max",

tests/integration/tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,18 @@ def test_falcon_11b_chunked_prefill(self):
604604
client.run(
605605
"vllm falcon-11b-chunked-prefill --in_tokens 1200".split())
606606

607+
def test_llama_68m_speculative_medusa(self):
608+
with Runner('lmi', 'llama-68m-speculative-medusa') as r:
609+
prepare.build_vllm_model("llama-68m-speculative-medusa")
610+
r.launch()
611+
client.run("vllm llama-68m-speculative-medusa".split())
612+
613+
def test_llama_68m_speculative_eagle(self):
614+
with Runner('lmi', 'llama-68m-speculative-eagle') as r:
615+
prepare.build_vllm_model("llama-68m-speculative-eagle")
616+
r.launch()
617+
client.run("vllm llama-68m-speculative-eagle".split())
618+
607619

608620
@pytest.mark.vllm
609621
@pytest.mark.lora

0 commit comments

Comments
 (0)