Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,7 @@ steps:
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- (cd .. && VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=2 -m tests.distributed.test_ngram | grep 'successfully passed!')
- pytest -v -s distributed/test_sequence_parallel.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- pytest -v -s v1/worker/test_worker_memory_snapshot.py
Expand Down
25 changes: 25 additions & 0 deletions tests/distributed/test_ngram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os

import torch

from tests.utils import init_test_distributed_environment
from tests.v1.spec_decode.test_ngram import test_ngram_proposer
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

if __name__ == "__main__":
pp_size = 1
local_rank = int(os.environ['LOCAL_RANK'])
tp_size = int(os.environ['WORLD_SIZE'])
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size=tp_size,
pp_size=pp_size,
rank=local_rank,
distributed_init_port=None,
local_rank=local_rank)

test_ngram_proposer()
cleanup_dist_env_and_memory()
print("test_ngram_distributed() successfully passed!")
7 changes: 5 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,13 @@ def init_test_distributed_environment(
tp_size: int,
pp_size: int,
rank: int,
distributed_init_port: str,
distributed_init_port: Optional[str] = None,
local_rank: int = -1,
) -> None:
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
if distributed_init_port:
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
else:
distributed_init_method = "env://"
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
Expand Down
1 change: 0 additions & 1 deletion tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def model_name():


def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
Expand Down
95 changes: 55 additions & 40 deletions vllm/v1/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from numba import get_num_threads, jit, njit, prange, set_num_threads

from vllm.config import VllmConfig
from vllm.distributed import get_tp_group


class NgramProposer:
Expand Down Expand Up @@ -35,7 +36,6 @@ def __init__(self, vllm_config: VllmConfig):
# Threshold of total number of tokens in the batch to enable
# multi-threading in numba batch propose.
self.num_tokens_threshold = 8192
tp_size = vllm_config.parallel_config.tensor_parallel_size
cpu_count = os.cpu_count()
# Max number of threads for numba parallel processing.
if cpu_count:
Expand All @@ -44,15 +44,18 @@ def __init__(self, vllm_config: VllmConfig):
# Cap the number of threads to 8 to avoid using too many threads
# since other components like frontend (incl tokenization)
# and Structured Outputs also use multiple threads.
# TODO(ekagra-ranjan): bump up the cap from 1 to 8
# when TP parallelization for ngram is implemented.
self.num_numba_thread_available = min(1, (cpu_count // 2))
# Divide by tp_size to ensure each tensor parallel rank
# has some threads since all ranks will run this.
self.num_numba_thread_available //= tp_size
self.num_numba_thread_available = min(8, (cpu_count // 2))
else:
self.num_numba_thread_available = 1

# Tensor parallel group for TP parallel ngram.
# Rank 0 will run the ngram proposer and broadcast the results
# to other ranks. This is done so that all CPU threads is available
# to rank 0 to run the ngram proposer, instead of dividing CPU threads
# among all TP ranks.
self.tp_group = get_tp_group()
self.leader_rank = 0

# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32),
Expand Down Expand Up @@ -85,39 +88,51 @@ def batch_propose(
"""
draft_token_ids: list[list[int]] = []

# Only run batch propose if there are requests needing ngram proposals.
# avoid calling numba function with empty list which causes error
# ValueError: cannot compute fingerprint of empty list
if num_ngram_requests := len(valid_ngram_requests):
original_num_numba_threads = get_num_threads()
# Ensure we use at least one thread.
# If total tokens is small, using multiple threads
# may slow down due to overhead.
total_tokens = np.sum(num_tokens_no_spec)
if total_tokens >= self.num_tokens_threshold:
final_num_threads = max(
1, min(self.num_numba_thread_available,
num_ngram_requests))
set_num_threads(final_num_threads)
else:
set_num_threads(1)

batch_propose_numba(valid_ngram_requests, num_tokens_no_spec,
token_ids_cpu, self.min_n, self.max_n,
self.max_model_len, self.k,
self.valid_ngram_draft,
self.valid_ngram_num_drafts)

# Restore original number of threads.
set_num_threads(original_num_numba_threads)

for i in range(num_requests):
if i in valid_ngram_requests and \
self.valid_ngram_num_drafts[i] > 0:
draft_token_ids.append(self.valid_ngram_draft[
i, :self.valid_ngram_num_drafts[i]].tolist())
else:
draft_token_ids.append([])
# Only rank 0 will run the ngram proposer
# and broadcast the results to other ranks.
if self.tp_group is None or self.tp_group.rank == self.leader_rank:
# Only run batch propose if there are requests needing
# ngram proposals. Avoid calling numba function with empty
# list which causes error:
# ValueError: cannot compute fingerprint of empty list
if num_ngram_requests := len(valid_ngram_requests):
original_num_numba_threads = get_num_threads()
# Ensure we use at least one thread.
# If total tokens is small, using multiple threads
# may slow down due to overhead.
total_tokens = np.sum(num_tokens_no_spec)
if total_tokens >= self.num_tokens_threshold:
final_num_threads = max(
1,
min(self.num_numba_thread_available,
num_ngram_requests))
set_num_threads(final_num_threads)
else:
set_num_threads(1)

batch_propose_numba(valid_ngram_requests, num_tokens_no_spec,
token_ids_cpu, self.min_n, self.max_n,
self.max_model_len, self.k,
self.valid_ngram_draft,
self.valid_ngram_num_drafts)

# Restore original number of threads.
set_num_threads(original_num_numba_threads)

for i in range(num_requests):
if i in valid_ngram_requests and \
self.valid_ngram_num_drafts[i] > 0:
draft_token_ids.append(self.valid_ngram_draft[
i, :self.valid_ngram_num_drafts[i]].tolist())
else:
draft_token_ids.append([])
else:
draft_token_ids = [[] for _ in range(num_requests)]

# Broadcast from rank 0 to other ranks using GroupCoordinator
if self.tp_group is not None and self.tp_group.world_size > 1:
draft_token_ids = self.tp_group.broadcast_object_list(
draft_token_ids, src=self.leader_rank)

return draft_token_ids

Expand Down