@@ -137,7 +137,7 @@ def broadcast_object_list(obj_list: List[Any],
137137 return obj_list
138138
139139
140- TensorMetadata = namedtuple ("TensorMetadata" , ["dtype" , "size" ])
140+ TensorMetadata = namedtuple ("TensorMetadata" , ["device" , " dtype" , "size" ])
141141
142142
143143def _split_tensor_dict (
@@ -152,15 +152,13 @@ def _split_tensor_dict(
152152 tensor_list = []
153153 for key , value in tensor_dict .items ():
154154 if isinstance (value , torch .Tensor ):
155- # Note(youkaichao): currently this only supports broadcasting
156- # tensors on cuda. In the future, we can add device as a field in
157- # TensorMetadata to support broadcasting tensors on different
158- # devices.
159- assert value .is_cuda , (
160- f"Tensor { key } : { value } is not on cuda. Currently we only "
161- f"support broadcasting tensors on cuda." )
162- metadata_list .append ((key , TensorMetadata (value .dtype ,
163- value .size ())))
155+ # Note: we cannot use `value.device` here,
156+ # because it contains not only the device type but also the device
157+ # index (e.g. "cuda:0"). We only need the device type.
158+ # receiving side will set the device index.
159+ device = "cpu" if value .is_cpu else "cuda"
160+ metadata_list .append (
161+ (key , TensorMetadata (device , value .dtype , value .size ())))
164162 tensor_list .append (value )
165163 else :
166164 metadata_list .append ((key , value ))
@@ -206,11 +204,19 @@ def broadcast_tensor_dict(
206204 if tensor .numel () == 0 :
207205 # Skip broadcasting empty tensors.
208206 continue
209- async_handles .append (
210- torch .distributed .broadcast (tensor ,
211- src = src ,
212- group = group ,
213- async_op = True ))
207+ if tensor .is_cpu :
208+ # use metadata_group for CPU tensors
209+ handle = torch .distributed .broadcast (tensor ,
210+ src = src ,
211+ group = metadata_group ,
212+ async_op = True )
213+ else :
214+ # use group for GPU tensors
215+ handle = torch .distributed .broadcast (tensor ,
216+ src = src ,
217+ group = group ,
218+ async_op = True )
219+ async_handles .append (handle )
214220 for async_handle in async_handles :
215221 async_handle .wait ()
216222
@@ -226,16 +232,24 @@ def broadcast_tensor_dict(
226232 if isinstance (value , TensorMetadata ):
227233 tensor = torch .empty (value .size ,
228234 dtype = value .dtype ,
229- device = "cuda" )
235+ device = value . device )
230236 if tensor .numel () == 0 :
231237 # Skip broadcasting empty tensors.
232238 tensor_dict [key ] = tensor
233239 continue
234- async_handle = torch .distributed .broadcast (tensor ,
235- src = src ,
236- async_op = True ,
237- group = group )
238- async_handles .append (async_handle )
240+ if tensor .is_cpu :
241+ # use metadata_group for CPU tensors
242+ handle = torch .distributed .broadcast (tensor ,
243+ src = src ,
244+ group = metadata_group ,
245+ async_op = True )
246+ else :
247+ # use group for GPU tensors
248+ handle = torch .distributed .broadcast (tensor ,
249+ src = src ,
250+ group = group ,
251+ async_op = True )
252+ async_handles .append (handle )
239253 tensor_dict [key ] = tensor
240254 else :
241255 tensor_dict [key ] = value
0 commit comments