Skip to content

AdaLora rank_pattern is None, Docs and Warnings and Integrations should be improved up on. #1327

@kuronekosaiko

Description

@kuronekosaiko

System Info

docker.io/pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
Python 3.10.13
pytorch 2.1.2
peft 0.7.1
transformers 4.36.2
bitsandbytes 0.41.3.post2

Who can help?

@pacman100 @younesbelkada @sayakpaul

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

from mylib import my_custom_dataset
from transformers import GPTQConfig
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments
)
from peft import (
    prepare_model_for_kbit_training,
    PeftModelForCausalLM,
    AdaLoraConfig
)


# Yi is a llama family model.
model_path = "models/Yi-6B-GPTQ"
dataset = my_custom_dataset()

quantization_config = GPTQConfig(bits=4, use_exllama=False)
model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=quantization_config)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

config = AdaLoraConfig(
    lora_alpha=8,
    target_r=4,
    init_r=24,
    total_step=20,
    tinit=5,
    tfinal=2,
    deltaT=1,
    lora_dropout=0,
    task_type="CAUSAL_LM",
    bias="none",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "down_proj",
        "up_proj",
        "gate_proj"
    ]
)
model = PeftModelForCausalLM(model, config, "default")

model.config.use_cache = False
trainer = Trainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        max_steps=20,
        warmup_steps=5,
        logging_steps=1,
        save_steps=2,
        learning_rate=2**(-12),
        gradient_checkpointing=True,
        fp16=True,
        fp16_opt_level="O2",
        output_dir="adalora_test5_outputs",
        logging_strategy="steps",
        lr_scheduler_type="constant_with_warmup",
        save_strategy="steps",
        save_total_limit=2,
        optim="paged_adamw_32bit"
    )
)

with torch.backends.cuda.sdp_kernel(enable_flash=False,
                                    enable_math=False,
                                    enable_mem_efficient=True):
    trainer.train()

Expected behavior

rank_pattern in adapter_config.json should be not None/null. Something like this, a trained adalora adapter, a not empty rank_pattern is expected.

rank_pattern should be updated in every DeltaT step, and the code should update the "Adaptive Budget Allocation" mask using this rank_pattern in every DeltaT step.

Since ~8 months ago as shown in this, rank_pattern is no longer correct.

From my testing (code shown below), neither of AdaLoraModel.update_and_allocate and RankAllocator.update_and_allocate are ever called. This might be the reason of the above behaviour.

......

from peft import AdaLoraModel
from peft.tuners.adalora.layer import RankAllocator

# a function to raise exception if ever called
def raise_anyway(self, *args, **kwargs):
    raise

# monkey-patch
AdaLoraModel.update_and_allocate = raise_anyway
RankAllocator.update_and_allocate = raise_anyway

# load the model and train as usual without any errors.
......

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions