-
Notifications
You must be signed in to change notification settings - Fork 0
Update rl_replacements.py returned hidden states from logprobs #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
unsloth/models/rl_replacements.py
Outdated
| # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. | ||
| # See https://github.com/huggingface/trl/issues/2770 | ||
| logits = logits[:, -logits_to_keep:] | ||
| logits = logits.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Please remove this
| hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits | ||
| #breakpoint() | ||
| #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
| return hidden_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're returning here, we don't need the code below. Remove perhaps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to leave it there for debugging purposes, dan had the old code below because of that
| logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits | ||
| logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
| logits = logits.to(torch.float32) | ||
| hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: These are logits. can we rename appropriately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually return those as hidden states so we save memory and can calculate logits on the fly over here:
No description provided.