1+ import importlib .util
12from typing import Optional , Union , ClassVar
23from dataclasses import dataclass
34import os
@@ -376,9 +377,9 @@ class ParallelConfig:
376377 Args:
377378 pipeline_parallel_size: Number of pipeline parallel groups.
378379 tensor_parallel_size: Number of tensor parallel groups.
379- worker_use_ray: Whether to use Ray for model workers. Will be set to
380+ worker_use_ray: Whether to use Ray for model workers. Will default to
380381 True if either pipeline_parallel_size or tensor_parallel_size is
381- greater than 1.
382+ greater than 1 and Ray is installed .
382383 max_parallel_loading_workers: Maximum number of multiple batches
383384 when load model sequentially. To avoid RAM OOM when using tensor
384385 parallel and large models.
@@ -392,7 +393,7 @@ def __init__(
392393 self ,
393394 pipeline_parallel_size : int ,
394395 tensor_parallel_size : int ,
395- worker_use_ray : bool ,
396+ worker_use_ray : Optional [ bool ] = None ,
396397 max_parallel_loading_workers : Optional [int ] = None ,
397398 disable_custom_all_reduce : bool = False ,
398399 ray_workers_use_nsight : bool = False ,
@@ -412,9 +413,10 @@ def __init__(
412413 self .ray_workers_use_nsight = ray_workers_use_nsight
413414
414415 self .world_size = pipeline_parallel_size * self .tensor_parallel_size
415- # Ray worker is not supported for Neuron backend.
416- if self .world_size > 1 and not is_neuron ():
417- self .worker_use_ray = True
416+ if self .worker_use_ray is None :
417+ ray_found = importlib .util .find_spec ("ray" ) is not None
418+ self .worker_use_ray = ray_found and self .world_size > 1
419+
418420 self ._verify_args ()
419421
420422 def _verify_args (self ) -> None :
@@ -498,12 +500,12 @@ class DeviceConfig:
498500 def __init__ (self , device : str = "auto" ) -> None :
499501 if device == "auto" :
500502 # Automated device type detection
501- if torch .cuda .is_available ():
502- self .device_type = "cuda"
503- elif is_neuron ():
503+ if is_neuron ():
504504 self .device_type = "neuron"
505505 else :
506- raise RuntimeError ("No supported device detected." )
506+ # We don't call torch.cuda.is_available() here to
507+ # avoid initializing CUDA before workers are forked
508+ self .device_type = "cuda"
507509 else :
508510 # Device type is assigned explicitly
509511 self .device_type = device
0 commit comments