-
Notifications
You must be signed in to change notification settings - Fork 427
[Cross-entropy-loss] return mean token accuracy metric with CE loss #910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@vaibhavjindal would you be able to kindly review? |
|
@shimizust this will be a breaking change i believe BTW |
|
@kashif could you please elaborate on how it will be a breaking change? Will it break the intergration with transformers or trl? |
|
yes if someone is using the raw functions in their lib. then now that functions returns one more thing... but on the HF side this PR takes care of this |
|
@kashif got it. So if i understand correctly, it will make sure that liger remains compatible with newer versions from HF. However, just want to confirm it will break liger support with older transformers/trl versions? |
|
no i believe my changes here will work with older version of HF.. i just meant non-HF frameworks |
|
TRL relies on HF integration for the CE loss so in TRL I will just pin to the liger version that has these changes |
|
@vaibhavjindal let me fix up the new qwen3-vl model to update its API |
|
@vaibhavjindal all good from my side |
Thanks a lot! I will do some final checks on correctness and benchmarks and will try to get it merged soon. |
|
thank you so much.. also see here: huggingface/trl#4302 (comment) |
|
thanks @vaibhavjindal for the typo fix and making it more robust! |
Summary
Returns the mean token accuracy metric when minimizing the cross-entropy loss without materializing the logits
https://x.com/jeremyphoward/status/1703246293802586155
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence