File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -204,7 +204,8 @@ def _build_model_optimizer(self):
204204 apply_monkey_patch (config , verbose = True )
205205
206206 # This may be very large
207- init_context = get_init_weight_context_manager (use_meta_tensor = not config .tie_word_embeddings )
207+ init_context = get_init_weight_context_manager (use_meta_tensor = not config .tie_word_embeddings ,
208+ mesh = self .device_mesh )
208209
209210 with init_context ():
210211 self .model : PreTrainedModel = AutoModelForCausalLM .from_pretrained (local_model_path ,
Original file line number Diff line number Diff line change @@ -685,7 +685,7 @@ def _build_critic_model_optimizer(self, config):
685685 from verl .models .transformers .monkey_patch import apply_monkey_patch
686686 apply_monkey_patch (critic_model_config , verbose = True )
687687
688- init_context = get_init_weight_context_manager ()
688+ init_context = get_init_weight_context_manager (True , mesh = self . device_mesh )
689689 with init_context (), warnings .catch_warnings ():
690690 warnings .simplefilter ("ignore" )
691691 setattr (critic_model_config , 'classifier_dropout' , 0. )
@@ -944,7 +944,8 @@ def _build_model(self, config):
944944 apply_monkey_patch (model_config , verbose = True )
945945
946946 # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
947- init_context = get_init_weight_context_manager (use_meta_tensor = not model_config .tie_word_embeddings )
947+ init_context = get_init_weight_context_manager (use_meta_tensor = not model_config .tie_word_embeddings ,
948+ mesh = self .device_mesh )
948949
949950 with init_context (), warnings .catch_warnings ():
950951 warnings .simplefilter ("ignore" )
You can’t perform that action at this time.
0 commit comments