Skip to content

Conversation

@youkaichao
Copy link
Member

@youkaichao youkaichao commented Apr 22, 2024

nccl initialization requires broadcasting a unique id, which lives in cpu memory. previously, we only have one nccl backend process group, so we have to move the unique id to gpu, broadcast it, and then move it back to cpu.

After #3904 , we always have a cpu/gloo backend, so we don't need to move the unique id around. We can just broadcast it in cpu memory.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left some small comments.

Comment on lines +250 to +258
current_device = torch.cuda.current_device()
try:
torch.cuda.set_device(device)
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
finally:
torch.cuda.set_device(current_device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a try...finally block here? Can the program continue to run when there is an exception in the try block?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can, but I think it would be better for a function to be pure, i.e. don't implicitly modify some global state.

@youkaichao youkaichao merged commit 91f50a6 into vllm-project:main Apr 24, 2024
@youkaichao youkaichao deleted the pynccl_init_improve branch April 24, 2024 01:32
xjpang pushed a commit to xjpang/vllm that referenced this pull request Apr 25, 2024
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 26, 2024
akondrat-amd pushed a commit to akondrat-amd/ci-vllm that referenced this pull request May 1, 2024
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants