Skip to content

Commit a1dd922

Browse files
authored
fix: set torch dtype to auto (#749)
### Issue As transformers `from_pretrained()` automatically load ckpt in the default torch.dtype (here, fp32) and not follow the dtype of the weights themselve, it often causes high CPU memory usage even OOM, which crushes the training. ### Fix Follow by huggingface/transformers#34743 and tested on loading Llama-3.1-405B-Instruct. ``` - model = AutoModelForCausalLM.from_pretrained(model_file) + model = AutoModelForCausalLM.from_pretrained(model_file, torch_dtype="auto") ``` ### Potential problem: Some CPUs can not support `bfloat16` and it might take some time to convert back to `float32`. Co-authored-by: Changlong Yu <[email protected]>
1 parent fc27caf commit a1dd922

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

verl/utils/checkpoint/megatron_checkpoint_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
267267
state_dict['score.weight'] = state_dict['score.weight']
268268
else:
269269
from transformers import AutoModelForCausalLM
270-
model = AutoModelForCausalLM.from_pretrained(self.config.model.path)
270+
model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto")
271271
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
272272
if hdfs_path is not None:
273273
print(f'Uploading checkpoint to {hdfs_path}')

verl/utils/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def load_megatron_model_weights(config,
316316
'model.embed_tokens.weight'][:32000] # workaround, 32001 -> 32000
317317
is_value_model = True
318318
else:
319-
model = AutoModelForCausalLM.from_pretrained(local_model_path)
319+
model = AutoModelForCausalLM.from_pretrained(local_model_path, torch_dtype="auto")
320320
state_dict = model.state_dict()
321321

322322
from verl.models.weight_loader_registry import get_weight_loader

0 commit comments

Comments
 (0)