Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f7a6356
replace narrow-usage set_cuda_visible_devices to general update_envir…
youkaichao Apr 12, 2024
bbdfc69
add warning when env is overwritten
youkaichao Apr 12, 2024
1e62614
use logger.warning
youkaichao Apr 12, 2024
37eb344
fix env copy
youkaichao Apr 12, 2024
6f64b48
avoid overwritten warning in ray
youkaichao Apr 12, 2024
0499106
fix lint
youkaichao Apr 12, 2024
d26672f
allow heterogeneous args in _run_workers; move update_environment_var…
youkaichao Apr 12, 2024
3a01337
unified init worker
youkaichao Apr 12, 2024
c85d040
fix recursion
youkaichao Apr 12, 2024
5e49b98
on the fly local rank calculation
youkaichao Apr 12, 2024
37ed6c9
post update kwargs
youkaichao Apr 12, 2024
b654ee2
add remote
youkaichao Apr 12, 2024
e11448e
fix update_environment_variables in ray worker
youkaichao Apr 12, 2024
97e6601
use staticmethod
youkaichao Apr 12, 2024
fd2cbe2
fix dummy worker local_rank
youkaichao Apr 12, 2024
a8d7504
fix dummy worker rank
youkaichao Apr 12, 2024
e659635
add WorkerWrapperBase
youkaichao Apr 12, 2024
778fb3f
add all_args to _run_workers
youkaichao Apr 12, 2024
d295107
refactor
youkaichao Apr 12, 2024
7ca22a4
fix dangling self
youkaichao Apr 12, 2024
5f6c8f3
fix execute_method in driver worker
youkaichao Apr 12, 2024
13de66e
withdraw changes in many workers
youkaichao Apr 12, 2024
32ef3bb
no need for init_worker in workerbase
youkaichao Apr 12, 2024
221f626
unify worker_node_and_gpu_ids
youkaichao Apr 12, 2024
0087773
use id rather than ip
youkaichao Apr 12, 2024
36a185e
unify init
youkaichao Apr 12, 2024
95ca917
fix lint
youkaichao Apr 12, 2024
ea5f2a5
finish todo
youkaichao Apr 12, 2024
d10ca88
rename to RayWorkerWrapper
youkaichao Apr 12, 2024
a164219
Merge remote-tracking branch 'origin' into update_env
youkaichao Apr 16, 2024
eb27be9
fix mypy typing
youkaichao Apr 16, 2024
74deb44
move init hf decision to each worker
youkaichao Apr 16, 2024
3bd2c98
use quotes to address white space in env var values
youkaichao Apr 16, 2024
21be004
add docstring
youkaichao Apr 16, 2024
1aee6a0
add config
youkaichao Apr 16, 2024
4337ac6
Merge remote-tracking branch 'origin' into update_env
youkaichao Apr 17, 2024
40d4560
fix _run_workers_async
youkaichao Apr 17, 2024
2509db4
move duplicate code to utils
youkaichao Apr 17, 2024
d1bda36
add docstring
youkaichao Apr 17, 2024
1e30d89
use docstring
youkaichao Apr 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import multiprocessing
import os

import pytest
import torch

from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.utils import update_environment_variables


def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = os.environ.copy()
env = {}
env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
Expand All @@ -32,8 +32,7 @@ def update_env(fn):
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
import os
os.environ.update(env)
update_environment_variables(env)
fn()

return wrapper
Expand Down
36 changes: 6 additions & 30 deletions vllm/engine/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,26 @@

from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
from vllm.utils import get_ip, is_hip
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)

try:
import ray

class RayWorkerVllm:
class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""

def __init__(self, init_cached_hf_modules=False) -> None:
if init_cached_hf_modules:
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False

def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()

def __getattr__(self, name):
return getattr(self.worker, name)

def execute_method(self, method, *args, **kwargs):
try:
executor = getattr(self, method)
return executor(*args, **kwargs)
except Exception as e:
# exceptions in ray worker may cause deadlock
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e

def get_node_ip(self) -> str:
return get_ip()

Expand All @@ -52,9 +31,6 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids

def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids)

def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
import torch
Expand All @@ -71,7 +47,7 @@ def execute_model_compiled_dag_remote(self, ignored):
"For distributed inference, please install Ray with "
"`pip install ray`.")
ray = None
RayWorkerVllm = None
RayWorkerWrapper = None


def initialize_ray_cluster(
Expand Down
137 changes: 71 additions & 66 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import copy
import os
import pickle
from collections import defaultdict
Expand All @@ -8,13 +7,13 @@
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.engine.ray_utils import RayWorkerWrapper, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async, set_cuda_visible_devices)
make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -79,9 +78,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",

# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: RayWorkerVllm = None
self.driver_dummy_worker: RayWorkerWrapper = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerVllm] = []
self.workers: List[RayWorkerWrapper] = []

# Create the workers.
driver_ip = get_ip()
Expand All @@ -98,13 +97,22 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
)(RayWorkerWrapper).remote(
init_cached_hf_modules=self.model_config.trust_remote_code,
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
)

worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
init_cached_hf_modules=self.model_config.trust_remote_code,
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
)
else:
# Else, added to the list of workers.
self.workers.append(worker)
Expand All @@ -116,79 +124,55 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
"GPU node.")

# Get the set of GPU IDs used on each node.
driver_node_id, driver_gpu_ids = ray.get(
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
worker_node_and_gpu_ids = ray.get(
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)

node_workers = defaultdict(list)
node_gpus = defaultdict(list)

node_workers[driver_node_id].append(0)
node_gpus[driver_node_id].extend(driver_gpu_ids)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
start=1):
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)

# Set CUDA_VISIBLE_DEVICES for the driver and workers.
set_cuda_visible_devices(node_gpus[driver_node_id])
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
all_args_to_update_environment_variables = []
for (node_id, _) in worker_node_and_gpu_ids:
all_args_to_update_environment_variables.append([{
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id]))
}])
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)

distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())

# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker

model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
cache_config = copy.deepcopy(self.cache_config)
vision_language_config = copy.deepcopy(self.vision_language_config)

# Initialize the actual workers with the Worker class.
for rank, (worker, (node_id, _)) in enumerate(
zip(self.workers, worker_node_and_gpu_ids),
start=1,
):
def collect_arg_helper_func(**kwargs):
# avoid writing `{"name": value}` manually
return kwargs

init_worker_all_kwargs = []

# Initialize the actual workers inside worker wrapper.
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
local_rank = node_workers[node_id].index(rank)
worker.init_worker.remote(
lambda rank=rank, local_rank=local_rank: Worker(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
init_worker_all_kwargs.append(
collect_arg_helper_func(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=lora_config,
vision_language_config=vision_language_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=rank == 0,
))

# Initialize the driver worker with the Worker class.
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=driver_local_rank,
rank=driver_rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=True,
)
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

self._run_workers("init_device")
self._run_workers(
Expand Down Expand Up @@ -278,11 +262,26 @@ def _run_workers(
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
all_args: Optional[List[List[Any]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
"""Runs the given method on all workers.
all_args and all_kwargs are used to pass heterogeneous arguments,
i.e. different arguments for each worker.
"""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

if all_args is None:
all_args = [driver_args] + [args] * len(self.workers)
if all_kwargs is None:
all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)

if max_concurrent_workers:
raise NotImplementedError(
Expand All @@ -295,8 +294,10 @@ def _run_workers(
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_args[1:], all_kwargs[1:])
]

if driver_args is None:
Expand All @@ -305,9 +306,13 @@ def _run_workers(
driver_kwargs = kwargs

# Start the driver worker after all the ray workers.
driver_worker_output = getattr(self.driver_worker,
method)(*driver_args, **driver_kwargs)

if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *all_args[0], **all_kwargs[0])
else:
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *all_args[0], **all_kwargs[0]))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
Expand Down
8 changes: 6 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,12 @@ def get_open_port() -> int:
return s.getsockname()[1]


def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items():
if k in os.environ:
logger.warning(f"Overwriting environment variable {k} "
f"from {os.environ[k]} to {v}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do {os.environ[k]=} and {v=} so whitespace is obvious

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add single quotes to make it clear. {os.environ[k]=} will print os.environ[k] to users, which is confusing and not informative.

os.environ[k] = v


def chunk_list(lst, chunk_size):
Expand Down
Loading