-
Notifications
You must be signed in to change notification settings - Fork 2.7k
save cpu mem by leveraging FSDP rank0 broadcasting #77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
d8a81bb
save cpu mem by leveraging FSDP rank0 broadcasting
lchu6 c8d4f38
replace init_empty_weights with torch.device(meta)
lchu6 101391f
Revert "replace init_empty_weights with torch.device(meta)"
lchu6 1e64fc9
switch to simpler param_init_fn and meta device init
lchu6 895dfce
add nightly check for using low_cpu_fsdp mode
lchu6 e216c6f
address #87
lchu6 c19c5c6
fix fsdp construction on low_cpu_fsdp
lchu6 0c51b47
fix #90
lchu6 80a4c36
further fix #90
lchu6 c453b66
add doc example about using low_cpu_fsdp
lchu6 1cc9df1
remove unused import
lchu6 41ffbca
code cleanup to remove all unused imports
lchu6 3d1e9cd
minor code optimization
lchu6 feaa344
resolve conflicts
lchu6 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| from transformers import ( | ||
| LlamaForCausalLM, | ||
| LlamaTokenizer, | ||
| LlamaConfig, | ||
| AutoModelForCausalLM, | ||
| AutoModelForSeq2SeqLM, | ||
| AutoTokenizer, | ||
|
|
@@ -41,6 +42,8 @@ | |
| get_policies | ||
| ) | ||
|
|
||
| from accelerate import init_empty_weights | ||
|
|
||
| from utils.dataset_utils import get_preprocessed_dataset | ||
|
|
||
| from utils.config_utils import ( | ||
|
|
@@ -62,8 +65,10 @@ | |
| from torch.optim.lr_scheduler import StepLR | ||
| from pkg_resources import packaging | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.cuda.nccl as nccl | ||
| import torch.distributed as dist | ||
| from torch.distributed.fsdp._common_utils import _is_fsdp_flattened | ||
| from transformers.models.llama.modeling_llama import LlamaDecoderLayer | ||
|
|
||
|
|
||
|
|
@@ -90,11 +95,26 @@ def main(**kwargs): | |
| gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size | ||
|
|
||
| # Load the pre-trained model and setup its configuration | ||
| model = LlamaForCausalLM.from_pretrained( | ||
| train_config.model_name, | ||
| load_in_8bit=True if train_config.quantization else None, | ||
| device_map="auto" if train_config.quantization else None, | ||
| ) | ||
| if train_config.enable_fsdp: | ||
lchu6 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # for FSDP, we save cpu memory by loading pretrained model on rank0 only. | ||
| # this avoids cpu oom when loading large models like llama 70B, in which case | ||
| # model alone would consume 2+TB cpu mem (70 * 4 * 8) | ||
| if rank == 0: | ||
| model = LlamaForCausalLM.from_pretrained( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we figure out why torch.device("meta") init doesn't work here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rohan-varma for non-0 ranks, we are using |
||
| train_config.model_name, | ||
| load_in_8bit=True if train_config.quantization else None, | ||
| device_map="auto" if train_config.quantization else None, | ||
| ) | ||
| else: | ||
| llama_config = LlamaConfig.from_pretrained(train_config.model_name) | ||
| with init_empty_weights(): | ||
lchu6 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| model = LlamaForCausalLM(llama_config) | ||
| else: | ||
| model = LlamaForCausalLM.from_pretrained( | ||
| train_config.model_name, | ||
| load_in_8bit=True if train_config.quantization else None, | ||
| device_map="auto" if train_config.quantization else None, | ||
| ) | ||
|
|
||
| print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) | ||
|
|
||
|
|
@@ -127,14 +147,29 @@ def main(**kwargs): | |
|
|
||
| mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) | ||
| my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) | ||
|
|
||
|
|
||
| # given the fast evolving PRs around meta device init, I am not sure | ||
| # what is the best param_init_fn atm, maybe we can switch to simple to_emtpy. | ||
lchu6 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def _param_init_fn(module: nn.Module): | ||
| torch.manual_seed(0) | ||
| for submodule in module.modules(): | ||
| for param_name, param in submodule.named_parameters(recurse=False): | ||
| if not _is_fsdp_flattened(param) and param.is_meta: | ||
| materialized_param = nn.Parameter( | ||
| torch.empty_like(param, device=torch.device("cuda")) | ||
| ) | ||
| nn.init.uniform_(materialized_param) | ||
| setattr(submodule, param_name, materialized_param) | ||
|
|
||
| model = FSDP( | ||
| model, | ||
| auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, | ||
| mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, | ||
| sharding_strategy=fsdp_config.sharding_strategy, | ||
| device_id=torch.cuda.current_device(), | ||
| limit_all_gathers=True, | ||
| sync_module_states=True, | ||
lchu6 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| param_init_fn=None if rank == 0 else _param_init_fn, | ||
| ) | ||
| if fsdp_config.fsdp_activation_checkpointing: | ||
| policies.apply_fsdp_checkpointing(model) | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.