diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b3d10f75ab50..6d275a7c60f7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -935,6 +935,7 @@ steps: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - (cd .. && VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=2 -m tests.distributed.test_ngram | grep 'successfully passed!') - pytest -v -s distributed/test_sequence_parallel.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - pytest -v -s v1/worker/test_worker_memory_snapshot.py diff --git a/tests/distributed/test_ngram.py b/tests/distributed/test_ngram.py new file mode 100644 index 000000000000..afca71b43462 --- /dev/null +++ b/tests/distributed/test_ngram.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +import torch + +from tests.utils import init_test_distributed_environment +from tests.v1.spec_decode.test_ngram import test_ngram_proposer +from vllm.distributed.parallel_state import cleanup_dist_env_and_memory + +if __name__ == "__main__": + pp_size = 1 + local_rank = int(os.environ['LOCAL_RANK']) + tp_size = int(os.environ['WORLD_SIZE']) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size=tp_size, + pp_size=pp_size, + rank=local_rank, + distributed_init_port=None, + local_rank=local_rank) + + test_ngram_proposer() + cleanup_dist_env_and_memory() + print("test_ngram_distributed() successfully passed!") diff --git a/tests/utils.py b/tests/utils.py index ffdc0f732543..ae01bdeaa30d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -671,10 +671,13 @@ def init_test_distributed_environment( tp_size: int, pp_size: int, rank: int, - distributed_init_port: str, + distributed_init_port: Optional[str] = None, local_rank: int = -1, ) -> None: - distributed_init_method = f"tcp://localhost:{distributed_init_port}" + if distributed_init_port: + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + else: + distributed_init_method = "env://" init_distributed_environment( world_size=pp_size * tp_size, rank=rank, diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 8f048775352e..9b2d9e27d0a3 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -80,7 +80,6 @@ def model_name(): def test_ngram_correctness( - monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, model_name: str, ): diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index aed050a3540c..8d7602b3db61 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -6,6 +6,7 @@ from numba import get_num_threads, jit, njit, prange, set_num_threads from vllm.config import VllmConfig +from vllm.distributed import get_tp_group class NgramProposer: @@ -35,7 +36,6 @@ def __init__(self, vllm_config: VllmConfig): # Threshold of total number of tokens in the batch to enable # multi-threading in numba batch propose. self.num_tokens_threshold = 8192 - tp_size = vllm_config.parallel_config.tensor_parallel_size cpu_count = os.cpu_count() # Max number of threads for numba parallel processing. if cpu_count: @@ -44,15 +44,18 @@ def __init__(self, vllm_config: VllmConfig): # Cap the number of threads to 8 to avoid using too many threads # since other components like frontend (incl tokenization) # and Structured Outputs also use multiple threads. - # TODO(ekagra-ranjan): bump up the cap from 1 to 8 - # when TP parallelization for ngram is implemented. - self.num_numba_thread_available = min(1, (cpu_count // 2)) - # Divide by tp_size to ensure each tensor parallel rank - # has some threads since all ranks will run this. - self.num_numba_thread_available //= tp_size + self.num_numba_thread_available = min(8, (cpu_count // 2)) else: self.num_numba_thread_available = 1 + # Tensor parallel group for TP parallel ngram. + # Rank 0 will run the ngram proposer and broadcast the results + # to other ranks. This is done so that all CPU threads is available + # to rank 0 to run the ngram proposer, instead of dividing CPU threads + # among all TP ranks. + self.tp_group = get_tp_group() + self.leader_rank = 0 + # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), @@ -85,39 +88,51 @@ def batch_propose( """ draft_token_ids: list[list[int]] = [] - # Only run batch propose if there are requests needing ngram proposals. - # avoid calling numba function with empty list which causes error - # ValueError: cannot compute fingerprint of empty list - if num_ngram_requests := len(valid_ngram_requests): - original_num_numba_threads = get_num_threads() - # Ensure we use at least one thread. - # If total tokens is small, using multiple threads - # may slow down due to overhead. - total_tokens = np.sum(num_tokens_no_spec) - if total_tokens >= self.num_tokens_threshold: - final_num_threads = max( - 1, min(self.num_numba_thread_available, - num_ngram_requests)) - set_num_threads(final_num_threads) - else: - set_num_threads(1) - - batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, - token_ids_cpu, self.min_n, self.max_n, - self.max_model_len, self.k, - self.valid_ngram_draft, - self.valid_ngram_num_drafts) - - # Restore original number of threads. - set_num_threads(original_num_numba_threads) - - for i in range(num_requests): - if i in valid_ngram_requests and \ - self.valid_ngram_num_drafts[i] > 0: - draft_token_ids.append(self.valid_ngram_draft[ - i, :self.valid_ngram_num_drafts[i]].tolist()) - else: - draft_token_ids.append([]) + # Only rank 0 will run the ngram proposer + # and broadcast the results to other ranks. + if self.tp_group is None or self.tp_group.rank == self.leader_rank: + # Only run batch propose if there are requests needing + # ngram proposals. Avoid calling numba function with empty + # list which causes error: + # ValueError: cannot compute fingerprint of empty list + if num_ngram_requests := len(valid_ngram_requests): + original_num_numba_threads = get_num_threads() + # Ensure we use at least one thread. + # If total tokens is small, using multiple threads + # may slow down due to overhead. + total_tokens = np.sum(num_tokens_no_spec) + if total_tokens >= self.num_tokens_threshold: + final_num_threads = max( + 1, + min(self.num_numba_thread_available, + num_ngram_requests)) + set_num_threads(final_num_threads) + else: + set_num_threads(1) + + batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, + token_ids_cpu, self.min_n, self.max_n, + self.max_model_len, self.k, + self.valid_ngram_draft, + self.valid_ngram_num_drafts) + + # Restore original number of threads. + set_num_threads(original_num_numba_threads) + + for i in range(num_requests): + if i in valid_ngram_requests and \ + self.valid_ngram_num_drafts[i] > 0: + draft_token_ids.append(self.valid_ngram_draft[ + i, :self.valid_ngram_num_drafts[i]].tolist()) + else: + draft_token_ids.append([]) + else: + draft_token_ids = [[] for _ in range(num_requests)] + + # Broadcast from rank 0 to other ranks using GroupCoordinator + if self.tp_group is not None and self.tp_group.world_size > 1: + draft_token_ids = self.tp_group.broadcast_object_list( + draft_token_ids, src=self.leader_rank) return draft_token_ids