77from vllm .config import VllmConfig
88from vllm .logger import init_logger
99from vllm .utils import get_distributed_init_method , get_ip , get_open_port
10+ from vllm .platforms import current_platform
1011from vllm .v1 .executor .abstract import Executor
1112from vllm .v1 .executor .ray_utils import (RayWorkerWrapper ,
1213 initialize_ray_cluster , ray )
@@ -27,13 +28,17 @@ def __init__(self, vllm_config: VllmConfig) -> None:
2728 self .vllm_config = vllm_config
2829 self .parallel_config = vllm_config .parallel_config
2930 self .model_config = vllm_config .model_config
31+
3032 self .forward_dag : Optional [ray .dag .CompiledDAG ] = None
3133
3234 # Disable Ray usage stats collection.
3335 ray_usage = os .environ .get ("RAY_USAGE_STATS_ENABLED" , "0" )
3436 if ray_usage != "1" :
3537 os .environ ["RAY_USAGE_STATS_ENABLED" ] = "0"
3638
39+ self .device_str = "TPU" if current_platform .is_tpu () else "GPU"
40+ self .use_dag = current_platform .is_cuda ()
41+
3742 initialize_ray_cluster (self .parallel_config )
3843 placement_group = self .parallel_config .placement_group
3944
@@ -42,16 +47,16 @@ def __init__(self, vllm_config: VllmConfig) -> None:
4247
4348 def _init_workers_ray (self , placement_group : "PlacementGroup" ,
4449 ** ray_remote_kwargs ):
45- # A list of workers to run a model.
46- self .workers : List [RayWorkerWrapper ] = []
47- if self .parallel_config .ray_workers_use_nsight :
50+ if (current_platform .is_cuda ()
51+ and self .parallel_config .ray_workers_use_nsight ):
4852 ray_remote_kwargs = self ._configure_ray_workers_use_nsight (
4953 ray_remote_kwargs )
5054
5155 # Create the workers.
56+ self .workers : List [RayWorkerWrapper ] = []
5257 driver_ip = get_ip ()
5358 for bundle_id , bundle in enumerate (placement_group .bundle_specs ):
54- if not bundle .get ("GPU" , 0 ):
59+ if not bundle .get (self . device_str , 0 ):
5560 # Skip bundles that don't have GPUs,
5661 # as each worker needs one GPU.
5762 continue
@@ -63,7 +68,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
6368
6469 worker = ray .remote (
6570 num_cpus = 0 ,
66- num_gpus = 1 ,
71+ resources = { self . device_str : 1 } ,
6772 scheduling_strategy = scheduling_strategy ,
6873 ** ray_remote_kwargs ,
6974 )(RayWorkerWrapper ).remote (vllm_config = self .vllm_config )
@@ -279,11 +284,14 @@ def execute_model(
279284 self ,
280285 scheduler_output ,
281286 ) -> ModelRunnerOutput :
282- if self .forward_dag is None :
283- self .forward_dag = self ._compiled_ray_dag ()
284- # Only the first worker (with rank 0) returns the execution result.
285- # Others return None.
286- output = ray .get (self .forward_dag .execute (scheduler_output ))[0 ]
287+ if self .use_dag :
288+ if self .forward_dag is None :
289+ self .forward_dag = self ._compiled_ray_dag ()
290+
291+ output = ray .get (self .forward_dag .execute (scheduler_output ))[0 ]
292+ else :
293+ output = self ._run_workers ("execute_model" , scheduler_output )[0 ]
294+
287295 return output
288296
289297 def profile (self , is_start = True ):
0 commit comments