Hi,
The documentation of LED states:
To fine-tune LED on all 16384, it is necessary to enable gradient checkpointing by executing model.gradient_checkpointing_enable().
Moreover, @patrickvonplaten in this notebook initializes the model using AutoModelForSeq2SeqLM.from_pretrained("allenai/led-base-16384", gradient_checkpointing=True, use_cache=False).
I have a 24GB GPU, but the base model fits perfectly without the need of activating gradient checkpointing, which instead slows down performance by a large margin. I am using a batch size = 1 with 16 steps of gradient accumulation.
So my question is: is my approach (just gradient accumulation) enough for fine-tuning LED, or it is really necessary to use gradient checkpointing (for reasons I may have missed)? If so, what are the implications of using gradient checkpointing in this case, besides saving memory?
I was also wondering what exactly does the use_cache=False parameter in the from_pretrained method.
Thank you very much