Skip to content

google/flan-t5-xl - A100 80GB out of memory - LoRA #945

@tomekrut

Description

@tomekrut

Hi
I have a model based on T5 but wrapped in another class as there are many customizations done at the header level for text classification use case. All in all, within a Model class I call
self.t5 = T5Model.from_pretrained( 'google/flan-t5-xl', from_tf=False, config=self.config, cache_dir=self.cache_dir, )

another object within the main model class is
self.header = MyCustomClass()

As it is a custom model I looked in the following places

  1. https://github.com/huggingface/peft/blob/main/examples/sequence_classification/LoRA.ipynb
  2. https://www.philschmid.de/fine-tune-flan-t5-peft
  3. https://github.com/huggingface/peft/blob/main/examples/multilayer_perceptron/multilayer_perceptron_lora.ipynb

when I call the following
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q", "v"],
lora_dropout=0.05,
modules_to_save=["model.header.____", "model.header.____", ''....],
bias="none",
)

The point is I get the training parameters reduction from 2,7B to 17M. All is great but in terms of GPU memory consumption after initial 20GB of memory usage after the model is loaded when the model starts training it goes OOM. I would imagine that if the weights take around 20GB and an individual (any) 20M parameter model utilizes usually around 2GB of GPU memory. In total, the model training should stay way below 30GB.

Here are the key lines of what happens to the model.

model = MyModel() # inside there is self.t5 object
model = get_peft_model(model, lora_config)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
for i, batch in enumerate(train_dataloader):
outputs = model(batch)
loss = outputs[0]
loss.backward()
optimizer.step()
optimizer.zero_grad()

What am I doing wrong? I don't even mention the XXL model that should fit in the A100 based on tutorial 2

BTW - in the third tutorial above I see that the count of trainable parameters does not match the expected number
trainable params: 56,164 || all params: 4,100,164 || trainable%: 1.369798866581922
but the number of updated ones in cell 16 and 17 is = 16000 x 3 + 160 + 4000 + 2 = 52,162 instead 56,164. Does any one know why?

The attached LoRA.py can easily run on a single A100 GPU though. Any thoughts?
lora.txt

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions