3636logger = logging .getLogger (__name__ )
3737
3838
39+ def override_print (enable ):
40+ import builtins as __builtin__
41+
42+ builtin_print = __builtin__ .print
43+
44+ def print (* args , ** kwargs ):
45+ force = kwargs .pop ("force" , False )
46+ if force or enable :
47+ builtin_print (* args , ** kwargs )
48+
49+ __builtin__ .print = print
50+
51+
52+ def override_logger (logger , enable ):
53+ logger_info = logger .info
54+
55+ def info (* args , ** kwargs ):
56+ force = kwargs .pop ("force" , False )
57+ if force or enable :
58+ logger_info (* args , ** kwargs )
59+
60+ logger .info = info
61+
62+
63+ def initialize_distributed_model (args , model , logger , model_dtype ):
64+ override_print (args .global_rank == 0 )
65+ override_logger (logger , args .global_rank == 0 )
66+
67+ import deepspeed
68+
69+ logger .info (f"Initializing DeepSpeed with world size: { args .world_size } " )
70+ deepspeed .init_distributed (
71+ dist_backend = "hccl" ,
72+ verbose = args .global_rank == 0 ,
73+ )
74+ model .eval ()
75+
76+ ds_inference_kwargs = {"dtype" : model_dtype }
77+ ds_inference_kwargs ["tensor_parallel" ] = {"tp_size" : args .world_size }
78+ ds_inference_kwargs ["enable_cuda_graph" ] = args .use_hpu_graphs
79+ ds_inference_kwargs ["injection_policy" ] = {}
80+
81+ model = deepspeed .init_inference (model , ** ds_inference_kwargs ).module
82+
83+ return model
84+
85+
3986def setup_quantization (model , args ):
4087 from neural_compressor .torch .quantization import FP8Config , convert , prepare
4188
@@ -129,6 +176,11 @@ def main():
129176
130177 # set args.quant_config with env variable if it is set
131178 args .quant_config = os .getenv ("QUANT_CONFIG" , "" )
179+
180+ args .local_rank = int (os .getenv ("LOCAL_RANK" , "0" ))
181+ args .world_size = int (os .getenv ("WORLD_SIZE" , "0" ))
182+ args .global_rank = int (os .getenv ("RANK" , "0" ))
183+
132184 os .environ .setdefault ("EXPERIMENTAL_WEIGHT_SHARING" , "FALSE" )
133185 adapt_transformers_to_gaudi ()
134186
@@ -187,6 +239,16 @@ def main():
187239 torch_dtype = model_dtype ,
188240 device = "hpu" ,
189241 )
242+
243+ if args .world_size > 1 :
244+ generator .model = initialize_distributed_model (args , generator .model , logger , model_dtype )
245+
246+ else :
247+ if args .use_hpu_graphs :
248+ from habana_frameworks .torch .hpu import wrap_in_hpu_graph
249+
250+ generator .model = wrap_in_hpu_graph (generator .model )
251+
190252 generate_kwargs = {
191253 "lazy_mode" : True ,
192254 "hpu_graphs" : args .use_hpu_graphs ,
@@ -198,11 +260,6 @@ def main():
198260 if args .use_kv_cache :
199261 generate_kwargs ["use_cache" ] = args .use_kv_cache
200262
201- if args .use_hpu_graphs :
202- from habana_frameworks .torch .hpu import wrap_in_hpu_graph
203-
204- generator .model = wrap_in_hpu_graph (generator .model )
205-
206263 if args .quant_config :
207264 generator .model = setup_quantization (generator .model , args )
208265 htcore .hpu_initialize (generator .model )
0 commit comments