Skip to content

Is gradient checkpointing really needed when fine-tuning LED on 16384 tokens? #16541

@caesar-one

Description

@caesar-one

Hi,

The documentation of LED states:

To fine-tune LED on all 16384, it is necessary to enable gradient checkpointing by executing model.gradient_checkpointing_enable().

Moreover, @patrickvonplaten in this notebook initializes the model using AutoModelForSeq2SeqLM.from_pretrained("allenai/led-base-16384", gradient_checkpointing=True, use_cache=False).

I have a 24GB GPU, but the base model fits perfectly without the need of activating gradient checkpointing, which instead slows down performance by a large margin. I am using a batch size = 1 with 16 steps of gradient accumulation.

So my question is: is my approach (just gradient accumulation) enough for fine-tuning LED, or it is really necessary to use gradient checkpointing (for reasons I may have missed)? If so, what are the implications of using gradient checkpointing in this case, besides saving memory?

I was also wondering what exactly does the use_cache=False parameter in the from_pretrained method.

Thank you very much

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions