-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
System Info
docker.io/pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
Python 3.10.13
pytorch 2.1.2
peft 0.7.1
transformers 4.36.2
bitsandbytes 0.41.3.post2
Who can help?
@pacman100 @younesbelkada @sayakpaul
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder - My own task or dataset (give details below)
Reproduction
from mylib import my_custom_dataset
from transformers import GPTQConfig
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments
)
from peft import (
prepare_model_for_kbit_training,
PeftModelForCausalLM,
AdaLoraConfig
)
# Yi is a llama family model.
model_path = "models/Yi-6B-GPTQ"
dataset = my_custom_dataset()
quantization_config = GPTQConfig(bits=4, use_exllama=False)
model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=quantization_config)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
config = AdaLoraConfig(
lora_alpha=8,
target_r=4,
init_r=24,
total_step=20,
tinit=5,
tfinal=2,
deltaT=1,
lora_dropout=0,
task_type="CAUSAL_LM",
bias="none",
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"down_proj",
"up_proj",
"gate_proj"
]
)
model = PeftModelForCausalLM(model, config, "default")
model.config.use_cache = False
trainer = Trainer(
model=model,
train_dataset=dataset,
args=TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
max_steps=20,
warmup_steps=5,
logging_steps=1,
save_steps=2,
learning_rate=2**(-12),
gradient_checkpointing=True,
fp16=True,
fp16_opt_level="O2",
output_dir="adalora_test5_outputs",
logging_strategy="steps",
lr_scheduler_type="constant_with_warmup",
save_strategy="steps",
save_total_limit=2,
optim="paged_adamw_32bit"
)
)
with torch.backends.cuda.sdp_kernel(enable_flash=False,
enable_math=False,
enable_mem_efficient=True):
trainer.train()Expected behavior
rank_pattern in adapter_config.json should be not None/null. Something like this, a trained adalora adapter, a not empty rank_pattern is expected.
rank_pattern should be updated in every DeltaT step, and the code should update the "Adaptive Budget Allocation" mask using this rank_pattern in every DeltaT step.
Since ~8 months ago as shown in this, rank_pattern is no longer correct.
From my testing (code shown below), neither of AdaLoraModel.update_and_allocate and RankAllocator.update_and_allocate are ever called. This might be the reason of the above behaviour.
......
from peft import AdaLoraModel
from peft.tuners.adalora.layer import RankAllocator
# a function to raise exception if ever called
def raise_anyway(self, *args, **kwargs):
raise
# monkey-patch
AdaLoraModel.update_and_allocate = raise_anyway
RankAllocator.update_and_allocate = raise_anyway
# load the model and train as usual without any errors.
......Metadata
Metadata
Assignees
Labels
No labels