Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"

# whether to build cupy on rocm
ARG BUILD_CUPY="1"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y

Expand Down Expand Up @@ -70,16 +73,33 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
&& cd ..; \
fi

COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install xformers==0.0.23 --no-deps

# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually removed it so that later steps of numpy upgrade can continue
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi

# build cupy
RUN if [ "$BUILD_CUPY" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone -b hipgraph_enablement --recursive https://github.com/ROCm/cupy.git \
&& cd cupy \
&& pip install mpi4py-mpich \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need MPI for CuPy? For NVIDIA GPUs, we use TCP store instead of MPI.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM uses a hack that terminates the TCP store used by cupy right after the cupy nccl backend is initialized:

# Stop the TCP store to prevent the deadlock issues at termination time.
# FIXME(woosuk): This is hacky. Find a more robust solution.
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
_NCCL_BACKEND._store.stop()

I did this because I found that otherwise the worker processes hang when they are terminated. If ROCm cupy uses MPI, then vLLM might need a similar hack to prevent deadlocks at termination time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @WoosukKwon we prefer to use MPI for ROCm Cupy. Is there any specific reason to choose TCP Store instead of MPI from the vLLM side?

Copy link
Collaborator Author

@hongxiayang hongxiayang Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM uses a hack that terminates the TCP store used by cupy right after the cupy nccl backend is initialized:

# Stop the TCP store to prevent the deadlock issues at termination time.
# FIXME(woosuk): This is hacky. Find a more robust solution.
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
_NCCL_BACKEND._store.stop()

I did this because I found that otherwise the worker processes hang when they are terminated. If ROCm cupy uses MPI, then vLLM might need a similar hack to prevent deadlocks at termination time.

@WoosukKwon (1) Can you give more context about this deadlock issue when the processes are terminated? We will need to test it to see whether we see the deadlock, so that we can determine whether we need this "stop" hack, and to verify the patch afterwards if it is needed. (2) What is the reason that the TCP store is used instead of MPI? Is there any performance issue with MPI you observed? As @lcskrishna mentioned, we have tested more on the MPI path.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lcskrishna @hongxiayang , we used TCP store just to reduce dependencies. We can use MPI if you prefer it over TCP store. Although I haven't tested MPI + CuPy on NVIDIA GPUs, I believe it works.

The deadlock issue is that, when the main process is terminated, the process hangs waiting for other processes spawned by cupy TCP store. The _NCCL_BACKEND._store.stop() hack is to avoid this.

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/util.py", line 357, in _exit_function
Process ExceptionAwareProcess-1:
    p.join()
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 38, in join
    super().join()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 43, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 27, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt: 
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 32, in run
    super().run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 87, in _server_loop
    c_socket, addr = s.accept()
  File "/usr/lib/python3.10/socket.py", line 293, in accept
    fd, addr = self._accept()

I'm not sure whether this also happens for the MPI backend. Could you please test it out and see whether it happens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lcskrishna @hongxiayang , we used TCP store just to reduce dependencies. We can use MPI if you prefer it over TCP store. Although I haven't tested MPI + CuPy on NVIDIA GPUs, I believe it works.

The deadlock issue is that, when the main process is terminated, the process hangs waiting for other processes spawned by cupy TCP store. The _NCCL_BACKEND._store.stop() hack is to avoid this.

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/util.py", line 357, in _exit_function
Process ExceptionAwareProcess-1:
    p.join()
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 38, in join
    super().join()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 43, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 27, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt: 
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 32, in run
    super().run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 87, in _server_loop
    c_socket, addr = s.accept()
  File "/usr/lib/python3.10/socket.py", line 293, in accept
    fd, addr = self._accept()

I'm not sure whether this also happens for the MPI backend. Could you please test it out and see whether it happens?

@WoosukKwon Quick question for verification: Regarding "when the main process is terminated"? do you mean it was killed manually before it completes in running throughput benchmarking script?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongxiayang Not really. Without the two-line hack on cupy TCP store, the process hangs when it normally terminates (e.g., after running python examples/llm_engine_example.py -tp 2).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongxiayang Not really. Without the two-line hack on cupy TCP store, the process hangs when it normally terminates (e.g., after running python examples/llm_engine_example.py -tp 2).

@WoosukKwon It seem that we do not need to do anything for this situation.
(1) I tested the examples/llm_engine_example.py -tp 2, it completed fine without deadlock or hang without any patch to call mpi_comm.Abort(). (2) I also tested with throughput benchmarking script and press Ctrl-C in the middle of the script, the script stopped cleanly. (3) I discussed with @lcskrishna earlier, and he also thought that mpi might not need any additional hack.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks for testing!

&& pip install scipy==1.9.3 \
&& pip install cython==0.29.* \
&& env CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \
&& export CUPY_INSTALL_USE_HIP=1 \
&& export ROCM_HOME=/opt/rocm \
&& export HCC_AMDGPU_TARGET="gfx90a,gfx942,gfx1100" \
&& pip install . \
&& cd ..; \
fi

COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install xformers==0.0.23 --no-deps

RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
Expand Down
4 changes: 1 addition & 3 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip


class Worker:
Expand Down Expand Up @@ -269,8 +268,7 @@ def init_distributed_environment(
"cupy.distributed is already initialized but the cupy world "
"size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).")
elif (parallel_config.world_size > 1 and cupy_port is not None
and not is_hip()):
elif (parallel_config.world_size > 1 and cupy_port is not None):
# NOTE(woosuk): We don't initialize CuPy process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
Expand Down