Skip to content

Commit a0fff5a

Browse files
committed
refactor code
1 parent 040f3a0 commit a0fff5a

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

verl/trainer/fsdp_sft_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ def _build_model_optimizer(self):
204204
apply_monkey_patch(config, verbose=True)
205205

206206
# This may be very large
207-
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings)
207+
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings,
208+
mesh=self.device_mesh)
208209

209210
with init_context():
210211
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,

verl/workers/fsdp_workers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def _build_critic_model_optimizer(self, config):
685685
from verl.models.transformers.monkey_patch import apply_monkey_patch
686686
apply_monkey_patch(critic_model_config, verbose=True)
687687

688-
init_context = get_init_weight_context_manager()
688+
init_context = get_init_weight_context_manager(True, mesh=self.device_mesh)
689689
with init_context(), warnings.catch_warnings():
690690
warnings.simplefilter("ignore")
691691
setattr(critic_model_config, 'classifier_dropout', 0.)
@@ -944,7 +944,8 @@ def _build_model(self, config):
944944
apply_monkey_patch(model_config, verbose=True)
945945

946946
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
947-
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings)
947+
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings,
948+
mesh=self.device_mesh)
948949

949950
with init_context(), warnings.catch_warnings():
950951
warnings.simplefilter("ignore")

0 commit comments

Comments
 (0)