2121
2222import numpy as np
2323import torch
24-
2524from megatron import fused_kernels
2625from megatron import get_adlr_autoresume
2726from megatron import get_args
3029from megatron .global_vars import set_global_variables
3130from megatron .mpu import (set_tensor_model_parallel_rank ,
3231 set_tensor_model_parallel_world_size )
33-
32+ from deepspeed . accelerator import get_accelerator
3433import deepspeed
3534import deepspeed .utils .groups as groups
3635
@@ -46,7 +45,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
4645 """
4746 if not allow_no_cuda :
4847 # Make sure cuda is available.
49- assert torch . cuda . is_available (), 'Megatron requires CUDA .'
48+ assert get_accelerator (). is_available (), 'Megatron requires accelerator .'
5049
5150 # Parse args, build tokenizer, and set adlr-autoresume,
5251 # tensorboard-writer, and timers.
@@ -107,7 +106,10 @@ def _compile_dependencies():
107106 compile_helper ()
108107 print ('>>> done with dataset index builder. Compilation time: {:.3f} '
109108 'seconds' .format (time .time () - start_time ), flush = True )
110-
109+
110+ if not get_accelerator ().device_name () == 'cuda' :
111+ print (">fused kernel is only supported in cuda, skip loading fused kernel" )
112+ return
111113 # ==================
112114 # Load fused kernels
113115 # ==================
@@ -134,7 +136,7 @@ def _compile_dependencies():
134136 if _is_rank_0 ():
135137 start_time = time .time ()
136138 print ('> compiling and loading fused kernels ...' , flush = True )
137- if torch . cuda .device_count () > 0 : # Skip when CPU-only
139+ if get_accelerator () .device_count () > 0 : # Skip when CPU-only
138140 fused_kernels .load (args )
139141 torch .distributed .barrier ()
140142 else :
@@ -185,7 +187,7 @@ def setup_deepspeed_random_and_activation_checkpointing(args):
185187def _initialize_distributed ():
186188 """Initialize torch.distributed and mpu."""
187189 args = get_args ()
188- device_count = torch . cuda .device_count ()
190+ device_count = get_accelerator () .device_count ()
189191 if torch .distributed .is_initialized ():
190192
191193 if args .rank == 0 :
@@ -206,7 +208,7 @@ def _initialize_distributed():
206208 else :
207209 args .local_rank = device
208210
209- torch . cuda .set_device (device ) # only do so when device_count > 0
211+ get_accelerator () .set_device (device ) # only do so when device_count > 0
210212
211213 # Call the init process
212214 init_method = 'tcp://'
@@ -249,14 +251,14 @@ def _set_random_seed(seed_):
249251 if seed_ is not None and seed_ > 0 :
250252 # Ensure that different pipeline MP stages get different seeds.
251253 # No need to do so for CPU-only case.
252- if torch . cuda .device_count () == 0 :
254+ if get_accelerator () .device_count () == 0 :
253255 seed = seed_
254256 else :
255257 seed = seed_ + (100 * mpu .get_pipeline_model_parallel_rank ())
256258 random .seed (seed )
257259 np .random .seed (seed )
258260 torch .manual_seed (seed )
259- if torch . cuda .device_count () > 0 :
261+ if get_accelerator () .device_count () > 0 :
260262 mpu .model_parallel_cuda_manual_seed (seed )
261263 else :
262264 raise ValueError ('Seed ({}) should be a positive integer.' .format (seed ))
@@ -284,7 +286,7 @@ def _is_rank_0():
284286 """Check whether it is rank 0. For AML, check if it is rank 0 of a node"""
285287 if torch .distributed .is_initialized ():
286288 if torch .distributed .get_rank () == 0 or (
287- 'AZUREML_EXPERIMENT_ID' in os .environ and torch .distributed .get_rank () % torch . cuda .device_count () == 0
289+ 'AZUREML_EXPERIMENT_ID' in os .environ and torch .distributed .get_rank () % get_accelerator () .device_count () == 0
288290 ):
289291 return True
290292 else :
0 commit comments