Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,18 +699,25 @@ def send_object(self, obj: Any, dst: int) -> None:
)

# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
device=torch.cuda.current_device()
)

size_tensor = torch.tensor(
[object_tensor.numel()], dtype=torch.long, device="cpu"
[object_tensor.numel()],
dtype=torch.long,
device=torch.cuda.current_device(),
)

# Send object size

torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
torch.distributed.send(
size_tensor, dst=self.ranks[dst], group=self.device_group
)

# Send object
torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
torch.distributed.send(
object_tensor, dst=self.ranks[dst], group=self.device_group
)

return None

Expand All @@ -724,29 +731,31 @@ def recv_object(self, src: int) -> Any:
src != self.rank_in_group
), "Invalid source rank. Source rank is the same as the current rank."

size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
size_tensor = torch.empty(
1, dtype=torch.long, device=torch.cuda.current_device()
)

# Receive object size
rank_size = torch.distributed.recv(
size_tensor, src=self.ranks[src], group=self.cpu_group
size_tensor, src=self.ranks[src], group=self.device_group
)

# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu",
device=torch.cuda.current_device(),
)

rank_object = torch.distributed.recv(
object_tensor, src=self.ranks[src], group=self.cpu_group
object_tensor, src=self.ranks[src], group=self.device_group
)

assert (
rank_object == rank_size
), "Received object sender rank does not match the size sender rank."

obj = pickle.loads(object_tensor.numpy().tobytes())
obj = pickle.loads(object_tensor.cpu().numpy().tobytes())

return obj

Expand Down Expand Up @@ -857,14 +866,16 @@ def send_tensor_dict(
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"

metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
# Note: While switching to Device-to-Device (D2D) would introduce an extra
# Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks
# show better overall transmission performance with D2D due to:
# 1. Superior D2D transfer bandwidth
# 2. Ability to overlap send and recv operations
# Thus the net performance gain justifies this approach.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def event_loop_pp(self):
point_to_point_pyobj(
recv_reqs,
self.pp_rank * self.tp_size + dp_offset,
self.world_group.cpu_group,
self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset,
(self.pp_rank + 1) * self.tp_size + dp_offset,
)
Expand Down Expand Up @@ -975,7 +975,7 @@ def recv_requests(self) -> List[Req]:
recv_reqs = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
self.world_group.cpu_group,
self.world_group.device_group,
(self.pp_rank - 1) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset,
)
Expand Down
24 changes: 18 additions & 6 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,36 +1000,48 @@ def point_to_point_pyobj(
src: int = 0,
dst: int = 1,
):
"""Send data from src to dst in group."""
"""Send data from src to dst in group using DeviceToDevice communication."""

if rank == src:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
dist.send(tensor_size, dst=dst, group=group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8)
).cuda(
device=torch.cuda.current_device()
) # Move to GPU
tensor_size = torch.tensor(
[size], dtype=torch.long, device=torch.cuda.current_device()
)
tensor_size = torch.tensor([size], dtype=torch.long)

dist.send(tensor_size, dst=dst, group=group)
dist.send(tensor_data, dst=dst, group=group)
return data

elif rank == dst:
tensor_size = torch.tensor([0], dtype=torch.long)
tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
dist.recv(tensor_size, src=src, group=group)
size = tensor_size.item()

if size == 0:
return []

tensor_data = torch.empty(size, dtype=torch.uint8)
tensor_data = torch.empty(
size, dtype=torch.uint8, device=torch.cuda.current_device()
)
dist.recv(tensor_data, src=src, group=group)

serialized_data = bytes(tensor_data.cpu().numpy())
serialized_data = bytes(
tensor_data.cpu().numpy()
) # Move back to host for deserialization
data = pickle.loads(serialized_data)
return data

Expand Down
Loading