Skip to content

Commit d25ec0e

Browse files
committed
Add TP support
1 parent b65ed98 commit d25ec0e

File tree

4 files changed

+21
-12
lines changed

4 files changed

+21
-12
lines changed

examples/offline_inference/offline_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
1212

1313
# Create an LLM.
14-
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16)
14+
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=256, max_num_seqs=16, tensor_parallel_size=4)
1515
# Generate texts from the prompts. The output is a list of RequestOutput objects
1616
# that contain the prompt, generated text, and other information.
1717
outputs = llm.generate(prompts, sampling_params)

requirements-tpu.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ setuptools-scm>=8
99
wheel
1010
jinja2
1111
ray[default]
12+
ray[adag] # TODO: Remove this
1213

1314
# Install torch_xla
1415
--pre

vllm/platforms/tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
8383
scheduler_config = vllm_config.scheduler_config
8484
if parallel_config.worker_cls == "auto":
8585
if envs.VLLM_USE_V1:
86-
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TRUWorker"
86+
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
8787
else:
8888
if scheduler_config.is_multi_step:
8989
parallel_config.worker_cls = \

vllm/v1/executor/ray_executor.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm.config import VllmConfig
88
from vllm.logger import init_logger
99
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
10+
from vllm.platforms import current_platform
1011
from vllm.v1.executor.abstract import Executor
1112
from vllm.v1.executor.ray_utils import (RayWorkerWrapper,
1213
initialize_ray_cluster, ray)
@@ -27,13 +28,17 @@ def __init__(self, vllm_config: VllmConfig) -> None:
2728
self.vllm_config = vllm_config
2829
self.parallel_config = vllm_config.parallel_config
2930
self.model_config = vllm_config.model_config
31+
3032
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
3133

3234
# Disable Ray usage stats collection.
3335
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
3436
if ray_usage != "1":
3537
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
3638

39+
self.device_str = "TPU" if current_platform.is_tpu() else "GPU"
40+
self.use_dag = current_platform.is_cuda()
41+
3742
initialize_ray_cluster(self.parallel_config)
3843
placement_group = self.parallel_config.placement_group
3944

@@ -42,16 +47,16 @@ def __init__(self, vllm_config: VllmConfig) -> None:
4247

4348
def _init_workers_ray(self, placement_group: "PlacementGroup",
4449
**ray_remote_kwargs):
45-
# A list of workers to run a model.
46-
self.workers: List[RayWorkerWrapper] = []
47-
if self.parallel_config.ray_workers_use_nsight:
50+
if (current_platform.is_cuda()
51+
and self.parallel_config.ray_workers_use_nsight):
4852
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
4953
ray_remote_kwargs)
5054

5155
# Create the workers.
56+
self.workers: List[RayWorkerWrapper] = []
5257
driver_ip = get_ip()
5358
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
54-
if not bundle.get("GPU", 0):
59+
if not bundle.get(self.device_str, 0):
5560
# Skip bundles that don't have GPUs,
5661
# as each worker needs one GPU.
5762
continue
@@ -63,7 +68,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
6368

6469
worker = ray.remote(
6570
num_cpus=0,
66-
num_gpus=1,
71+
resources={self.device_str: 1},
6772
scheduling_strategy=scheduling_strategy,
6873
**ray_remote_kwargs,
6974
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
@@ -279,11 +284,14 @@ def execute_model(
279284
self,
280285
scheduler_output,
281286
) -> ModelRunnerOutput:
282-
if self.forward_dag is None:
283-
self.forward_dag = self._compiled_ray_dag()
284-
# Only the first worker (with rank 0) returns the execution result.
285-
# Others return None.
286-
output = ray.get(self.forward_dag.execute(scheduler_output))[0]
287+
if self.use_dag:
288+
if self.forward_dag is None:
289+
self.forward_dag = self._compiled_ray_dag()
290+
291+
output = ray.get(self.forward_dag.execute(scheduler_output))[0]
292+
else:
293+
output = self._run_workers("execute_model", scheduler_output)[0]
294+
287295
return output
288296

289297
def profile(self, is_start=True):

0 commit comments

Comments
 (0)