Skip to content

Conversation

@pluesclues
Copy link

No description provided.

# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
logits = logits.to(torch.float32)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Please remove this

hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
#breakpoint()
#logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return hidden_states
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're returning here, we don't need the code below. Remove perhaps?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to leave it there for debugging purposes, dan had the old code below because of that

logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
logits = logits.to(torch.float32)
hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: These are logits. can we rename appropriately

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually return those as hidden states so we save memory and can calculate logits on the fly over here:

https://github.com/pluesclues/unsloth-zoo/blob/8b5bfe233f819aac89876025d121d6af28713af6/unsloth_zoo/rl_replacements.py#L127-L129

@Datta0 Datta0 merged commit 2961e2d into Datta0:trl_upgrade_fix May 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants