Skip to content

Commit cc466a3

Browse files
authored
[Core][Distributed] support cpu&device in broadcast tensor dict (vllm-project#4660)
[Core][Distributed] support both cpu and device tensor in broadcast tensor dict (vllm-project#4660)
1 parent 8344f77 commit cc466a3

File tree

2 files changed

+41
-22
lines changed

2 files changed

+41
-22
lines changed

tests/distributed/test_comm_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,18 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
7777
init_test_distributed_environment(1, tensor_parallel_size, rank,
7878
distributed_init_port)
7979
test_dict = {
80+
# device tensor
8081
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
81-
"b": torch.arange(16, dtype=torch.int8, device="cuda"),
82+
# CPU tensor
83+
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
8284
"c": "test",
8385
"d": [1, 2, 3],
8486
"e": {
8587
"a": 1,
8688
"b": 2
8789
},
90+
# empty tensor
91+
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
8892
}
8993

9094
if rank == 0:
@@ -97,6 +101,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
97101
assert recv_dict["c"] == test_dict["c"]
98102
assert recv_dict["d"] == test_dict["d"]
99103
assert recv_dict["e"] == test_dict["e"]
104+
assert torch.allclose(recv_dict["f"], test_dict["f"])
100105

101106

102107
@pytest.mark.skipif(torch.cuda.device_count() < 2,

vllm/distributed/communication_op.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def broadcast_object_list(obj_list: List[Any],
137137
return obj_list
138138

139139

140-
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
140+
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
141141

142142

143143
def _split_tensor_dict(
@@ -152,15 +152,13 @@ def _split_tensor_dict(
152152
tensor_list = []
153153
for key, value in tensor_dict.items():
154154
if isinstance(value, torch.Tensor):
155-
# Note(youkaichao): currently this only supports broadcasting
156-
# tensors on cuda. In the future, we can add device as a field in
157-
# TensorMetadata to support broadcasting tensors on different
158-
# devices.
159-
assert value.is_cuda, (
160-
f"Tensor {key}: {value} is not on cuda. Currently we only "
161-
f"support broadcasting tensors on cuda.")
162-
metadata_list.append((key, TensorMetadata(value.dtype,
163-
value.size())))
155+
# Note: we cannot use `value.device` here,
156+
# because it contains not only the device type but also the device
157+
# index (e.g. "cuda:0"). We only need the device type.
158+
# receiving side will set the device index.
159+
device = "cpu" if value.is_cpu else "cuda"
160+
metadata_list.append(
161+
(key, TensorMetadata(device, value.dtype, value.size())))
164162
tensor_list.append(value)
165163
else:
166164
metadata_list.append((key, value))
@@ -206,11 +204,19 @@ def broadcast_tensor_dict(
206204
if tensor.numel() == 0:
207205
# Skip broadcasting empty tensors.
208206
continue
209-
async_handles.append(
210-
torch.distributed.broadcast(tensor,
211-
src=src,
212-
group=group,
213-
async_op=True))
207+
if tensor.is_cpu:
208+
# use metadata_group for CPU tensors
209+
handle = torch.distributed.broadcast(tensor,
210+
src=src,
211+
group=metadata_group,
212+
async_op=True)
213+
else:
214+
# use group for GPU tensors
215+
handle = torch.distributed.broadcast(tensor,
216+
src=src,
217+
group=group,
218+
async_op=True)
219+
async_handles.append(handle)
214220
for async_handle in async_handles:
215221
async_handle.wait()
216222

@@ -226,16 +232,24 @@ def broadcast_tensor_dict(
226232
if isinstance(value, TensorMetadata):
227233
tensor = torch.empty(value.size,
228234
dtype=value.dtype,
229-
device="cuda")
235+
device=value.device)
230236
if tensor.numel() == 0:
231237
# Skip broadcasting empty tensors.
232238
tensor_dict[key] = tensor
233239
continue
234-
async_handle = torch.distributed.broadcast(tensor,
235-
src=src,
236-
async_op=True,
237-
group=group)
238-
async_handles.append(async_handle)
240+
if tensor.is_cpu:
241+
# use metadata_group for CPU tensors
242+
handle = torch.distributed.broadcast(tensor,
243+
src=src,
244+
group=metadata_group,
245+
async_op=True)
246+
else:
247+
# use group for GPU tensors
248+
handle = torch.distributed.broadcast(tensor,
249+
src=src,
250+
group=group,
251+
async_op=True)
252+
async_handles.append(handle)
239253
tensor_dict[key] = tensor
240254
else:
241255
tensor_dict[key] = value

0 commit comments

Comments
 (0)