Skip to content

Commit 4988427

Browse files
committed
fix ut
Signed-off-by: MengqingCao <[email protected]>
1 parent b643a28 commit 4988427

File tree

3 files changed

+230
-105
lines changed

3 files changed

+230
-105
lines changed

tests/ut/core/test_scheduler.py

Lines changed: 202 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,19 @@
1313
from vllm.v1.core.sched.output import SchedulerOutput
1414
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1515
KVCacheGroupSpec)
16-
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
16+
from vllm.v1.outputs import ModelRunnerOutput
1717
from vllm.v1.request import Request, RequestStatus
1818
from vllm.v1.structured_output import StructuredOutputManager
1919

2020
from tests.ut.base import TestBase
2121
from vllm_ascend.core.scheduler import AscendScheduler
2222
from vllm_ascend.utils import vllm_version_is
2323

24+
if not vllm_version_is("0.10.1.1"):
25+
from vllm.v1.outputs import DraftTokenIds
26+
else:
27+
DraftTokenIds = None
28+
2429
EOS_TOKEN_ID = 50256
2530
MODEL = "Qwen3-0.6B"
2631
ENABLE_PREFIX_CACHING = None
@@ -66,16 +71,33 @@ def create_requests(
6671

6772

6873
def make_output(scheduler):
69-
return ModelRunnerOutput(
70-
req_ids=[req.request_id for req in scheduler.running],
71-
req_id_to_index={
72-
req.request_id: i
73-
for i, req in enumerate(scheduler.running)
74-
},
75-
sampled_token_ids=[[1000]] * len(scheduler.running),
76-
logprobs=None,
77-
prompt_logprobs_dict={},
78-
pooler_output=[])
74+
req_ids = [req.request_id for req in scheduler.running]
75+
req_id_to_index = {
76+
req.request_id: i
77+
for i, req in enumerate(scheduler.running)
78+
}
79+
sampled_token_ids = [[1000]] * len(scheduler.running)
80+
logprobs = None
81+
if vllm_version_is("0.10.1.1"):
82+
modelrunner_output = ModelRunnerOutput(
83+
req_ids=req_ids,
84+
req_id_to_index=req_id_to_index,
85+
sampled_token_ids=sampled_token_ids,
86+
spec_token_ids=None,
87+
logprobs=logprobs,
88+
prompt_logprobs_dict={},
89+
pooler_output=[],
90+
)
91+
else:
92+
modelrunner_output = ModelRunnerOutput(
93+
req_ids=req_ids,
94+
req_id_to_index=req_id_to_index,
95+
sampled_token_ids=sampled_token_ids,
96+
logprobs=logprobs,
97+
prompt_logprobs_dict={},
98+
pooler_output=[],
99+
)
100+
return modelrunner_output
79101

80102

81103
class TestAscendScheduler(TestBase):
@@ -271,8 +293,7 @@ def test_stop_via_update_from_output(self):
271293
req.num_computed_tokens = req.num_tokens
272294
scheduler.requests[req.request_id] = req
273295
scheduler.running.append(req)
274-
if not vllm_version_is("0.9.2"):
275-
req.status = RequestStatus.RUNNING
296+
req.status = RequestStatus.RUNNING
276297

277298
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
278299
scheduled_cached_reqs=[],
@@ -291,18 +312,33 @@ def test_stop_via_update_from_output(self):
291312
free_encoder_input_ids=[],
292313
structured_output_request_ids={},
293314
grammar_bitmask=None)
294-
295-
model_output = ModelRunnerOutput(
296-
req_ids=[req.request_id for req in requests],
297-
req_id_to_index={
298-
req.request_id: i
299-
for i, req in enumerate(requests)
300-
},
301-
sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]
302-
], # First request hits EOS, second continues
303-
logprobs=None,
304-
prompt_logprobs_dict={},
305-
pooler_output=[])
315+
if vllm_version_is("0.10.1.1"):
316+
model_output = ModelRunnerOutput(
317+
req_ids=[req.request_id for req in requests],
318+
req_id_to_index={
319+
req.request_id: i
320+
for i, req in enumerate(requests)
321+
},
322+
sampled_token_ids=[[EOS_TOKEN_ID], [
323+
10, 11
324+
]], # First request hits EOS, second continues
325+
spec_token_ids=None,
326+
logprobs=None,
327+
prompt_logprobs_dict={},
328+
pooler_output=[])
329+
else:
330+
model_output = ModelRunnerOutput(
331+
req_ids=[req.request_id for req in requests],
332+
req_id_to_index={
333+
req.request_id: i
334+
for i, req in enumerate(requests)
335+
},
336+
sampled_token_ids=[[EOS_TOKEN_ID], [
337+
10, 11
338+
]], # First request hits EOS, second continues
339+
logprobs=None,
340+
prompt_logprobs_dict={},
341+
pooler_output=[])
306342

307343
scheduler.update_from_output(scheduler_output, model_output)
308344

@@ -325,8 +361,7 @@ def test_stop_via_update_from_output(self):
325361
req.num_computed_tokens = req.num_tokens
326362
scheduler.requests[req.request_id] = req
327363
scheduler.running.append(req)
328-
if not vllm_version_is("0.9.2"):
329-
req.status = RequestStatus.RUNNING
364+
req.status = RequestStatus.RUNNING
330365

331366
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
332367
scheduled_cached_reqs=[],
@@ -346,18 +381,31 @@ def test_stop_via_update_from_output(self):
346381
free_encoder_input_ids=[],
347382
structured_output_request_ids={},
348383
grammar_bitmask=None)
349-
350-
model_output = ModelRunnerOutput(
351-
req_ids=[req.request_id for req in requests],
352-
req_id_to_index={
353-
req.request_id: i
354-
for i, req in enumerate(requests)
355-
},
356-
sampled_token_ids=[[10, 42, 12],
357-
[13, 14]], # First request hits stop token
358-
logprobs=None,
359-
prompt_logprobs_dict={},
360-
pooler_output=[])
384+
if vllm_version_is("0.10.1.1"):
385+
model_output = ModelRunnerOutput(
386+
req_ids=[req.request_id for req in requests],
387+
req_id_to_index={
388+
req.request_id: i
389+
for i, req in enumerate(requests)
390+
},
391+
sampled_token_ids=[[10, 42, 12],
392+
[13, 14]], # First request hits stop token
393+
spec_token_ids=None,
394+
logprobs=None,
395+
prompt_logprobs_dict={},
396+
pooler_output=[])
397+
else:
398+
model_output = ModelRunnerOutput(
399+
req_ids=[req.request_id for req in requests],
400+
req_id_to_index={
401+
req.request_id: i
402+
for i, req in enumerate(requests)
403+
},
404+
sampled_token_ids=[[10, 42, 12],
405+
[13, 14]], # First request hits stop token
406+
logprobs=None,
407+
prompt_logprobs_dict={},
408+
pooler_output=[])
361409

362410
scheduler.update_from_output(scheduler_output, model_output)
363411

@@ -379,8 +427,7 @@ def test_stop_via_update_from_output(self):
379427
req.num_computed_tokens = req.num_tokens
380428
scheduler.requests[req.request_id] = req
381429
scheduler.running.append(req)
382-
if not vllm_version_is("0.9.2"):
383-
req.status = RequestStatus.RUNNING
430+
req.status = RequestStatus.RUNNING
384431

385432
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
386433
scheduled_cached_reqs=[],
@@ -401,18 +448,31 @@ def test_stop_via_update_from_output(self):
401448
structured_output_request_ids={},
402449
grammar_bitmask=None)
403450

404-
model_output = ModelRunnerOutput(
405-
req_ids=[req.request_id for req in requests],
406-
req_id_to_index={
407-
req.request_id: i
408-
for i, req in enumerate(requests)
409-
},
410-
sampled_token_ids=[[10, 11, 12],
411-
[13]], # First request exceeds max_tokens
412-
logprobs=None,
413-
prompt_logprobs_dict={},
414-
pooler_output=[])
415-
451+
if vllm_version_is("0.10.1.1"):
452+
model_output = ModelRunnerOutput(
453+
req_ids=[req.request_id for req in requests],
454+
req_id_to_index={
455+
req.request_id: i
456+
for i, req in enumerate(requests)
457+
},
458+
sampled_token_ids=[[10, 11, 12],
459+
[13]], # First request exceeds max_tokens
460+
spec_token_ids=None,
461+
logprobs=None,
462+
prompt_logprobs_dict={},
463+
pooler_output=[])
464+
else:
465+
model_output = ModelRunnerOutput(
466+
req_ids=[req.request_id for req in requests],
467+
req_id_to_index={
468+
req.request_id: i
469+
for i, req in enumerate(requests)
470+
},
471+
sampled_token_ids=[[10, 11, 12],
472+
[13]], # First request exceeds max_tokens
473+
logprobs=None,
474+
prompt_logprobs_dict={},
475+
pooler_output=[])
416476
scheduler.update_from_output(scheduler_output, model_output)
417477

418478
# Verify first request stopped due to length
@@ -448,13 +508,24 @@ def test_stop_via_update_from_output(self):
448508
structured_output_request_ids={},
449509
grammar_bitmask=None)
450510

451-
model_output = ModelRunnerOutput(
452-
req_ids=[requests[0].request_id],
453-
req_id_to_index={requests[0].request_id: 0},
454-
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
455-
logprobs=None,
456-
prompt_logprobs_dict={},
457-
pooler_output=[])
511+
if vllm_version_is("0.10.1.1"):
512+
model_output = ModelRunnerOutput(
513+
req_ids=[requests[0].request_id],
514+
req_id_to_index={requests[0].request_id: 0},
515+
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
516+
spec_token_ids=None,
517+
logprobs=None,
518+
prompt_logprobs_dict={},
519+
pooler_output=[])
520+
521+
else:
522+
model_output = ModelRunnerOutput(
523+
req_ids=[requests[0].request_id],
524+
req_id_to_index={requests[0].request_id: 0},
525+
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
526+
logprobs=None,
527+
prompt_logprobs_dict={},
528+
pooler_output=[])
458529

459530
scheduler.update_from_output(scheduler_output, model_output)
460531

@@ -505,13 +576,23 @@ def test_schedule_concurrent_batches(self):
505576
512)
506577

507578
# Model output of the first request.
508-
model_runner_output = ModelRunnerOutput(
509-
req_ids=[requests[0].request_id],
510-
req_id_to_index={requests[0].request_id: 0},
511-
sampled_token_ids=[[0]],
512-
logprobs=None,
513-
prompt_logprobs_dict={},
514-
pooler_output=[])
579+
if vllm_version_is("0.10.1.1"):
580+
model_runner_output = ModelRunnerOutput(
581+
req_ids=[requests[0].request_id],
582+
req_id_to_index={requests[0].request_id: 0},
583+
sampled_token_ids=[[0]],
584+
spec_token_ids=None,
585+
logprobs=None,
586+
prompt_logprobs_dict={},
587+
pooler_output=[])
588+
else:
589+
model_runner_output = ModelRunnerOutput(
590+
req_ids=[requests[0].request_id],
591+
req_id_to_index={requests[0].request_id: 0},
592+
sampled_token_ids=[[0]],
593+
logprobs=None,
594+
prompt_logprobs_dict={},
595+
pooler_output=[])
515596

516597
scheduler.update_from_output(scheduler_output0,
517598
model_runner_output)
@@ -521,13 +602,23 @@ def test_schedule_concurrent_batches(self):
521602
# request is still running.
522603
scheduler.schedule()
523604
# Model output of the second request.
524-
model_runner_output = ModelRunnerOutput(
525-
req_ids=[requests[1].request_id],
526-
req_id_to_index={requests[1].request_id: 0},
527-
sampled_token_ids=[[0]],
528-
logprobs=None,
529-
prompt_logprobs_dict={},
530-
pooler_output=[])
605+
if vllm_version_is("0.10.1.1"):
606+
model_runner_output = ModelRunnerOutput(
607+
req_ids=[requests[1].request_id],
608+
req_id_to_index={requests[1].request_id: 0},
609+
sampled_token_ids=[[0]],
610+
spec_token_ids=None,
611+
logprobs=None,
612+
prompt_logprobs_dict={},
613+
pooler_output=[])
614+
else:
615+
model_runner_output = ModelRunnerOutput(
616+
req_ids=[requests[1].request_id],
617+
req_id_to_index={requests[1].request_id: 0},
618+
sampled_token_ids=[[0]],
619+
logprobs=None,
620+
prompt_logprobs_dict={},
621+
pooler_output=[])
531622

532623
scheduler.update_from_output(scheduler_output1,
533624
model_runner_output)
@@ -579,19 +670,29 @@ def test_schedule_spec_decoding_stats(self):
579670
req_id = requests[i].request_id
580671
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
581672
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
582-
583-
model_runner_output = ModelRunnerOutput(
584-
req_ids=req_ids,
585-
req_id_to_index=req_to_index,
586-
sampled_token_ids=[[0] for _ in range(len(requests))],
587-
logprobs=None,
588-
prompt_logprobs_dict={},
589-
pooler_output=[])
590-
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
673+
if vllm_version_is("0.10.1.1"):
674+
model_runner_output = ModelRunnerOutput(
675+
req_ids=req_ids,
676+
req_id_to_index=req_to_index,
677+
sampled_token_ids=[[0] for _ in range(len(requests))],
678+
logprobs=None,
679+
prompt_logprobs_dict={},
680+
spec_token_ids=spec_tokens,
681+
pooler_output=[])
682+
else:
683+
model_runner_output = ModelRunnerOutput(
684+
req_ids=req_ids,
685+
req_id_to_index=req_to_index,
686+
sampled_token_ids=[[0] for _ in range(len(requests))],
687+
logprobs=None,
688+
prompt_logprobs_dict={},
689+
pooler_output=[])
690+
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
591691

592692
engine_core_outputs = scheduler.update_from_output(
593693
output, model_runner_output)
594-
scheduler.update_draft_token_ids(draft_token_ids)
694+
if not vllm_version_is("0.10.1.1"):
695+
scheduler.update_draft_token_ids(draft_token_ids)
595696

596697
for i in range(len(requests)):
597698
running_req = scheduler.running[i]
@@ -627,14 +728,23 @@ def test_schedule_spec_decoding_stats(self):
627728
else:
628729
self.assertNotIn(req_id,
629730
output.scheduled_spec_decode_tokens)
630-
631-
model_runner_output = ModelRunnerOutput(
632-
req_ids=req_ids,
633-
req_id_to_index=req_to_index,
634-
sampled_token_ids=output_tokens,
635-
logprobs=None,
636-
prompt_logprobs_dict={},
637-
pooler_output=[])
731+
if vllm_version_is("0.10.1.1"):
732+
model_runner_output = ModelRunnerOutput(
733+
req_ids=req_ids,
734+
req_id_to_index=req_to_index,
735+
sampled_token_ids=output_tokens,
736+
spec_token_ids=None,
737+
logprobs=None,
738+
prompt_logprobs_dict={},
739+
pooler_output=[])
740+
else:
741+
model_runner_output = ModelRunnerOutput(
742+
req_ids=req_ids,
743+
req_id_to_index=req_to_index,
744+
sampled_token_ids=output_tokens,
745+
logprobs=None,
746+
prompt_logprobs_dict={},
747+
pooler_output=[])
638748

639749
engine_core_outputs = scheduler.update_from_output(
640750
output, model_runner_output)

0 commit comments

Comments
 (0)