Add TREAD training for the FLUX model #1672
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Closes #1611
I haven't had time to test it much, but you can make route configs (CLI args) like:
It works, it does speed up training, and the speedup is proportion to the
selection_ratio(the closer that is to 1.0, the more tokens are dropped).I haven't messed with the router configurations much to see what works and what doesn't.
It also worked with masked loss, where masking prevents certain tokens from being dropped. This slows down training by using more tokens.
The token dropping was implemented for RoPE in FLUX, as this is required to make TREAD work. I'm unsure the implementation is 100% correct, but it seems to work.
The more tokens you drop, the higher loss seems to be when you start LoRA/LoKr training. It tends to correct fairly rapidly and the images look more normal. I'm unsure if it's the network just adjusting to the smaller amount of tokens being supplied in the intermediary layers.
There was also a bug that seemed to break training with masked loss for flux in the main branch. This was fixed with a
self.config.model_flavour == "kontext"guard.