Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,14 +699,14 @@ def send_object(self, obj: Any, dst: int) -> None:
)

# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).to(
device=self.device
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
device=torch.cuda.current_device()
Comment on lines +702 to +703
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider caching the result of torch.cuda.current_device() to a local variable to avoid redundant calls. This can improve readability and potentially performance.

Suggested change
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
device=torch.cuda.current_device()
current_device = torch.cuda.current_device()
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
device=current_device
)

)

size_tensor = torch.tensor(
[object_tensor.numel()],
dtype=torch.long,
device=self.device,
device=torch.cuda.current_device(),
)

# Send object size
Expand All @@ -731,7 +731,9 @@ def recv_object(self, src: int) -> Any:
src != self.rank_in_group
), "Invalid source rank. Source rank is the same as the current rank."

size_tensor = torch.empty(1, dtype=torch.long, device=self.device)
size_tensor = torch.empty(
1, dtype=torch.long, device=torch.cuda.current_device()
Comment on lines +734 to +735
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider caching the result of torch.cuda.current_device() to a local variable to avoid redundant calls. This can improve readability and potentially performance.

Suggested change
size_tensor = torch.empty(
1, dtype=torch.long, device=torch.cuda.current_device()
current_device = torch.cuda.current_device()
size_tensor = torch.empty(
1, dtype=torch.long, device=current_device
)

)

# Receive object size
rank_size = torch.distributed.recv(
Expand All @@ -742,7 +744,7 @@ def recv_object(self, src: int) -> Any:
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device=self.device,
device=torch.cuda.current_device(),
)

rank_object = torch.distributed.recv(
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,6 @@ def event_loop_pp(self):
self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset,
(self.pp_rank + 1) * self.tp_size + dp_offset,
device=self.device,
)

# send out proxy tensors to the next stage
Expand Down Expand Up @@ -1024,7 +1023,6 @@ def recv_requests(self) -> List[Req]:
self.world_group.device_group,
(self.pp_rank - 1) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset,
device=self.device,
)
else:
recv_reqs = None
Expand Down Expand Up @@ -1055,15 +1053,13 @@ def recv_requests(self) -> List[Req]:
self.attn_tp_group.rank,
self.attn_tp_cpu_group,
src=self.attn_tp_group.ranks[0],
device=self.device,
)
if self.tp_size != 1:
control_reqs = broadcast_pyobj(
control_reqs,
self.tp_group.rank,
self.tp_cpu_group,
src=self.tp_group.ranks[0],
device=self.device,
)
recv_reqs = work_reqs + control_reqs
elif self.tp_size != 1:
Expand All @@ -1072,7 +1068,6 @@ def recv_requests(self) -> List[Req]:
self.tp_group.rank,
self.tp_cpu_group,
src=self.tp_group.ranks[0],
device=self.device,
)
return recv_reqs

Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def __init__(
self.tp_size * self.pp_rank + tp_rank,
self.world_group.cpu_group,
src=self.world_group.ranks[0],
device=self.device,
)[0]
set_random_seed(self.random_seed)

Expand Down
34 changes: 20 additions & 14 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,15 +1100,15 @@ def broadcast_pyobj(
rank: int,
dist_group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0,
device: Optional[str] = None,
force_cpu_device: bool = True,
):
"""Broadcast inputs from src rank to all other ranks with torch.dist backend.
The `rank` here refer to the source rank on global process group (regardless
of dist_group argument).
"""

if device is None:
device = get_device()
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
)

if rank == src:
if len(data) == 0:
Expand Down Expand Up @@ -1148,38 +1148,44 @@ def point_to_point_pyobj(
group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0,
dst: int = 1,
device: Optional[str] = None,
):
"""Send data from src to dst in group using DeviceToDevice communication."""
if device is None:
device = get_device()

if rank == src:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
dist.send(tensor_size, dst=dst, group=group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8)
).to(
device=device
) # Move to Device
tensor_size = torch.tensor([size], dtype=torch.long, device=device)
).cuda(
device=torch.cuda.current_device()
Comment on lines +1165 to +1166
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider caching the result of torch.cuda.current_device() to a local variable to avoid redundant calls. This can improve readability and potentially performance.

            current_device = torch.cuda.current_device()
            tensor_data = torch.ByteTensor(
                np.frombuffer(serialized_data, dtype=np.uint8)
            ).cuda(
                device=current_device
            )  # Move to GPU

) # Move to GPU
tensor_size = torch.tensor(
[size], dtype=torch.long, device=torch.cuda.current_device()
Comment on lines +1168 to +1169
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider caching the result of torch.cuda.current_device() to a local variable to avoid redundant calls. This can improve readability and potentially performance.

            tensor_size = torch.tensor(
                [size], dtype=torch.long, device=current_device
            )

)

dist.send(tensor_size, dst=dst, group=group)
dist.send(tensor_data, dst=dst, group=group)
return data

elif rank == dst:
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
Comment on lines +1177 to +1178
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider caching the result of torch.cuda.current_device() to a local variable to avoid redundant calls. This can improve readability and potentially performance.

        tensor_size = torch.tensor(
            [0], dtype=torch.long, device=current_device
        )

)
dist.recv(tensor_size, src=src, group=group)
size = tensor_size.item()

if size == 0:
return []

tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
tensor_data = torch.empty(
size, dtype=torch.uint8, device=torch.cuda.current_device()
)
dist.recv(tensor_data, src=src, group=group)

serialized_data = bytes(
Expand Down
Loading