1313from vllm .v1 .core .sched .output import SchedulerOutput
1414from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
1515 KVCacheGroupSpec )
16- from vllm .v1 .outputs import DraftTokenIds , ModelRunnerOutput
16+ from vllm .v1 .outputs import ModelRunnerOutput
1717from vllm .v1 .request import Request , RequestStatus
1818from vllm .v1 .structured_output import StructuredOutputManager
1919
2020from tests .ut .base import TestBase
2121from vllm_ascend .core .scheduler import AscendScheduler
2222from vllm_ascend .utils import vllm_version_is
2323
24+ if not vllm_version_is ("0.10.1.1" ):
25+ from vllm .v1 .outputs import DraftTokenIds
26+ else :
27+ DraftTokenIds = None
28+
2429EOS_TOKEN_ID = 50256
2530MODEL = "Qwen3-0.6B"
2631ENABLE_PREFIX_CACHING = None
@@ -66,16 +71,33 @@ def create_requests(
6671
6772
6873def make_output (scheduler ):
69- return ModelRunnerOutput (
70- req_ids = [req .request_id for req in scheduler .running ],
71- req_id_to_index = {
72- req .request_id : i
73- for i , req in enumerate (scheduler .running )
74- },
75- sampled_token_ids = [[1000 ]] * len (scheduler .running ),
76- logprobs = None ,
77- prompt_logprobs_dict = {},
78- pooler_output = [])
74+ req_ids = [req .request_id for req in scheduler .running ]
75+ req_id_to_index = {
76+ req .request_id : i
77+ for i , req in enumerate (scheduler .running )
78+ }
79+ sampled_token_ids = [[1000 ]] * len (scheduler .running )
80+ logprobs = None
81+ if vllm_version_is ("0.10.1.1" ):
82+ modelrunner_output = ModelRunnerOutput (
83+ req_ids = req_ids ,
84+ req_id_to_index = req_id_to_index ,
85+ sampled_token_ids = sampled_token_ids ,
86+ spec_token_ids = None ,
87+ logprobs = logprobs ,
88+ prompt_logprobs_dict = {},
89+ pooler_output = [],
90+ )
91+ else :
92+ modelrunner_output = ModelRunnerOutput (
93+ req_ids = req_ids ,
94+ req_id_to_index = req_id_to_index ,
95+ sampled_token_ids = sampled_token_ids ,
96+ logprobs = logprobs ,
97+ prompt_logprobs_dict = {},
98+ pooler_output = [],
99+ )
100+ return modelrunner_output
79101
80102
81103class TestAscendScheduler (TestBase ):
@@ -271,8 +293,7 @@ def test_stop_via_update_from_output(self):
271293 req .num_computed_tokens = req .num_tokens
272294 scheduler .requests [req .request_id ] = req
273295 scheduler .running .append (req )
274- if not vllm_version_is ("0.9.2" ):
275- req .status = RequestStatus .RUNNING
296+ req .status = RequestStatus .RUNNING
276297
277298 scheduler_output = SchedulerOutput (scheduled_new_reqs = [],
278299 scheduled_cached_reqs = [],
@@ -291,18 +312,33 @@ def test_stop_via_update_from_output(self):
291312 free_encoder_input_ids = [],
292313 structured_output_request_ids = {},
293314 grammar_bitmask = None )
294-
295- model_output = ModelRunnerOutput (
296- req_ids = [req .request_id for req in requests ],
297- req_id_to_index = {
298- req .request_id : i
299- for i , req in enumerate (requests )
300- },
301- sampled_token_ids = [[EOS_TOKEN_ID ], [10 , 11 ]
302- ], # First request hits EOS, second continues
303- logprobs = None ,
304- prompt_logprobs_dict = {},
305- pooler_output = [])
315+ if vllm_version_is ("0.10.1.1" ):
316+ model_output = ModelRunnerOutput (
317+ req_ids = [req .request_id for req in requests ],
318+ req_id_to_index = {
319+ req .request_id : i
320+ for i , req in enumerate (requests )
321+ },
322+ sampled_token_ids = [[EOS_TOKEN_ID ], [
323+ 10 , 11
324+ ]], # First request hits EOS, second continues
325+ spec_token_ids = None ,
326+ logprobs = None ,
327+ prompt_logprobs_dict = {},
328+ pooler_output = [])
329+ else :
330+ model_output = ModelRunnerOutput (
331+ req_ids = [req .request_id for req in requests ],
332+ req_id_to_index = {
333+ req .request_id : i
334+ for i , req in enumerate (requests )
335+ },
336+ sampled_token_ids = [[EOS_TOKEN_ID ], [
337+ 10 , 11
338+ ]], # First request hits EOS, second continues
339+ logprobs = None ,
340+ prompt_logprobs_dict = {},
341+ pooler_output = [])
306342
307343 scheduler .update_from_output (scheduler_output , model_output )
308344
@@ -325,8 +361,7 @@ def test_stop_via_update_from_output(self):
325361 req .num_computed_tokens = req .num_tokens
326362 scheduler .requests [req .request_id ] = req
327363 scheduler .running .append (req )
328- if not vllm_version_is ("0.9.2" ):
329- req .status = RequestStatus .RUNNING
364+ req .status = RequestStatus .RUNNING
330365
331366 scheduler_output = SchedulerOutput (scheduled_new_reqs = [],
332367 scheduled_cached_reqs = [],
@@ -346,18 +381,31 @@ def test_stop_via_update_from_output(self):
346381 free_encoder_input_ids = [],
347382 structured_output_request_ids = {},
348383 grammar_bitmask = None )
349-
350- model_output = ModelRunnerOutput (
351- req_ids = [req .request_id for req in requests ],
352- req_id_to_index = {
353- req .request_id : i
354- for i , req in enumerate (requests )
355- },
356- sampled_token_ids = [[10 , 42 , 12 ],
357- [13 , 14 ]], # First request hits stop token
358- logprobs = None ,
359- prompt_logprobs_dict = {},
360- pooler_output = [])
384+ if vllm_version_is ("0.10.1.1" ):
385+ model_output = ModelRunnerOutput (
386+ req_ids = [req .request_id for req in requests ],
387+ req_id_to_index = {
388+ req .request_id : i
389+ for i , req in enumerate (requests )
390+ },
391+ sampled_token_ids = [[10 , 42 , 12 ],
392+ [13 , 14 ]], # First request hits stop token
393+ spec_token_ids = None ,
394+ logprobs = None ,
395+ prompt_logprobs_dict = {},
396+ pooler_output = [])
397+ else :
398+ model_output = ModelRunnerOutput (
399+ req_ids = [req .request_id for req in requests ],
400+ req_id_to_index = {
401+ req .request_id : i
402+ for i , req in enumerate (requests )
403+ },
404+ sampled_token_ids = [[10 , 42 , 12 ],
405+ [13 , 14 ]], # First request hits stop token
406+ logprobs = None ,
407+ prompt_logprobs_dict = {},
408+ pooler_output = [])
361409
362410 scheduler .update_from_output (scheduler_output , model_output )
363411
@@ -379,8 +427,7 @@ def test_stop_via_update_from_output(self):
379427 req .num_computed_tokens = req .num_tokens
380428 scheduler .requests [req .request_id ] = req
381429 scheduler .running .append (req )
382- if not vllm_version_is ("0.9.2" ):
383- req .status = RequestStatus .RUNNING
430+ req .status = RequestStatus .RUNNING
384431
385432 scheduler_output = SchedulerOutput (scheduled_new_reqs = [],
386433 scheduled_cached_reqs = [],
@@ -401,18 +448,31 @@ def test_stop_via_update_from_output(self):
401448 structured_output_request_ids = {},
402449 grammar_bitmask = None )
403450
404- model_output = ModelRunnerOutput (
405- req_ids = [req .request_id for req in requests ],
406- req_id_to_index = {
407- req .request_id : i
408- for i , req in enumerate (requests )
409- },
410- sampled_token_ids = [[10 , 11 , 12 ],
411- [13 ]], # First request exceeds max_tokens
412- logprobs = None ,
413- prompt_logprobs_dict = {},
414- pooler_output = [])
415-
451+ if vllm_version_is ("0.10.1.1" ):
452+ model_output = ModelRunnerOutput (
453+ req_ids = [req .request_id for req in requests ],
454+ req_id_to_index = {
455+ req .request_id : i
456+ for i , req in enumerate (requests )
457+ },
458+ sampled_token_ids = [[10 , 11 , 12 ],
459+ [13 ]], # First request exceeds max_tokens
460+ spec_token_ids = None ,
461+ logprobs = None ,
462+ prompt_logprobs_dict = {},
463+ pooler_output = [])
464+ else :
465+ model_output = ModelRunnerOutput (
466+ req_ids = [req .request_id for req in requests ],
467+ req_id_to_index = {
468+ req .request_id : i
469+ for i , req in enumerate (requests )
470+ },
471+ sampled_token_ids = [[10 , 11 , 12 ],
472+ [13 ]], # First request exceeds max_tokens
473+ logprobs = None ,
474+ prompt_logprobs_dict = {},
475+ pooler_output = [])
416476 scheduler .update_from_output (scheduler_output , model_output )
417477
418478 # Verify first request stopped due to length
@@ -448,13 +508,24 @@ def test_stop_via_update_from_output(self):
448508 structured_output_request_ids = {},
449509 grammar_bitmask = None )
450510
451- model_output = ModelRunnerOutput (
452- req_ids = [requests [0 ].request_id ],
453- req_id_to_index = {requests [0 ].request_id : 0 },
454- sampled_token_ids = [[EOS_TOKEN_ID , 10 , 11 ]],
455- logprobs = None ,
456- prompt_logprobs_dict = {},
457- pooler_output = [])
511+ if vllm_version_is ("0.10.1.1" ):
512+ model_output = ModelRunnerOutput (
513+ req_ids = [requests [0 ].request_id ],
514+ req_id_to_index = {requests [0 ].request_id : 0 },
515+ sampled_token_ids = [[EOS_TOKEN_ID , 10 , 11 ]],
516+ spec_token_ids = None ,
517+ logprobs = None ,
518+ prompt_logprobs_dict = {},
519+ pooler_output = [])
520+
521+ else :
522+ model_output = ModelRunnerOutput (
523+ req_ids = [requests [0 ].request_id ],
524+ req_id_to_index = {requests [0 ].request_id : 0 },
525+ sampled_token_ids = [[EOS_TOKEN_ID , 10 , 11 ]],
526+ logprobs = None ,
527+ prompt_logprobs_dict = {},
528+ pooler_output = [])
458529
459530 scheduler .update_from_output (scheduler_output , model_output )
460531
@@ -505,13 +576,23 @@ def test_schedule_concurrent_batches(self):
505576 512 )
506577
507578 # Model output of the first request.
508- model_runner_output = ModelRunnerOutput (
509- req_ids = [requests [0 ].request_id ],
510- req_id_to_index = {requests [0 ].request_id : 0 },
511- sampled_token_ids = [[0 ]],
512- logprobs = None ,
513- prompt_logprobs_dict = {},
514- pooler_output = [])
579+ if vllm_version_is ("0.10.1.1" ):
580+ model_runner_output = ModelRunnerOutput (
581+ req_ids = [requests [0 ].request_id ],
582+ req_id_to_index = {requests [0 ].request_id : 0 },
583+ sampled_token_ids = [[0 ]],
584+ spec_token_ids = None ,
585+ logprobs = None ,
586+ prompt_logprobs_dict = {},
587+ pooler_output = [])
588+ else :
589+ model_runner_output = ModelRunnerOutput (
590+ req_ids = [requests [0 ].request_id ],
591+ req_id_to_index = {requests [0 ].request_id : 0 },
592+ sampled_token_ids = [[0 ]],
593+ logprobs = None ,
594+ prompt_logprobs_dict = {},
595+ pooler_output = [])
515596
516597 scheduler .update_from_output (scheduler_output0 ,
517598 model_runner_output )
@@ -521,13 +602,23 @@ def test_schedule_concurrent_batches(self):
521602 # request is still running.
522603 scheduler .schedule ()
523604 # Model output of the second request.
524- model_runner_output = ModelRunnerOutput (
525- req_ids = [requests [1 ].request_id ],
526- req_id_to_index = {requests [1 ].request_id : 0 },
527- sampled_token_ids = [[0 ]],
528- logprobs = None ,
529- prompt_logprobs_dict = {},
530- pooler_output = [])
605+ if vllm_version_is ("0.10.1.1" ):
606+ model_runner_output = ModelRunnerOutput (
607+ req_ids = [requests [1 ].request_id ],
608+ req_id_to_index = {requests [1 ].request_id : 0 },
609+ sampled_token_ids = [[0 ]],
610+ spec_token_ids = None ,
611+ logprobs = None ,
612+ prompt_logprobs_dict = {},
613+ pooler_output = [])
614+ else :
615+ model_runner_output = ModelRunnerOutput (
616+ req_ids = [requests [1 ].request_id ],
617+ req_id_to_index = {requests [1 ].request_id : 0 },
618+ sampled_token_ids = [[0 ]],
619+ logprobs = None ,
620+ prompt_logprobs_dict = {},
621+ pooler_output = [])
531622
532623 scheduler .update_from_output (scheduler_output1 ,
533624 model_runner_output )
@@ -579,19 +670,29 @@ def test_schedule_spec_decoding_stats(self):
579670 req_id = requests [i ].request_id
580671 self .assertEqual (output .num_scheduled_tokens [req_id ], 1 )
581672 self .assertNotIn (req_id , output .scheduled_spec_decode_tokens )
582-
583- model_runner_output = ModelRunnerOutput (
584- req_ids = req_ids ,
585- req_id_to_index = req_to_index ,
586- sampled_token_ids = [[0 ] for _ in range (len (requests ))],
587- logprobs = None ,
588- prompt_logprobs_dict = {},
589- pooler_output = [])
590- draft_token_ids = DraftTokenIds (req_ids , spec_tokens )
673+ if vllm_version_is ("0.10.1.1" ):
674+ model_runner_output = ModelRunnerOutput (
675+ req_ids = req_ids ,
676+ req_id_to_index = req_to_index ,
677+ sampled_token_ids = [[0 ] for _ in range (len (requests ))],
678+ logprobs = None ,
679+ prompt_logprobs_dict = {},
680+ spec_token_ids = spec_tokens ,
681+ pooler_output = [])
682+ else :
683+ model_runner_output = ModelRunnerOutput (
684+ req_ids = req_ids ,
685+ req_id_to_index = req_to_index ,
686+ sampled_token_ids = [[0 ] for _ in range (len (requests ))],
687+ logprobs = None ,
688+ prompt_logprobs_dict = {},
689+ pooler_output = [])
690+ draft_token_ids = DraftTokenIds (req_ids , spec_tokens )
591691
592692 engine_core_outputs = scheduler .update_from_output (
593693 output , model_runner_output )
594- scheduler .update_draft_token_ids (draft_token_ids )
694+ if not vllm_version_is ("0.10.1.1" ):
695+ scheduler .update_draft_token_ids (draft_token_ids )
595696
596697 for i in range (len (requests )):
597698 running_req = scheduler .running [i ]
@@ -627,14 +728,23 @@ def test_schedule_spec_decoding_stats(self):
627728 else :
628729 self .assertNotIn (req_id ,
629730 output .scheduled_spec_decode_tokens )
630-
631- model_runner_output = ModelRunnerOutput (
632- req_ids = req_ids ,
633- req_id_to_index = req_to_index ,
634- sampled_token_ids = output_tokens ,
635- logprobs = None ,
636- prompt_logprobs_dict = {},
637- pooler_output = [])
731+ if vllm_version_is ("0.10.1.1" ):
732+ model_runner_output = ModelRunnerOutput (
733+ req_ids = req_ids ,
734+ req_id_to_index = req_to_index ,
735+ sampled_token_ids = output_tokens ,
736+ spec_token_ids = None ,
737+ logprobs = None ,
738+ prompt_logprobs_dict = {},
739+ pooler_output = [])
740+ else :
741+ model_runner_output = ModelRunnerOutput (
742+ req_ids = req_ids ,
743+ req_id_to_index = req_to_index ,
744+ sampled_token_ids = output_tokens ,
745+ logprobs = None ,
746+ prompt_logprobs_dict = {},
747+ pooler_output = [])
638748
639749 engine_core_outputs = scheduler .update_from_output (
640750 output , model_runner_output )
0 commit comments