-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
Description
Motivation.
In current colocate RL, train and inference will occupy GPU memory successively. After the train finished, it should send model weights to the inference engine and update model inplace.
Currently, vLLM does not support inplace model update method. In sglang, it serializes tensors into IPC handles in training engines and rebuild tensors from IPC handles in inference engines. This operation may cause overhead due to frequently serializing and deserializing tensors into IPC handles and sending large pickled data between inference engines.
This RFC proposes that we only need to create one GPU tensor from IPC handle in each device and share data from it, which makes model update much faster than serializing and deserializing tensors into IPC handles in each request.
Proposed Change.
Design
A model update flow is combined with multiple update requests. vLLM will expose an http interface called /v1/update-weights-from-ipc to handle each request.
When sending the first model update request, an external field handles should be added into the request. When vLLM gets handles, it will rebuild it as a shared buffer tensor and save it as an attribute. vLLM will use this shared buffer tensor to receive data from trains.
Trains will copy tensor data into this channel and send update_weights_from_ipc request to vLLM, when receiving update request, vLLM will use data from shared channel to update weight inplace.
While handling the last update request, an external field end=True should be added, which indicates that vLLM should remove the shared buffer tensor from IPC to release GPU memory. Thus a full update process is finished.
A single update request is designed below
class UpdateWeightFromIPCRequest:
# a list of tuple to specify tensor metadata
# the info in tuple is [name, dtype, shape]
named_tensors: list[tuple[str, torch.dtype, torch.Size]]
# dict key is device_uuid, could get my own from `current_platform.get_device_uuid(self.device.index)` in vLLM
# dict value is a serialized ipc `handle`, vLLM can use `func, args = handle` and `func(*args)` to rebuild GPU tensor
# if `handles` is not None, means this is the first request in current update flow
# vLLM should rebuild and save this GPU tensor as a shared buffer
handles: dict[str, tuple[Callable, tuple]] | None
# specify the start offset of named_tensors in ipc_buffer tensor
offset: int
# specify whether this request is the last request in current update flow
end: boolThe update implementation can be described in worker like below
class Worker(WorkerBase):
...
def update_weights_from_ipc(
self,
named_tensors: list[tuple[str, torch.dtype, torch.Size, int]],
handles: dict[str, tuple[Callable, tuple]] | None,
offset: int,
end: bool,
):
device_id = self.device.index
BUF_ATTR_NAME = '_shared_ipc_buffer'
buffer: torch.Tensor
if handles is not None:
buffer = rebuild_ipc(handles[self.device_uuid], device_id)
assert buffer.dtype == torch.uint8
setattr(self, BUF_ATTR_NAME, buffer)
else:
assert hasattr(self, BUF_ATTR_NAME)
buffer = getattr(self, BUF_ATTR_NAME)
assert buffer is not None
weights = []
for name, dtype, shape in named_tensors:
if isinstance(shape, (list, tuple)):
shape = torch.Size(shape)
assert isinstance(shape, torch.Size)
size = dtype.itemsize * shape.numel()
tensor = buffer[offset:offset + size].view(dtype=dtype).view(shape)
weights.append((name, tensor))
offset += size
self.model_runner.model.load_weights(weights=weights)
del weights
if end:
process_weights_after_loading(self.model_runner.model,
self.model_config, self.device)
if hasattr(self, BUF_ATTR_NAME):
delattr(self, BUF_ATTR_NAME)
torch.cuda.synchronize()
torch.cuda.empty_cache()Practice
By copying weights into shared buffer, it's convenient for training clients to write a pipeline to accelerate weight update. In the Kimi-K2 report, we wrote a two-stage pipeline like below
The trainer will broadcast data into half of the shared IPC buffer. At the same time, weight update will use the other half of the shared IPC buffer, which makes the two operators parallel.
By using the /v1/update-weights-from-ipc interface in vLLM and pipeline in client, it can perform less than 20s to update all 1T weights in vLLM for Kimi-K2 when deployed in thounsands of GPU devices.
Feedback Period.
No response
CC List.
No response
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.