-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
Description
Your current environment
Any distributed inference tasks with ray currently suffer from this issue.
🐛 Describe the bug
Basic background of ray
ray provides an easy-to-use asynchronous execution framework:
def f():
print(1)
import ray
ray.init()
marked_function = ray.remote(f) # mark `f` as a remote function that can be asynchronously executed
handle = marked_function.remote() # schedule a worker to asynchronously execute the function, immediately return a handle
result = ray.get(handle) # synchronously wait for the worker to finish and return the resultThe way it deals with Exception is noteworthy, see comments in the below:
def f():
print(1)
raise RuntimeError("test")
# the following line will not be executed
print(2)
import ray
ray.init()
marked_function = ray.remote(f) # mark `f` as a remote function that can be asynchronously executed
handle = marked_function.remote() # schedule a worker to asynchronously execute the function, immediately return a handle
# ... do other work in the meantime ...
# the main process will not be notified if the worker fails
# only when we call `ray.get` will we be notified of the error
result = ray.get(handle) # raise the error that was thrown in the worker, wrapping it in a RayTaskErrorThe deadlock in distributed inference
The deadlock happens during initialization of distributed inference, i.e. creating process group to collaborate.
A minimal reproducible example looks like this:
import torch
import torch.distributed as dist
def f(rank, world_size, distributed_init_method):
# raise RuntimeError # uncoment this line to see a deadlock
dist.init_process_group(
backend="gloo",
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
)
tensor = torch.zeros(1)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print(f"Rank {rank} has data {tensor.item()}")
import ray
ray.init()
marked_function = ray.remote(f)
distributed_init_method = "tcp://127.0.0.1:29500"
world_size = 2
# start the first process
handle = marked_function.remote(rank=0, world_size=world_size, distributed_init_method=distributed_init_method)
# the main process is the second process
# wait for the first process to join here to initialize the process group for distributed environment
dist.init_process_group(backend="gloo", init_method=distributed_init_method, world_size=world_size, rank=1)
# two processes are ready to communicate
tensor = torch.ones(1)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print(f"Rank 1 has data {tensor.item()}")
result = ray.get(handle)Normally it works with the following output:
2024-03-17 10:24:23,293 INFO worker.py:1724 -- Started a local Ray instance.
Rank 1 has data 1.0
(f pid=14616) Rank 0 has data 1.0
However, if the f function throws an exception before calling dist.init_process_group, it will be kept in an error state, waiting for the main process to call ray.get to error out; meanwhile, the main process is stuck at dist.init_process_group, waiting for the worker process to join to initialize the process group for distributed environment. Together they caused a deadlock.
How is this related with vLLM
vLLM uses ray for distributed inference, and the core code is attached below:
vllm/vllm/executor/ray_gpu_executor.py
Lines 299 to 351 in 6b78837
| def _run_workers( | |
| self, | |
| method: str, | |
| *args, | |
| driver_args: Optional[List[Any]] = None, | |
| driver_kwargs: Optional[Dict[str, Any]] = None, | |
| max_concurrent_workers: Optional[int] = None, | |
| use_ray_compiled_dag: bool = False, | |
| **kwargs, | |
| ) -> Any: | |
| """Runs the given method on all workers.""" | |
| if max_concurrent_workers: | |
| raise NotImplementedError( | |
| "max_concurrent_workers is not supported yet.") | |
| if use_ray_compiled_dag: | |
| # Right now, compiled DAG can only accept a single | |
| # input. TODO(sang): Fix it. | |
| output_channels = self.forward_dag.execute(1) | |
| else: | |
| # Start the ray workers first. | |
| ray_worker_outputs = [ | |
| worker.execute_method.remote(method, *args, **kwargs) | |
| for worker in self.workers | |
| ] | |
| if driver_args is None: | |
| driver_args = args | |
| if driver_kwargs is None: | |
| driver_kwargs = kwargs | |
| # Start the driver worker after all the ray workers. | |
| driver_worker_output = getattr(self.driver_worker, | |
| method)(*driver_args, **driver_kwargs) | |
| # Get the results of the ray workers. | |
| if self.workers: | |
| if use_ray_compiled_dag: | |
| try: | |
| ray_worker_outputs = [ | |
| pickle.loads(chan.begin_read()) | |
| for chan in output_channels | |
| ] | |
| finally: | |
| # Has to call end_read in order to reuse the DAG. | |
| for chan in output_channels: | |
| chan.end_read() | |
| else: | |
| ray_worker_outputs = ray.get(ray_worker_outputs) | |
| return [driver_worker_output] + ray_worker_outputs | |
When calling init_model, both ray worker and the main process will reach the following function:
Lines 71 to 96 in abfc4f3
| def init_model(self, cupy_port: Optional[int] = None) -> None: | |
| if self.device_config.device.type == "cuda": | |
| # torch.distributed.all_reduce does not free the input tensor until | |
| # the synchronization point. This causes the memory usage to grow | |
| # as the number of all_reduce calls increases. This env var disables | |
| # this behavior. | |
| # Related issue: | |
| # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 | |
| os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" | |
| # This env var set by Ray causes exceptions with graph building. | |
| os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) | |
| self.device = torch.device(f"cuda:{self.local_rank}") | |
| torch.cuda.set_device(self.device) | |
| _check_if_gpu_supports_dtype(self.model_config.dtype) | |
| torch.cuda.empty_cache() | |
| self.init_gpu_memory = torch.cuda.mem_get_info()[0] | |
| else: | |
| raise RuntimeError( | |
| f"Not support device type: {self.device_config.device}") | |
| # Initialize the distributed environment. | |
| init_distributed_environment(self.parallel_config, self.rank, | |
| cupy_port, self.distributed_init_method) | |
| # Initialize the model. | |
| set_random_seed(self.model_config.seed) |
And essentially we are back to the minimal reproducible example mentioned before. All of the exception before init_distributed_environment can cause deadlock.
In my case, my GPU driver has some problem, and torch.cuda.set_device raises an exception, causing the deadlock.
Solution to be discussed
Any suggestion to fix this is welcome.
Might be related: #2466 .