-
Notifications
You must be signed in to change notification settings - Fork 683
MPS memory usage support #2406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPS memory usage support #2406
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2406
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d818a7e with merge base 504cbea ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great - can you include some screenshots or logging output from your test run to confirm it works as expected?
|
Yes sure |
|
LORA: Terminal Logs(venv) musab@musab torchtune % tune run lora_finetune_single_device --config recipes/configs/llama3_1/8B_lora_single_device.yaml checkpointer.checkpoint_dir="Meta-Llama-3.1-8B-Instruct" tokenizer.path="Meta-Llama-3.1-8B-Instruct/original/tokenizer.model" checkpointer.output_dir="Llama-3.1-8B-Instruct-Tuned/" optimizer._component_=torch.optim.AdamW optimizer.fused=True device=mps log_peak_memory_stats=True enable_activation_checkpointing=True tokenizer.max_seq_len=2048 batch_size=1 max_steps_per_epoch=10 output_dir=output
import error: No module named 'triton'
W0218 16:24:26.524000 56073 torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:
batch_size: 1
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: Meta-Llama-3.1-8B-Instruct
checkpoint_files:
- model-00001-of-00004.safetensors
- model-00002-of-00004.safetensors
- model-00003-of-00004.safetensors
- model-00004-of-00004.safetensors
model_type: LLAMA3
output_dir: Llama-3.1-8B-Instruct-Tuned/
recipe_checkpoint: null
clip_grad_norm: null
compile: false
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: false
device: mps
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
max_steps_per_epoch: 10
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: output/logs
model:
_component_: torchtune.models.llama3_1.lora_llama3_1_8b
apply_lora_to_mlp: true
apply_lora_to_output: false
lora_alpha: 16
lora_attn_modules:
- q_proj
- v_proj
- output_proj
lora_dropout: 0.0
lora_rank: 8
optimizer:
_component_: torch.optim.AdamW
fused: true
lr: 0.0003
weight_decay: 0.01
output_dir: output
profiler:
_component_: torchtune.training.setup_torch_profiler
active_steps: 2
cpu: true
cuda: true
enabled: false
num_cycles: 1
output_dir: output/profiling_outputs
profile_memory: false
record_shapes: true
wait_steps: 5
warmup_steps: 3
with_flops: false
with_stack: false
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
max_seq_len: 2048
path: Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
DEBUG:torchtune.utils._logging:Setting manual seed to local seed 182672124. Local seed is seed + rank = 182672124 + 0
INFO:torchtune.utils._logging:Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
Writing logs to output/logs/log_1739885066.txt
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
CPU peak memory allocation: 16.01 GiB
CPU peak memory reserved: 16.01 GiB
CPU peak memory active: 15.10 GiB
INFO:torchtune.utils._logging:Tokenizer is initialized from file.
INFO:torchtune.utils._logging:Optimizer and loss are initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|10|Loss: 1.6531014442443848: 100%|█████████████████████████████████████████████████████████████| 10/10 [01:37<00:00, 10.07s/it]INFO:torchtune.utils._logging:Starting checkpoint save...
INFO:torchtune.utils._logging:Model checkpoint of size 4.63 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/model-00001-of-00004.safetensors
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/model-00002-of-00004.safetensors
INFO:torchtune.utils._logging:Model checkpoint of size 4.58 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/model-00003-of-00004.safetensors
INFO:torchtune.utils._logging:Model checkpoint of size 1.09 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/model-00004-of-00004.safetensors
INFO:torchtune.utils._logging:Adapter checkpoint of size 0.04 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/adapter_model.pt
INFO:torchtune.utils._logging:Adapter checkpoint of size 0.04 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/adapter_model.safetensors
INFO:torchtune.utils._logging:Adapter checkpoint of size 0.00 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/adapter_config.json
INFO:torchtune.utils._logging:Saving final epoch checkpoint.
INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
INFO:torchtune.utils._logging:Checkpoint saved in 44.84 seconds.
1|10|Loss: 1.6531014442443848: 100%|█████████████████████████████████████████████████████████████| 10/10 [02:22<00:00, 14.23s/it]LogsStep 1 | loss:1.8555225133895874 lr:2.9999999999999997e-06 tokens_per_second_per_gpu:152.13568115234375 peak_memory_active:15.219014883041382 peak_memory_alloc:16.401046752929688 peak_memory_reserved:16.401046752929688
Step 2 | loss:1.5935295820236206 lr:5.999999999999999e-06 tokens_per_second_per_gpu:166.90853881835938 peak_memory_active:15.219014883041382 peak_memory_alloc:16.419021606445312 peak_memory_reserved:16.419021606445312
Step 3 | loss:1.565019965171814 lr:8.999999999999999e-06 tokens_per_second_per_gpu:179.53611755371094 peak_memory_active:15.21902084350586 peak_memory_alloc:17.472152709960938 peak_memory_reserved:17.472152709960938
Step 4 | loss:1.6224077939987183 lr:1.1999999999999999e-05 tokens_per_second_per_gpu:169.76544189453125 peak_memory_active:15.21902084350586 peak_memory_alloc:17.482131958007812 peak_memory_reserved:17.482131958007812
Step 5 | loss:1.6456714868545532 lr:1.4999999999999999e-05 tokens_per_second_per_gpu:169.54025268554688 peak_memory_active:15.219014883041382 peak_memory_alloc:17.505889892578125 peak_memory_reserved:17.505889892578125
Step 6 | loss:1.6735244989395142 lr:1.7999999999999997e-05 tokens_per_second_per_gpu:167.10714721679688 peak_memory_active:15.21902084350586 peak_memory_alloc:17.484664916992188 peak_memory_reserved:17.484664916992188
Step 7 | loss:2.0274529457092285 lr:2.1e-05 tokens_per_second_per_gpu:156.9060516357422 peak_memory_active:15.219014883041382 peak_memory_alloc:17.517990112304688 peak_memory_reserved:17.517990112304688
Step 8 | loss:1.623584270477295 lr:2.3999999999999997e-05 tokens_per_second_per_gpu:160.64462280273438 peak_memory_active:15.21902084350586 peak_memory_alloc:17.541717529296875 peak_memory_reserved:17.541717529296875
Step 9 | loss:1.7031704187393188 lr:2.6999999999999996e-05 tokens_per_second_per_gpu:153.1827392578125 peak_memory_active:15.219014883041382 peak_memory_alloc:17.541839599609375 peak_memory_reserved:17.541839599609375
Step 10 | loss:1.6531014442443848 lr:2.9999999999999997e-05 tokens_per_second_per_gpu:165.5233917236328 peak_memory_active:15.21902084350586 peak_memory_alloc:17.536239624023438 peak_memory_reserved:17.536239624023438 |
|
Full-Weight: Terminal Logs(venv) musab@musab torchtune % tune run full_finetune_single_device --config recipes/configs/llama3_1/8B_full_single_device.yaml checkpointer.checkpoint_dir="Meta-Llama-3.1-8B-Instruct" tokenizer.path="Meta-Llama-3.1-8B-Instruct/original/tokenizer.model" checkpointer.output_dir="Llama-3.1-8B-Instruct-Tuned/" optimizer._component_=torch.optim.AdamW optimizer.fused=True device=mps log_peak_memory_stats=True enable_activation_checkpointing=True tokenizer.max_seq_len=2048 batch_size=1 max_steps_per_epoch=10 output_dir=output
import error: No module named 'triton'
W0218 16:38:14.067000 59075 torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
INFO:torchtune.utils._logging:Running FullFinetuneRecipeSingleDevice with resolved config:
batch_size: 1
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: Meta-Llama-3.1-8B-Instruct
checkpoint_files:
- model-00001-of-00004.safetensors
- model-00002-of-00004.safetensors
- model-00003-of-00004.safetensors
- model-00004-of-00004.safetensors
model_type: LLAMA3
output_dir: Llama-3.1-8B-Instruct-Tuned/
recipe_checkpoint: null
clip_grad_norm: null
compile: false
dataset:
_component_: torchtune.datasets.alpaca_dataset
packed: false
device: mps
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: 10
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: output/logs
model:
_component_: torchtune.models.llama3_1.llama3_1_8b
optimizer:
_component_: torch.optim.AdamW
fused: true
lr: 2.0e-05
optimizer_in_bwd: true
output_dir: output
profiler:
_component_: torchtune.training.setup_torch_profiler
active_steps: 2
cpu: true
cuda: true
enabled: false
num_cycles: 1
output_dir: output/profiling_outputs
profile_memory: false
record_shapes: true
wait_steps: 5
warmup_steps: 3
with_flops: false
with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
max_seq_len: 2048
path: Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
INFO:torchtune.utils._logging:Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
DEBUG:torchtune.utils._logging:Setting manual seed to local seed 3054736935. Local seed is seed + rank = 3054736935 + 0
Writing logs to output/logs/log_1739885894.txt
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
CPU peak memory allocation: 15.98 GiB
CPU peak memory reserved: 15.98 GiB
CPU peak memory active: 15.07 GiB
INFO:torchtune.utils._logging:Tokenizer is initialized from file.
INFO:torchtune.utils._logging:In-backward optimizers are set up.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:No learning rate scheduler configured. Using constant learning rate.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|10|Loss: 0.6701164841651917: 100%|█████████████████████████████████████████████████████████████| 10/10 [00:21<00:00, 2.09s/it]INFO:torchtune.utils._logging:Model checkpoint of size 4.63 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/model-00001-of-00004.safetensors
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/model-00002-of-00004.safetensors
INFO:torchtune.utils._logging:Model checkpoint of size 4.58 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/model-00003-of-00004.safetensors
INFO:torchtune.utils._logging:Model checkpoint of size 1.09 GiB saved to Llama-3.1-8B-Instruct-Tuned/epoch_0/model-00004-of-00004.safetensors
INFO:torchtune.utils._logging:Saving final epoch checkpoint.
INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
1|10|Loss: 0.6701164841651917: 100%|█████████████████████████████████████████████████████████████| 10/10 [00:27<00:00, 2.71s/it]LogsStep 1 | loss:2.4555773735046387 lr:2e-05 tokens_per_second_per_gpu:46.73149108886719 peak_memory_active:47.32469987869263 peak_memory_alloc:52.001708984375 peak_memory_reserved:52.001708984375
Step 2 | loss:1.4622046947479248 lr:2e-05 tokens_per_second_per_gpu:45.71133041381836 peak_memory_active:47.32469987869263 peak_memory_alloc:53.99104309082031 peak_memory_reserved:53.99104309082031
Step 3 | loss:1.7121436595916748 lr:2e-05 tokens_per_second_per_gpu:45.4216423034668 peak_memory_active:47.32469987869263 peak_memory_alloc:54.96961975097656 peak_memory_reserved:54.96961975097656
Step 4 | loss:1.1561416387557983 lr:2e-05 tokens_per_second_per_gpu:40.134742736816406 peak_memory_active:47.32469987869263 peak_memory_alloc:55.948211669921875 peak_memory_reserved:55.948211669921875
Step 5 | loss:1.0250656604766846 lr:2e-05 tokens_per_second_per_gpu:23.665254592895508 peak_memory_active:47.32469940185547 peak_memory_alloc:56.93070983886719 peak_memory_reserved:56.93070983886719
Step 6 | loss:1.046566128730774 lr:2e-05 tokens_per_second_per_gpu:55.77537155151367 peak_memory_active:47.32488775253296 peak_memory_alloc:57.90928649902344 peak_memory_reserved:57.90928649902344
Step 7 | loss:0.554315447807312 lr:2e-05 tokens_per_second_per_gpu:38.398162841796875 peak_memory_active:47.32469987869263 peak_memory_alloc:58.88798522949219 peak_memory_reserved:58.88798522949219
Step 8 | loss:1.5436396598815918 lr:2e-05 tokens_per_second_per_gpu:39.18619918823242 peak_memory_active:47.32469987869263 peak_memory_alloc:59.866546630859375 peak_memory_reserved:59.866546630859375
Step 9 | loss:0.8239040970802307 lr:2e-05 tokens_per_second_per_gpu:47.04241943359375 peak_memory_active:47.32469987869263 peak_memory_alloc:59.862640380859375 peak_memory_reserved:59.862640380859375
Step 10 | loss:0.6701164841651917 lr:2e-05 tokens_per_second_per_gpu:41.19228744506836 peak_memory_active:47.32469987869263 peak_memory_alloc:60.84120178222656 peak_memory_reserved:60.84120178222656 M3 Max 128GB |
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
| Defaults to "Memory stats after model init:" | ||
| """ | ||
| device_support = get_device_support() | ||
| _log.info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's loop over the dictionary items so we log the key and the value. This way, we can omit the peak_memory_reserved if it doesn't exist at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. New output for MPS:
INFO:torchtune.utils._logging:Memory stats after model init:
CPU peak memory active: 2.45 GiB
CPU peak memory alloc: 3.10 GiB
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2406 +/- ##
===========================================
- Coverage 63.87% 23.47% -40.41%
===========================================
Files 368 373 +5
Lines 21873 22406 +533
===========================================
- Hits 13971 5259 -8712
- Misses 7902 17147 +9245 ☔ View full report in Codecov by Sentry. |
Context
When device is mps,
training.get_memory_statswas causing the training to crash. Since memory_stats was only available for CUDA.Changelog
Added mps memory logging support via
https://pytorch.org/docs/stable/generated/torch.mps.current_allocated_memory.html
https://pytorch.org/docs/stable/generated/torch.mps.driver_allocated_memory.html
Test plan
Tested manually on MacBook Pro with device=mps.
Terminal Logs
Logs
M3 Max 128GB
Torch 2.6.0
Python 3.13.2
UX
I did not change any public API
Question
Should we omit
peak_memory_reservedfor MPS or should we set it equal topeak_mem_alloc?