From bcb8a0f869df3a2b6270f54fff2efef02f158b45 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Thu, 20 Nov 2025 13:47:14 -0800 Subject: [PATCH 1/5] Revert "[Redo] #26368 (#28771)" This reverts commit 98b4d389ed27f09fd185ade889a02f640a3ff0b4. Signed-off-by: Jialin Ouyang --- tests/v1/core/test_async_scheduler.py | 3 +- .../v1/core/test_priority_scheduler_random.py | 6 +- tests/v1/core/test_scheduler.py | 88 ++++++++----------- .../kv_connector/unit/test_nixl_connector.py | 7 +- tests/v1/kv_connector/unit/utils.py | 3 +- tests/v1/spec_decode/test_eagle.py | 5 +- tests/v1/spec_decode/test_ngram.py | 18 ++-- vllm/v1/core/sched/scheduler.py | 4 +- vllm/v1/outputs.py | 4 +- vllm/v1/sample/rejection_sampler.py | 8 +- vllm/v1/spec_decode/eagle.py | 7 +- vllm/v1/spec_decode/ngram_proposer.py | 6 +- vllm/v1/spec_decode/suffix_decoding.py | 10 +-- vllm/v1/worker/gpu_model_runner.py | 46 +++------- vllm/v1/worker/tpu_model_runner.py | 8 +- 15 files changed, 92 insertions(+), 131 deletions(-) diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 1d80ee987591..e0645ed43015 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import deque -import numpy as np import pytest from vllm.v1.core.sched.output import SchedulerOutput @@ -22,7 +21,7 @@ def _make_model_runner_output( return ModelRunnerOutput( req_ids=req_ids, req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)}, - sampled_token_ids=[np.array([i]) for i in range(len(req_ids))], + sampled_token_ids=[[i] for i in range(len(req_ids))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/tests/v1/core/test_priority_scheduler_random.py b/tests/v1/core/test_priority_scheduler_random.py index ba0b703302e3..b4805be80272 100644 --- a/tests/v1/core/test_priority_scheduler_random.py +++ b/tests/v1/core/test_priority_scheduler_random.py @@ -3,7 +3,6 @@ import random import uuid -import numpy as np import pytest from vllm.config import VllmConfig @@ -100,7 +99,8 @@ def _mock_execute_model( random.randint(*num_output_tokens_range) for _ in range(len(request_ids)) ] sampled_token_ids = [ - np.random.randint(0, 100, size=num_tokens) for num_tokens in num_output_tokens + [random.randint(0, 100) for _ in range(num_tokens)] + for num_tokens in num_output_tokens ] return ModelRunnerOutput( @@ -196,8 +196,6 @@ def test_priority_scheduling_blast( num_blocks: int, ): random.seed(42) - np.random.seed(42) - seen_request_prompt_length = dict[str, int]() seen_request_ids = set[str]() seen_mm_hashes = set[str]() diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 0570c0854c67..04e738293cd7 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -3,7 +3,6 @@ import dataclasses from unittest.mock import Mock -import numpy as np import pytest import torch @@ -170,7 +169,7 @@ def test_schedule_partial_requests(): req_id_to_index=req_to_index, # Only the first request has a sampled token id because # the rest requests are still being prefilled. - sampled_token_ids=[np.array([0]), np.array([]), np.array([])], + sampled_token_ids=[[0], [], []], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -217,7 +216,7 @@ def test_no_mm_input_chunking(): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[np.array([]) for _ in range(len(requests))], + sampled_token_ids=[[] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -277,7 +276,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[np.array([]) for _ in range(len(requests))], + sampled_token_ids=[[] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -301,8 +300,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[np.array([0]), np.array([0])] - + [np.array([]) for _ in range(len(requests) - 2)], + sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -349,8 +347,8 @@ def test_stop_via_update_from_output(): req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[ - np.array([EOS_TOKEN_ID]), - np.array([10, 11]), + [EOS_TOKEN_ID], + [10, 11], ], # First request hits EOS, second continues logprobs=None, prompt_logprobs_dict={}, @@ -394,10 +392,7 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[ - np.array([10, 42, 12]), - np.array([13, 14]), - ], # First request hits stop token + sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -441,10 +436,7 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[ - np.array([10, 11, 12]), - np.array([13]), - ], # First request exceeds max_tokens + sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -483,7 +475,7 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])], + sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -624,7 +616,7 @@ def test_schedule_concurrent_batches( model_runner_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[np.array([0])], + sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -641,7 +633,7 @@ def test_schedule_concurrent_batches( model_runner_output = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[np.array([0])], + sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -678,7 +670,7 @@ def test_preempt_during_execution(): model_runner_output0 = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[np.array([0])], + sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -695,7 +687,7 @@ def test_preempt_during_execution(): model_runner_output1 = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[np.array([42])], + sampled_token_ids=[[42]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -712,18 +704,14 @@ def test_preempt_during_execution(): @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", [ - ([[1, 2, 3]], [np.array([1, 2, 3, 4])], (1, 3, 3, [1, 1, 1])), # perfect match - ([[1, 2, 3]], [np.array([1, 5])], (1, 3, 1, [1, 0, 0])), # early mismatch - ( - [[1, 2], [3]], - [np.array([1, 2, 5]), np.array([3, 4])], - (2, 3, 3, [2, 1]), - ), # multiple sequences - ([[1]], [np.array([1, 2])], (1, 1, 1, [1])), # single token sequence - ([[]], [np.array([5])], (0, 0, 0, [0])), # empty sequence + ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match + ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch + ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences + ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence + ([[]], [[5]], (0, 0, 0, [0])), # empty sequence ( [[1, 2, 3], [4, 5, 6]], - [np.array([1, 2, 7]), np.array([4, 8])], + [[1, 2, 7], [4, 8]], (2, 6, 3, [2, 1, 0]), ), # multiple mismatches ], @@ -757,7 +745,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[np.array([0]) for _ in range(len(requests))], + sampled_token_ids=[[0] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -984,7 +972,7 @@ def test_kv_connector_basic(is_async: bool): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[np.array([1000])] * len(req_ids), + sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1037,7 +1025,7 @@ def test_kv_connector_basic(is_async: bool): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[np.array([1000])] * len(req_ids), + sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1100,7 +1088,7 @@ def test_external_prefix_cache_metrics(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=[r.request_id for r in requests], req_id_to_index={r.request_id: i for i, r in enumerate(requests)}, - sampled_token_ids=[np.array([1000])] * NUM_REQUESTS, + sampled_token_ids=[[1000]] * NUM_REQUESTS, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1166,7 +1154,7 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[np.array([1000])] * len(req_ids), + sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1251,7 +1239,7 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[np.array([1000])] * len(req_ids), + sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1344,7 +1332,7 @@ def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)}, - sampled_token_ids=[np.array([1000])] * len(scheduler.running), + sampled_token_ids=[[1000]] * len(scheduler.running), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1761,7 +1749,7 @@ def test_priority_scheduling_preemption(): req_id_to_index={ req.request_id: i for i, req in enumerate(low_priority_requests) }, - sampled_token_ids=[np.array([100]) for _ in low_priority_requests], + sampled_token_ids=[[100] for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1830,7 +1818,7 @@ def test_priority_scheduling_no_preemption_when_space_available(): req_id_to_index={ req.request_id: i for i, req in enumerate(low_priority_requests) }, - sampled_token_ids=[np.array([100]) for _ in low_priority_requests], + sampled_token_ids=[[100] for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -2076,7 +2064,7 @@ def test_priority_scheduling_heap_property(): model_output = ModelRunnerOutput( req_ids=[req.req_id], req_id_to_index={req.req_id: 0}, - sampled_token_ids=[np.array([100])], + sampled_token_ids=[[100]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -2162,7 +2150,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv( model_output = ModelRunnerOutput( req_ids=[request_low.request_id], req_id_to_index={request_low.request_id: 0}, - sampled_token_ids=[np.array([100])], + sampled_token_ids=[[100]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2193,7 +2181,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv( model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[np.array([100]) for _ in requests], + sampled_token_ids=[[100] for _ in requests], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2219,7 +2207,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv( model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[np.array([]), np.array([100])], + sampled_token_ids=[[], [100]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2636,7 +2624,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector): model_output = ModelRunnerOutput( req_ids=[request1.request_id], req_id_to_index={request1.request_id: 0}, - sampled_token_ids=[np.array([100])], + sampled_token_ids=[[100]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2842,7 +2830,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[np.array([1000])] * len(req_ids), + sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -2955,7 +2943,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption( model_output = ModelRunnerOutput( req_ids=[request_low.request_id], req_id_to_index={request_low.request_id: 0}, - sampled_token_ids=[np.array([100])], + sampled_token_ids=[[100]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -3006,7 +2994,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption( model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[np.array([100]) for _ in requests], + sampled_token_ids=[[100] for _ in requests], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -3041,7 +3029,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption( model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[np.array([100]), np.array([100, 200])], + sampled_token_ids=[[100], [100, 200]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -3227,7 +3215,7 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto model_output = ModelRunnerOutput( req_ids=[request1.request_id, request2.request_id], req_id_to_index={request1.request_id: 0, request2.request_id: 1}, - sampled_token_ids=[np.array([100]), np.array([121])], + sampled_token_ids=[[100], [121]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b264e5108c16..b7d7a10057b8 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -11,7 +11,6 @@ from collections import defaultdict from unittest.mock import patch -import numpy as np import pytest import ray import torch @@ -827,7 +826,7 @@ def test_kv_connector_stats_aggregation(): output = ModelRunnerOutput( req_ids=[f"req_{i}"], req_id_to_index={f"req_{i}": 0}, - sampled_token_ids=[np.array([123])], # dummy token + sampled_token_ids=[[123]], # dummy token logprobs=None, prompt_logprobs_dict={}, pooler_output=[None], @@ -908,7 +907,7 @@ def make_multi_stats(nixl_count: int, foo_count: int) -> MultiKVConnectorStats: output = ModelRunnerOutput( req_ids=[f"req_{i}"], req_id_to_index={f"req_{i}": 0}, - sampled_token_ids=[np.array([123])], + sampled_token_ids=[[123]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[None], @@ -966,7 +965,7 @@ def test_scheduler_kv_connector_stats_aggregation(): model_output = ModelRunnerOutput( req_ids=["req_0"], req_id_to_index={"req_0": 0}, - sampled_token_ids=[np.array([123])], + sampled_token_ids=[[123]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[None], diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index c248104d5b5e..f35f91bb3adf 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -7,7 +7,6 @@ from itertools import chain, count from typing import Any -import numpy as np import torch from vllm import SamplingParams @@ -229,7 +228,7 @@ def create_model_runner_output( # Make sampled tokens. sampled_token = EOS_TOKEN_ID if use_eos else token_id - sampled_token_ids = [np.array([sampled_token]) for _ in req_ids] + sampled_token_ids = [[sampled_token] for _ in req_ids] kv_connector_output = ( None diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 805b8c86b080..c93c59d1f4c4 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -3,7 +3,6 @@ from unittest import mock -import numpy as np import pytest import torch @@ -113,9 +112,7 @@ def test_prepare_next_token_ids(): sampled_token_ids_tensor = torch.tensor( sampled_token_ids, dtype=torch.int32, device=device ) - sampled_token_ids_cpu = [ - np.array([i for i in seq if i != -1]) for seq in sampled_token_ids - ] + sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids] expected_next_token_ids_cpu = [1, 4, 30, 40] expected_next_token_ids_tensor = torch.tensor( diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 563bc1d957f4..692c39282c37 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -77,7 +77,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match. token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[np.array([0])], + sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -88,7 +88,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( - sampled_token_ids=[np.array([0])], + sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -99,7 +99,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram but match for 3-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( - sampled_token_ids=[np.array([0])], + sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -111,7 +111,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # In this case, the proposer should return the 4-gram match. token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( - sampled_token_ids=[np.array([0])], + sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -122,7 +122,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Match for 2-gram and 3-gram, but not 4-gram. token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( - sampled_token_ids=[np.array([0])], + sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -133,7 +133,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Multiple 3-gram matched, but always pick the first one. token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( - sampled_token_ids=[np.array([0])], + sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -144,7 +144,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # check empty input token_ids_cpu = np.array([[]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[np.array([0])], + sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -157,7 +157,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # second request has 3 tokens and no match. Padded with -1 for max len 5 token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[np.array([0]), np.array([1])], + sampled_token_ids=[[0], [1]], req_ids=["0", "1"], num_tokens_no_spec=np.array([5, 3]), token_ids_cpu=token_ids_cpu, @@ -181,7 +181,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: input_2[:3] = [4, 5, 6] token_ids_cpu = np.array([input_1, input_2]) result = ngram_proposer.propose( - sampled_token_ids=[np.array([0]), np.array([1])], + sampled_token_ids=[[0], [1]], req_ids=["0", "1"], num_tokens_no_spec=np.array([len(input_1), 3]), token_ids_cpu=token_ids_cpu, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4cc4c29591cc..1ac8520a8ed2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1013,8 +1013,8 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids: list[int] = ( - sampled_token_ids[req_index].tolist() if sampled_token_ids else [] + generated_token_ids = ( + sampled_token_ids[req_index] if sampled_token_ids else [] ) scheduled_spec_token_ids = ( diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index c0b2835c3124..e32d5bb608b1 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -158,7 +158,7 @@ class ModelRunnerOutput: # num_generated_tokens is the number of tokens # generated in the current step. It can be different for # each request due to speculative/jump decoding. - sampled_token_ids: list[np.ndarray] + sampled_token_ids: list[list[int]] # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] @@ -220,7 +220,7 @@ def make_empty_encoder_model_runner_output( req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)} # No tokens generated yet ⇒ one empty list per request - sampled_token_ids: list[list[int]] = [np.array([0]) for _ in req_ids] + sampled_token_ids: list[list[int]] = [[0] for _ in req_ids] # Pooler outputs are not available yet ⇒ use None placeholders pooler_output: list[torch.Tensor | None] = [None for _ in req_ids] diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index f31a0cddda9a..926305d25f56 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -3,7 +3,6 @@ from dataclasses import replace -import numpy as np import torch import torch.nn as nn @@ -205,7 +204,7 @@ def _get_logprobs_tensors( def parse_output( output_token_ids: torch.Tensor, vocab_size: int, - ) -> list[np.ndarray]: + ) -> list[list[int]]: """Parse the output of the rejection sampler. Args: output_token_ids: The sampled token IDs in shape @@ -221,7 +220,10 @@ def parse_output( valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( output_token_ids_np < vocab_size ) - return [row[valid_mask[i]] for i, row in enumerate(output_token_ids_np)] + outputs = [ + row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) + ] + return outputs def apply_logits_processors( self, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index ba37bc81607f..0df9cd3214e5 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -496,7 +496,7 @@ def propose( def prepare_next_token_ids_cpu( self, - sampled_token_ids: list[np.ndarray], + sampled_token_ids: list[list[int]], requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, num_scheduled_tokens: dict[str, int], @@ -511,7 +511,7 @@ def prepare_next_token_ids_cpu( req_ids = gpu_input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): - if token_ids.shape[0] > 0: + if token_ids: # Common case. next_token_id = token_ids[-1] else: @@ -522,9 +522,10 @@ def prepare_next_token_ids_cpu( seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) - return torch.tensor( + next_token_ids = torch.tensor( next_token_ids, dtype=torch.int32, device=self.input_ids.device ) + return next_token_ids def prepare_next_token_ids_padded( self, diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 378937dba988..e2f83cb24aa9 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -54,7 +54,7 @@ def __init__(self, vllm_config: VllmConfig): # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose( - [np.array([])] * 1024, + [[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), @@ -131,7 +131,7 @@ def batch_propose( def propose( self, - sampled_token_ids: list[np.ndarray], + sampled_token_ids: list[list[int]], req_ids: list[str], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, @@ -140,7 +140,7 @@ def propose( # find which requests need ngram proposals valid_ngram_requests = [] for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = sampled_ids.shape[0] + num_sampled_ids = len(sampled_ids) if not num_sampled_ids: # Skip speculative decoding. continue diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index d76e0ffe778d..049e335db325 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import numpy as np - from vllm.config import VllmConfig from vllm.v1.worker.gpu_input_batch import InputBatch @@ -34,16 +32,16 @@ def __init__(self, vllm_config: VllmConfig): def propose( self, input_batch: InputBatch, - sampled_token_ids: list[np.ndarray], + sampled_token_ids: list[list[int]], ) -> list[list[int]]: """ Propose speculative tokens for each request in the input batch. Suffix Decoding will speculate a dynamic number of tokens for each request every decoding step, so each entry in the returned list may have different lengths. """ - draft_token_ids: list[np.ndarray] = [] + draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): - if sampled_ids.shape[0] == 0: + if not sampled_ids: # Skip speculative decoding for partial prefills. draft_token_ids.append([]) continue @@ -72,7 +70,7 @@ def propose( self.suffix_cache.start_request(req_id, prompt_token_ids) # Append the newly sampled ids to the suffix cache for this request. - self.suffix_cache.add_active_response(req_id, sampled_ids.tolist()) + self.suffix_cache.add_active_response(req_id, sampled_ids) # Suffix decoding only uses the most recent tokens up to max_tree_depth, so # we extract the pattern from the end of the input. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4b0a08ab57e1..5a8f208c4748 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -217,18 +217,10 @@ def get_output(self) -> ModelRunnerOutput: # Release the device tensors once the copy has completed. del self._logprobs_tensors del self._sampled_token_ids - max_gen_len = self.sampled_token_ids_cpu.shape[-1] - if max_gen_len == 1: - valid_sampled_token_ids: list[np.ndarray] = [ - row for row in self.sampled_token_ids_cpu.numpy() - ] - else: - valid_sampled_token_ids = RejectionSampler.parse_output( - self.sampled_token_ids_cpu, - self.vocab_size, - ) + + valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() for i in self._invalid_req_indices: - valid_sampled_token_ids[i] = np.array([]) + valid_sampled_token_ids[i].clear() output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids @@ -2453,7 +2445,7 @@ def _bookkeeping_sync( ) -> tuple[ dict[str, int], LogprobsLists | None, - list[np.ndarray], + list[list[int]], dict[str, LogprobsTensors | None], list[str], dict[str, int], @@ -2479,7 +2471,6 @@ def _bookkeeping_sync( num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids invalid_req_indices = [] - valid_sampled_token_ids: list[np.ndarray] if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] @@ -2494,7 +2485,7 @@ def _bookkeeping_sync( ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)] = np.array([]) + valid_sampled_token_ids[int(i)].clear() else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist() @@ -2524,24 +2515,19 @@ def _bookkeeping_sync( [0] if spec_decode_metadata and logprobs_tensors else None ) for req_idx in range(num_sampled_tokens): - sampled_ids: np.ndarray | None if self.use_async_scheduling: - sampled_ids = ( - np.array([-1]) if req_idx not in invalid_req_indices_set else None - ) + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None else: sampled_ids = valid_sampled_token_ids[req_idx] - num_sampled_ids: int = ( - sampled_ids.shape[0] if sampled_ids is not None else 0 - ) + num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0 if cu_num_accepted_tokens is not None: cu_num_accepted_tokens.append( cu_num_accepted_tokens[-1] + num_sampled_ids ) - if sampled_ids is None or num_sampled_ids == 0: + if not sampled_ids: continue start_idx = self.input_batch.num_tokens_no_spec[req_idx] @@ -2923,11 +2909,7 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) - self.input_batch.prev_sampled_token_ids = None - - def propose_draft_token_ids( - sampled_token_ids: torch.Tensor | list[np.ndarray], - ) -> None: + def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None with record_function_or_nullcontext("gpu_model_runner: draft"): self._draft_token_ids = self.propose_draft_token_ids( @@ -3096,14 +3078,14 @@ def _get_valid_sampled_token_count(self) -> list[int]: def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: torch.Tensor | list[np.ndarray], + sampled_token_ids: torch.Tensor | list[list[int]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, - ) -> torch.Tensor | list[list[int]]: + ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(sampled_token_ids, list) @@ -3135,7 +3117,7 @@ def propose_draft_token_ids( for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, sampled_token_ids ): - indices.append(offset + tokens.shape[0] - 1) + indices.append(offset + len(tokens) - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] @@ -5114,7 +5096,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec - def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]: + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: # This is a short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/22754. # `tolist` would trigger a cuda wise stream sync, which @@ -5127,4 +5109,4 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]: pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() - return [row for row in pinned.numpy()] + return pinned.tolist() diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 923c31c187f3..24d1cc2162f0 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1251,15 +1251,13 @@ def concat_lists(input_lists): max_gen_len = selected_token_ids.shape[-1] if max_gen_len == 1: - valid_sampled_token_ids: list[np.ndarray] = [ - row for row in selected_token_ids.numpy() - ] + valid_sampled_token_ids = selected_token_ids.tolist() # Mask out the sampled tokens that should not be sampled. # TODO: Keep in sync with gpu_model_runner.py, in particular # the "else" case here for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i] = np.array([]) + valid_sampled_token_ids[i].clear() # Append sampled tokens for i, req_state, seq_len in request_seq_lens: @@ -1272,7 +1270,7 @@ def concat_lists(input_lists): valid_mask = selected_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() valid_sampled_token_ids = [ - seq.numpy() for seq in selected_token_ids[valid_mask].split(gen_lens) + seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens) ] self.input_batch.num_tokens[:num_reqs] += gen_lens for i, req_state, seq_len in request_seq_lens: From fc89e61c152f1b2661e777ee246a2758b1fc9c8f Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Thu, 20 Nov 2025 14:06:38 -0800 Subject: [PATCH 2/5] fix merge conflict Signed-off-by: Jialin Ouyang --- vllm/v1/worker/gpu_model_runner.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5a8f208c4748..85cdfcb23076 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -218,9 +218,16 @@ def get_output(self) -> ModelRunnerOutput: del self._logprobs_tensors del self._sampled_token_ids - valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() + max_gen_len = self.sampled_token_ids_cpu.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() + else: + valid_sampled_token_ids = RejectionSampler.parse_output( + self.sampled_token_ids_cpu, + self.vocab_size, + ) for i in self._invalid_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids[i] = [] output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids From b0c973fe0ff3a07b4523790cfb364a1983fd24f7 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Thu, 20 Nov 2025 14:08:04 -0800 Subject: [PATCH 3/5] simply the code Signed-off-by: Jialin Ouyang --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 85cdfcb23076..bcb3985689ec 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -227,7 +227,7 @@ def get_output(self) -> ModelRunnerOutput: self.vocab_size, ) for i in self._invalid_req_indices: - valid_sampled_token_ids[i] = [] + valid_sampled_token_ids[i].clear() output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids From aab2160a40e80b9b3bfefe4100f7d419dc2b0616 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Thu, 20 Nov 2025 14:09:21 -0800 Subject: [PATCH 4/5] revert blank line Signed-off-by: Jialin Ouyang --- vllm/v1/worker/gpu_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bcb3985689ec..0e7a0b00d690 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -217,7 +217,6 @@ def get_output(self) -> ModelRunnerOutput: # Release the device tensors once the copy has completed. del self._logprobs_tensors del self._sampled_token_ids - max_gen_len = self.sampled_token_ids_cpu.shape[-1] if max_gen_len == 1: valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() From f92d39f05488f90a6cf185b03ac47ce7d68acc59 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Thu, 20 Nov 2025 16:46:58 -0800 Subject: [PATCH 5/5] Fix async scheduling Signed-off-by: Jialin Ouyang --- vllm/utils/gc_utils.py | 13 ++++++++----- vllm/v1/worker/gpu_model_runner.py | 2 ++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/utils/gc_utils.py b/vllm/utils/gc_utils.py index 3436e450a269..c56b1794230e 100644 --- a/vllm/utils/gc_utils.py +++ b/vllm/utils/gc_utils.py @@ -53,6 +53,7 @@ def __init__(self, config: GCDebugConfig) -> None: self.config = config # Start time in micro second of this GC cycle self.start_time_ns: int = time.monotonic_ns() + self.num_objects: int = 0 # If config.top_objects is positive, # compute top collected objects by object types self.gc_top_collected_objects: str = "" @@ -68,19 +69,21 @@ def handle(self, phase: str, info: dict[str, int]) -> None: # Before GC started, record GC start time # and top collected objects self.start_time_ns = time.monotonic_ns() - if (top_objects := self.config.top_objects) > 0: - self.gc_top_collected_objects = _compute_top_gc_collected_objects( - gc.get_objects(generation), top_objects - ) + objects = gc.get_objects(generation) + self.num_objects = len(objects) + self.gc_top_collected_objects = _compute_top_gc_collected_objects( + objects, self.config.top_objects + ) elif phase == "stop": # After GC finished, Record GC elapsed time and # optionally top collected objects elpased_ms = (time.monotonic_ns() - self.start_time_ns) / 1e6 logger.info( "GC took %.3fms to complete. " - "Collected %s objects in GC generation %d.%s", + "Collected %s objects (out of %d) in GC generation %d.%s", elpased_ms, str(info.get("collected", "?")), + self.num_objects, generation, ( f" Top collected objects: \n{self.gc_top_collected_objects}" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0e7a0b00d690..dd3a3ee174cb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2915,6 +2915,8 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + self.input_batch.prev_sampled_token_ids = None + def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None with record_function_or_nullcontext("gpu_model_runner: draft"):