Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 77 additions & 18 deletions examples/offline_inference/rlhf_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@
https://docs.ray.io/en/latest/placement-groups.html
"""

import gc
import os

import ray
import torch
import zmq
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.multiprocessing.reductions import reduce_tensor

from vllm import LLM

Expand Down Expand Up @@ -86,20 +89,72 @@ def __init__(self):
from vllm.platforms import current_platform

self.device_uuid = current_platform.get_device_uuid(0)
self.zmq_context = zmq.Context()
self.zmq_address_counter = 0
self.zmq_handle = None

def report_device_id(self) -> str:
return self.device_uuid

def get_weight_ipc_handles(self):
from torch.multiprocessing.reductions import reduce_tensor
def get_zmq_handles(self) -> dict[str, str]:
suffix = f"{self.device_uuid}-{self.zmq_address_counter}"
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock"
self.zmq_address_counter += 1
return {self.device_uuid: self.zmq_handle}

data = {}
for name, p in self.model.named_parameters():
# A training actor might hold only a subset of the weights and may
# need to gather weights from other actors. For demonstration
# purposes, each training actor owns the full weight set.
data[name] = reduce_tensor(p.detach())
return {self.device_uuid: data}
def update_weights(self):
# align size to avoid misaligned address
align_size = 256

def get_size(p: torch.Tensor) -> int:
return (p.nbytes + align_size - 1) // align_size * align_size

named_parameters: dict[str, torch.nn.Parameter] = dict(
self.model.named_parameters()
)
max_tensor_size = max(get_size(p) for p in named_parameters.values())
# use max_tensor_size * 2 as buffer size
buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0")
s = self.zmq_context.socket(zmq.REQ)
s.bind(self.zmq_handle)
handle = reduce_tensor(buffer)

offset = 0
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
named_tensors: list[dict] = []
real_tensors: list[torch.Tensor] = []
for name, p in named_parameters.items():
size = get_size(p)
if offset + size > buffer.numel():
buckets.append((named_tensors, real_tensors))
named_tensors, real_tensors = [], []
offset = 0
# assume tensors are contiguous
named_tensors.append(
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
)
real_tensors.append(p)
offset += size
if named_tensors:
buckets.append((named_tensors, real_tensors))
s.send_pyobj(handle)
s.recv()
for named_tensors, real_tensors in buckets:
offset = 0
for p in real_tensors:
buffer[offset : offset + p.nbytes].data.copy_(
p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
)
offset += get_size(p)
torch.cuda.synchronize()
s.send_pyobj(named_tensors)
s.recv()
s.send_pyobj(None)
s.recv()
s.close()
del buffer
gc.collect()
torch.cuda.empty_cache()


# Ray manages four GPUs.
Expand Down Expand Up @@ -175,18 +230,22 @@ def get_weight_ipc_handles(self):
# the second inference engine.
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]

print("Gather all the IPC handles from the training actors.")
ipc_handles = {}
print("Gather all the ZMQ handles from the training actors.")
zmq_handles = {}
for actor in training_actors:
ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))

print(f"ZMQ handles: {zmq_handles}")

print("Update the weights of the inference engines.")
for llm in inference_engines:
ray.get(
llm.collective_rpc.remote(
"update_weights_from_ipc_handles", args=(ipc_handles,)
)
)
ray.get(
[actor.update_weights.remote() for actor in training_actors]
+ [
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
for llm in inference_engines
]
)

print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
90 changes: 75 additions & 15 deletions examples/offline_inference/rlhf_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from typing import Callable, Optional, TypedDict

import torch
import zmq


def stateless_init_process_group(master_address, master_port, rank, world_size, device):
Expand Down Expand Up @@ -66,6 +70,27 @@ def check_weights_changed(self):
return weights_updated


def rebuild_ipc(
handle: tuple[Callable, tuple], device_id: Optional[int] = None
) -> torch.Tensor:
func, args = handle
list_args = list(args)
if device_id is not None:
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
buffer = func(*list_args)
return buffer


class FlattenedTensorMetadata(TypedDict):
name: str
shape: torch.Size
dtype: torch.dtype
# specify the start offset of this tensor in shared ipc_buffer tensor
offset: int


class ColocateWorkerExtension:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
Expand All @@ -76,27 +101,62 @@ class ColocateWorkerExtension:
should pass the full qualified name as `worker_extension_cls` argument.
"""

def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
from vllm.model_executor.model_loader.utils import process_weights_after_loading

assert self.device is not None
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
socket = self._zmq_ctx.socket(zmq.REP)
socket.connect(zmq_handles[self.report_device_id()])
buffer: Optional[torch.Tensor] = None
while True:
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = (
socket.recv_pyobj()
)
if payload is None:
# means the update is done
process_weights_after_loading(
self.model_runner.model, self.model_config, self.device
)
torch.cuda.synchronize()
socket.send(b"")
break
if isinstance(payload, tuple):
# an ipc handle that vLLM can use `func, args = handle`
# and `func(*args)` to rebuild GPU tensor.
buffer = rebuild_ipc(payload, self.device.index)
assert buffer.dtype == torch.uint8
socket.send(b"")
continue
assert isinstance(payload, list)
assert buffer is not None
weights = []
for item in payload:
shape = item["shape"]
if isinstance(shape, (list, tuple)):
shape = torch.Size(shape)
assert isinstance(shape, torch.Size)
dtype, offset = item["dtype"], item["offset"]
size = dtype.itemsize * shape.numel()
tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
weights.append((item["name"], tensor))
self.model_runner.model.load_weights(weights=weights)
del weights
torch.cuda.synchronize()
socket.send(b"")

socket.close()
del buffer
gc.collect()
torch.cuda.empty_cache()

def report_device_id(self) -> str:
from vllm.platforms import current_platform

self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid

def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
Expand Down