From d8a81bb531a9bb5d72f1b9bdb5dbbe830ea38759 Mon Sep 17 00:00:00 2001 From: lchu Date: Tue, 1 Aug 2023 00:13:03 -0400 Subject: [PATCH 01/13] save cpu mem by leveraging FSDP rank0 broadcasting --- llama_finetuning.py | 47 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/llama_finetuning.py b/llama_finetuning.py index ccf8c6845..fcf08ee2f 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -17,6 +17,7 @@ from transformers import ( LlamaForCausalLM, LlamaTokenizer, + LlamaConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, @@ -41,6 +42,8 @@ get_policies ) +from accelerate import init_empty_weights + from utils.dataset_utils import get_preprocessed_dataset from utils.config_utils import ( @@ -62,8 +65,10 @@ from torch.optim.lr_scheduler import StepLR from pkg_resources import packaging import torch +import torch.nn as nn import torch.cuda.nccl as nccl import torch.distributed as dist +from torch.distributed.fsdp._common_utils import _is_fsdp_flattened from transformers.models.llama.modeling_llama import LlamaDecoderLayer @@ -90,11 +95,26 @@ def main(**kwargs): gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size # Load the pre-trained model and setup its configuration - model = LlamaForCausalLM.from_pretrained( - train_config.model_name, - load_in_8bit=True if train_config.quantization else None, - device_map="auto" if train_config.quantization else None, - ) + if train_config.enable_fsdp: + # for FSDP, we save cpu memory by loading pretrained model on rank0 only. + # this avoids cpu oom when loading large models like llama 70B, in which case + # model alone would consume 2+TB cpu mem (70 * 4 * 8) + if rank == 0: + model = LlamaForCausalLM.from_pretrained( + train_config.model_name, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + ) + else: + llama_config = LlamaConfig.from_pretrained(train_config.model_name) + with init_empty_weights(): + model = LlamaForCausalLM(llama_config) + else: + model = LlamaForCausalLM.from_pretrained( + train_config.model_name, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + ) print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) @@ -127,7 +147,20 @@ def main(**kwargs): mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) - + + # given the fast evolving PRs around meta device init, I am not sure + # what is the best param_init_fn atm, maybe we can switch to simple to_emtpy. + def _param_init_fn(module: nn.Module): + torch.manual_seed(0) + for submodule in module.modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if not _is_fsdp_flattened(param) and param.is_meta: + materialized_param = nn.Parameter( + torch.empty_like(param, device=torch.device("cuda")) + ) + nn.init.uniform_(materialized_param) + setattr(submodule, param_name, materialized_param) + model = FSDP( model, auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, @@ -135,6 +168,8 @@ def main(**kwargs): sharding_strategy=fsdp_config.sharding_strategy, device_id=torch.cuda.current_device(), limit_all_gathers=True, + sync_module_states=True, + param_init_fn=None if rank == 0 else _param_init_fn, ) if fsdp_config.fsdp_activation_checkpointing: policies.apply_fsdp_checkpointing(model) From c8d4f38d2330e14288b5dd882d0a275d01daa86c Mon Sep 17 00:00:00 2001 From: lchu Date: Tue, 1 Aug 2023 01:24:44 -0400 Subject: [PATCH 02/13] replace init_empty_weights with torch.device(meta) --- llama_finetuning.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llama_finetuning.py b/llama_finetuning.py index fcf08ee2f..d2fc95af8 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -42,8 +42,6 @@ get_policies ) -from accelerate import init_empty_weights - from utils.dataset_utils import get_preprocessed_dataset from utils.config_utils import ( @@ -107,7 +105,7 @@ def main(**kwargs): ) else: llama_config = LlamaConfig.from_pretrained(train_config.model_name) - with init_empty_weights(): + with torch.device("meta"): model = LlamaForCausalLM(llama_config) else: model = LlamaForCausalLM.from_pretrained( From 101391f46a05bdaf03a4dc696af224c994b5bad3 Mon Sep 17 00:00:00 2001 From: lchu Date: Tue, 1 Aug 2023 01:35:49 -0400 Subject: [PATCH 03/13] Revert "replace init_empty_weights with torch.device(meta)" This reverts commit c8d4f38d2330e14288b5dd882d0a275d01daa86c. --- llama_finetuning.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llama_finetuning.py b/llama_finetuning.py index d2fc95af8..fcf08ee2f 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -42,6 +42,8 @@ get_policies ) +from accelerate import init_empty_weights + from utils.dataset_utils import get_preprocessed_dataset from utils.config_utils import ( @@ -105,7 +107,7 @@ def main(**kwargs): ) else: llama_config = LlamaConfig.from_pretrained(train_config.model_name) - with torch.device("meta"): + with init_empty_weights(): model = LlamaForCausalLM(llama_config) else: model = LlamaForCausalLM.from_pretrained( From 1e64fc98d9296b28b7d34b11c28498f50af69e8b Mon Sep 17 00:00:00 2001 From: lchu Date: Tue, 1 Aug 2023 12:33:24 -0400 Subject: [PATCH 04/13] switch to simpler param_init_fn and meta device init --- llama_finetuning.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/llama_finetuning.py b/llama_finetuning.py index fcf08ee2f..0ac1bd572 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -42,8 +42,6 @@ get_policies ) -from accelerate import init_empty_weights - from utils.dataset_utils import get_preprocessed_dataset from utils.config_utils import ( @@ -107,7 +105,7 @@ def main(**kwargs): ) else: llama_config = LlamaConfig.from_pretrained(train_config.model_name) - with init_empty_weights(): + with torch.device("meta"): model = LlamaForCausalLM(llama_config) else: model = LlamaForCausalLM.from_pretrained( @@ -148,19 +146,6 @@ def main(**kwargs): mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) - # given the fast evolving PRs around meta device init, I am not sure - # what is the best param_init_fn atm, maybe we can switch to simple to_emtpy. - def _param_init_fn(module: nn.Module): - torch.manual_seed(0) - for submodule in module.modules(): - for param_name, param in submodule.named_parameters(recurse=False): - if not _is_fsdp_flattened(param) and param.is_meta: - materialized_param = nn.Parameter( - torch.empty_like(param, device=torch.device("cuda")) - ) - nn.init.uniform_(materialized_param) - setattr(submodule, param_name, materialized_param) - model = FSDP( model, auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, @@ -169,7 +154,7 @@ def _param_init_fn(module: nn.Module): device_id=torch.cuda.current_device(), limit_all_gathers=True, sync_module_states=True, - param_init_fn=None if rank == 0 else _param_init_fn, + param_init_fn=None if rank == 0 else lambda module: module.to_empty(device=torch.device("cuda"), recurse=False), ) if fsdp_config.fsdp_activation_checkpointing: policies.apply_fsdp_checkpointing(model) From 895dfcea30570bf683c3fb34bfb2ddf640156e46 Mon Sep 17 00:00:00 2001 From: lchu Date: Wed, 2 Aug 2023 18:32:31 -0400 Subject: [PATCH 05/13] add nightly check for using low_cpu_fsdp mode --- configs/training.py | 3 ++- llama_finetuning.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/configs/training.py b/configs/training.py index 4c50372da..8b3bfdb54 100644 --- a/configs/training.py +++ b/configs/training.py @@ -7,7 +7,8 @@ @dataclass class train_config: model_name: str="PATH/to/LLAMA/7B" - enable_fsdp: bool= False + enable_fsdp: bool=False + low_cpu_fsdp: bool=False run_validation: bool=True batch_size_training: int=4 num_epochs: int=3 diff --git a/llama_finetuning.py b/llama_finetuning.py index 0ac1bd572..d79c7ab0c 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -93,10 +93,16 @@ def main(**kwargs): gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size # Load the pre-trained model and setup its configuration - if train_config.enable_fsdp: - # for FSDP, we save cpu memory by loading pretrained model on rank0 only. + if train_config.enable_fsdp and train_config.low_cpu_fsdp: + # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. # this avoids cpu oom when loading large models like llama 70B, in which case - # model alone would consume 2+TB cpu mem (70 * 4 * 8) + # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms + # overhead and currently requires latest nightly. + v = packaging.version.parse(torch.__version__) + verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 + if not verify_latest_nightly: + raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " + "please install latest nightly.") if rank == 0: model = LlamaForCausalLM.from_pretrained( train_config.model_name, From e216c6f1f3ba5417d06587bc29ac34272932f04d Mon Sep 17 00:00:00 2001 From: lchu Date: Wed, 2 Aug 2023 21:09:15 -0400 Subject: [PATCH 06/13] address #87 --- model_checkpointing/checkpoint_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model_checkpointing/checkpoint_handler.py b/model_checkpointing/checkpoint_handler.py index 5bb8ff111..b097df97d 100644 --- a/model_checkpointing/checkpoint_handler.py +++ b/model_checkpointing/checkpoint_handler.py @@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg): reader = FileSystemReader(load_dir) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): - checkpoint = model.state_dict() + checkpoint = {"model": model.state_dict()} if rank == 0: ck = checkpoint.keys() print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") @@ -78,7 +78,7 @@ def load_model_sharded(model, rank, cfg): print(f"checkpoint after load_state_dict()") ck = checkpoint.keys() print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") - model.load_state_dict(checkpoint) + model.load_state_dict(checkpoint["model"]) if rank == 0: print(f"Sharded state checkpoint loaded from {load_dir}") From c19c5c69aa4f2b0eb03119ef64c3aec164595ab9 Mon Sep 17 00:00:00 2001 From: lchu Date: Thu, 3 Aug 2023 10:38:31 -0400 Subject: [PATCH 07/13] fix fsdp construction on low_cpu_fsdp --- llama_finetuning.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/llama_finetuning.py b/llama_finetuning.py index d79c7ab0c..3c7ee773f 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -39,7 +39,7 @@ clear_gpu_cache, get_parameter_dtypes, print_model_size, - get_policies + get_policies ) from utils.dataset_utils import get_preprocessed_dataset @@ -88,10 +88,10 @@ def main(**kwargs): if torch.distributed.is_initialized(): torch.cuda.set_device(rank) setup_environ_flags(rank) - + # Calculate gradient accumulation steps gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size - + # Load the pre-trained model and setup its configuration if train_config.enable_fsdp and train_config.low_cpu_fsdp: # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. @@ -113,19 +113,20 @@ def main(**kwargs): llama_config = LlamaConfig.from_pretrained(train_config.model_name) with torch.device("meta"): model = LlamaForCausalLM(llama_config) + else: model = LlamaForCausalLM.from_pretrained( train_config.model_name, load_in_8bit=True if train_config.quantization else None, device_map="auto" if train_config.quantization else None, ) - + print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) - + # Prepare the model for int8 training if quantization is enabled if train_config.quantization: model = prepare_model_for_int8_training(model) - + # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled if train_config.enable_fsdp and fsdp_config.pure_bf16: model.to(torch.bfloat16) @@ -134,7 +135,7 @@ def main(**kwargs): tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name) tokenizer.add_special_tokens( { - + "pad_token": "", } ) @@ -142,11 +143,11 @@ def main(**kwargs): peft_config = generate_peft_config(train_config, kwargs) model = get_peft_model(model, peft_config) model.print_trainable_parameters() - + #setting up FSDP if enable_fsdp is enabled if train_config.enable_fsdp: if not train_config.use_peft and train_config.freeze_layers: - + freeze_transformer_layers(train_config.num_freeze_layers) mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) @@ -159,8 +160,9 @@ def main(**kwargs): sharding_strategy=fsdp_config.sharding_strategy, device_id=torch.cuda.current_device(), limit_all_gathers=True, - sync_module_states=True, - param_init_fn=None if rank == 0 else lambda module: module.to_empty(device=torch.device("cuda"), recurse=False), + sync_module_states=True if train_config.low_cpu_fsdp else False, + param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) + if train_config.low_cpu_fsdp and rank != 0 else None, ) if fsdp_config.fsdp_activation_checkpointing: policies.apply_fsdp_checkpointing(model) @@ -168,14 +170,14 @@ def main(**kwargs): model.to("cuda") dataset_config = generate_dataset_config(train_config, kwargs) - + # Load and preprocess the dataset for training and validation dataset_train = get_preprocessed_dataset( tokenizer, dataset_config, split="train", ) - + if not train_config.enable_fsdp or rank == 0: print(f"--> Training Set Length = {len(dataset_train)}") @@ -202,7 +204,7 @@ def main(**kwargs): rank=dist.get_rank(), num_replicas=dist.get_world_size(), ) - + # Create DataLoaders for the training and validation dataset train_dataloader = torch.utils.data.DataLoader( dataset_train, @@ -224,7 +226,7 @@ def main(**kwargs): drop_last=True, collate_fn=default_data_collator, ) - + # Initialize the optimizer and learning rate scheduler if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": optimizer = AnyPrecisionAdamW( @@ -246,7 +248,7 @@ def main(**kwargs): results = train( model, train_dataloader, - eval_dataloader, + eval_dataloader, tokenizer, optimizer, scheduler, From 0c51b472627d0f310cbf4e75488fa9d13b11792c Mon Sep 17 00:00:00 2001 From: lchu Date: Thu, 3 Aug 2023 11:06:43 -0400 Subject: [PATCH 08/13] fix #90 --- llama_finetuning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_finetuning.py b/llama_finetuning.py index 3c7ee773f..f5cc9c9f6 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -86,7 +86,8 @@ def main(**kwargs): world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): - torch.cuda.set_device(rank) + torch.cuda.set_device(local_rank) + clear_gpu_cache(local_rank) setup_environ_flags(rank) # Calculate gradient accumulation steps From 80a4c367070501c3fba6ffe6a6f918ca988ddec7 Mon Sep 17 00:00:00 2001 From: lchu Date: Thu, 3 Aug 2023 14:28:20 -0400 Subject: [PATCH 09/13] further fix #90 --- utils/train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/train_utils.py b/utils/train_utils.py index 3fa4c0cf1..a21e799ce 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -135,7 +135,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche lr_scheduler.step() if train_config.run_validation: - eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer) + eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) if train_config.save_model and eval_epoch_loss < best_val_loss: if train_config.enable_fsdp: dist.barrier() From c453b668faac8e9ecae3449fcc1f7d56d6ae54c1 Mon Sep 17 00:00:00 2001 From: lchu Date: Thu, 3 Aug 2023 16:28:38 -0400 Subject: [PATCH 10/13] add doc example about using low_cpu_fsdp --- README.md | 10 ++++++++++ docs/mutli_gpu.md | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/README.md b/README.md index ab122be26..6c1a50427 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,16 @@ torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --mode ``` +### Fine-tuning using FSDP on 70B Model + +If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs. + +```bash + +torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned + +``` + ### Multi GPU Multi Node: ```bash diff --git a/docs/mutli_gpu.md b/docs/mutli_gpu.md index a4396deea..49bbe72a1 100644 --- a/docs/mutli_gpu.md +++ b/docs/mutli_gpu.md @@ -55,6 +55,16 @@ torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --mode ``` +### Fine-tuning using FSDP on 70B Model + +If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs. + +```bash + +torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned + +``` + **Multi GPU multi node**: Here we use a slurm script to schedule a job with slurm over multiple nodes. From 1cc9df19e6bdc19d272fe4a938b8d224789a81c1 Mon Sep 17 00:00:00 2001 From: lchu Date: Sun, 6 Aug 2023 09:44:55 -0400 Subject: [PATCH 11/13] remove unused import --- llama_finetuning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llama_finetuning.py b/llama_finetuning.py index f5cc9c9f6..defe59030 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -66,7 +66,6 @@ import torch.nn as nn import torch.cuda.nccl as nccl import torch.distributed as dist -from torch.distributed.fsdp._common_utils import _is_fsdp_flattened from transformers.models.llama.modeling_llama import LlamaDecoderLayer From 41ffbcab52e385a38732118ec19d2bdf59c26192 Mon Sep 17 00:00:00 2001 From: lchu Date: Sun, 6 Aug 2023 09:46:53 -0400 Subject: [PATCH 12/13] code cleanup to remove all unused imports --- llama_finetuning.py | 69 +++++++++++++++------------------------------ 1 file changed, 23 insertions(+), 46 deletions(-) diff --git a/llama_finetuning.py b/llama_finetuning.py index defe59030..ad33de383 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -2,72 +2,49 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os -import sys -from typing import List, Union import fire import torch -import transformers -from datasets import load_dataset -import os.path as osp -from tqdm import tqdm - -# Unused imports removed -from utils import fsdp_auto_wrap_policy +import torch.distributed as dist +import torch.distributed as dist +import torch.optim as optim +from peft import get_peft_model, prepare_model_for_int8_training +from pkg_resources import packaging +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, +) +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import DistributedSampler from transformers import ( LlamaForCausalLM, LlamaTokenizer, LlamaConfig, - AutoModelForCausalLM, - AutoModelForSeq2SeqLM, - AutoTokenizer, default_data_collator, - BitsAndBytesConfig ) -import torch.distributed as dist +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +import policies +from configs import fsdp_config, train_config +from policies import AnyPrecisionAdamW + +from utils import fsdp_auto_wrap_policy +from utils.config_utils import ( + update_config, + generate_peft_config, + generate_dataset_config, +) +from utils.dataset_utils import get_preprocessed_dataset -# Unused imports removed from utils.train_utils import ( - set_tokenizer_params, train, - evaluation, freeze_transformer_layers, - check_frozen_layers_peft_model, setup, setup_environ_flags, - cleanup, clear_gpu_cache, - get_parameter_dtypes, print_model_size, get_policies ) -from utils.dataset_utils import get_preprocessed_dataset - -from utils.config_utils import ( - update_config, - generate_peft_config, - generate_dataset_config, -) -from peft import get_peft_model, TaskType, prepare_model_for_int8_training -import configs -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - MixedPrecision, -) -from torch.utils.data import DistributedSampler -import policies -from policies import AnyPrecisionAdamW -from configs import fsdp_config, train_config -import torch.optim as optim -from torch.optim.lr_scheduler import StepLR -from pkg_resources import packaging -import torch -import torch.nn as nn -import torch.cuda.nccl as nccl -import torch.distributed as dist -from transformers.models.llama.modeling_llama import LlamaDecoderLayer - def main(**kwargs): # Update the configuration for the training and sharding process From 3d1e9cd58caafe11de8fa26f68961eeb835fad7c Mon Sep 17 00:00:00 2001 From: lchu Date: Tue, 8 Aug 2023 10:39:50 -0400 Subject: [PATCH 13/13] minor code optimization --- llama_finetuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_finetuning.py b/llama_finetuning.py index ad33de383..cf573b020 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -137,7 +137,7 @@ def main(**kwargs): sharding_strategy=fsdp_config.sharding_strategy, device_id=torch.cuda.current_device(), limit_all_gathers=True, - sync_module_states=True if train_config.low_cpu_fsdp else False, + sync_module_states=train_config.low_cpu_fsdp, param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) if train_config.low_cpu_fsdp and rank != 0 else None, )