-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
Hi
I have a model based on T5 but wrapped in another class as there are many customizations done at the header level for text classification use case. All in all, within a Model class I call
self.t5 = T5Model.from_pretrained( 'google/flan-t5-xl', from_tf=False, config=self.config, cache_dir=self.cache_dir, )
another object within the main model class is
self.header = MyCustomClass()
As it is a custom model I looked in the following places
- https://github.com/huggingface/peft/blob/main/examples/sequence_classification/LoRA.ipynb
- https://www.philschmid.de/fine-tune-flan-t5-peft
- https://github.com/huggingface/peft/blob/main/examples/multilayer_perceptron/multilayer_perceptron_lora.ipynb
when I call the following
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q", "v"],
lora_dropout=0.05,
modules_to_save=["model.header.____", "model.header.____", ''....],
bias="none",
)
The point is I get the training parameters reduction from 2,7B to 17M. All is great but in terms of GPU memory consumption after initial 20GB of memory usage after the model is loaded when the model starts training it goes OOM. I would imagine that if the weights take around 20GB and an individual (any) 20M parameter model utilizes usually around 2GB of GPU memory. In total, the model training should stay way below 30GB.
Here are the key lines of what happens to the model.
model = MyModel() # inside there is self.t5 object
model = get_peft_model(model, lora_config)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
for i, batch in enumerate(train_dataloader):
outputs = model(batch)
loss = outputs[0]
loss.backward()
optimizer.step()
optimizer.zero_grad()
What am I doing wrong? I don't even mention the XXL model that should fit in the A100 based on tutorial 2
BTW - in the third tutorial above I see that the count of trainable parameters does not match the expected number
trainable params: 56,164 || all params: 4,100,164 || trainable%: 1.369798866581922
but the number of updated ones in cell 16 and 17 is = 16000 x 3 + 160 + 4000 + 2 = 52,162 instead 56,164. Does any one know why?
The attached LoRA.py can easily run on a single A100 GPU though. Any thoughts?
lora.txt