Commit a1dd922
authored
fix: set torch dtype to auto (#749)
### Issue
As transformers `from_pretrained()` automatically load ckpt in the
default torch.dtype (here, fp32) and not follow the dtype of the weights
themselve, it often causes high CPU memory usage even OOM, which crushes
the training.
### Fix
Follow by huggingface/transformers#34743 and
tested on loading Llama-3.1-405B-Instruct.
```
- model = AutoModelForCausalLM.from_pretrained(model_file)
+ model = AutoModelForCausalLM.from_pretrained(model_file, torch_dtype="auto")
```
### Potential problem:
Some CPUs can not support `bfloat16` and it might take some time to
convert back to `float32`.
Co-authored-by: Changlong Yu <[email protected]>1 parent fc27caf commit a1dd922
2 files changed
+2
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
267 | 267 | | |
268 | 268 | | |
269 | 269 | | |
270 | | - | |
| 270 | + | |
271 | 271 | | |
272 | 272 | | |
273 | 273 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
316 | 316 | | |
317 | 317 | | |
318 | 318 | | |
319 | | - | |
| 319 | + | |
320 | 320 | | |
321 | 321 | | |
322 | 322 | | |
| |||
0 commit comments