Skip to content
Merged
47 changes: 41 additions & 6 deletions llama_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
LlamaConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
Expand All @@ -41,6 +42,8 @@
get_policies
)

from accelerate import init_empty_weights

from utils.dataset_utils import get_preprocessed_dataset

from utils.config_utils import (
Expand All @@ -62,8 +65,10 @@
from torch.optim.lr_scheduler import StepLR
from pkg_resources import packaging
import torch
import torch.nn as nn
import torch.cuda.nccl as nccl
import torch.distributed as dist
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
from transformers.models.llama.modeling_llama import LlamaDecoderLayer


Expand All @@ -90,11 +95,26 @@ def main(**kwargs):
gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size

# Load the pre-trained model and setup its configuration
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)
if train_config.enable_fsdp:
# for FSDP, we save cpu memory by loading pretrained model on rank0 only.
# this avoids cpu oom when loading large models like llama 70B, in which case
# model alone would consume 2+TB cpu mem (70 * 4 * 8)
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we figure out why torch.device("meta") init doesn't work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma for non-0 ranks, we are using torch.device("meta") init.

train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)
else:
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
with init_empty_weights():
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

Expand Down Expand Up @@ -127,14 +147,29 @@ def main(**kwargs):

mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)


# given the fast evolving PRs around meta device init, I am not sure
# what is the best param_init_fn atm, maybe we can switch to simple to_emtpy.
def _param_init_fn(module: nn.Module):
torch.manual_seed(0)
for submodule in module.modules():
for param_name, param in submodule.named_parameters(recurse=False):
if not _is_fsdp_flattened(param) and param.is_meta:
materialized_param = nn.Parameter(
torch.empty_like(param, device=torch.device("cuda"))
)
nn.init.uniform_(materialized_param)
setattr(submodule, param_name, materialized_param)

model = FSDP(
model,
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=True,
param_init_fn=None if rank == 0 else _param_init_fn,
)
if fsdp_config.fsdp_activation_checkpointing:
policies.apply_fsdp_checkpointing(model)
Expand Down