-
Notifications
You must be signed in to change notification settings - Fork 0
Update rl_replacements.py returned hidden states from logprobs #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -214,18 +214,28 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, | |
| if not hasattr(self, '_autocast_dtype'): | ||
| self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 | ||
| if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 | ||
| os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" | ||
| os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" | ||
| with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): | ||
| # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded | ||
| logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits | ||
| logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
| logits = logits.to(torch.float32) | ||
| hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits | ||
| #breakpoint() | ||
| #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
| return hidden_states | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're returning here, we don't need the code below. Remove perhaps?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to leave it there for debugging purposes, dan had the old code below because of that |
||
| input_ids = input_ids[:, -logits_to_keep:] | ||
| # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. | ||
| # See https://github.com/huggingface/trl/issues/2770 | ||
| logits = logits[:, -logits_to_keep:] | ||
| logits = logits.to(torch.float32) | ||
|
||
| #return logits | ||
| return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens | ||
| logps = selective_log_softmax(logits, input_ids) | ||
|
|
||
| # row_indices, col_indices = torch.where(logps < -20) | ||
|
|
||
| # # Method 1: Check if tensors have elements | ||
| # if len(row_indices) > 0 and len(col_indices) > 0: | ||
| # breakpoint() # Breakpoint triggered here | ||
| # print("Found high values!") | ||
| return logps # compute logprobs for the input tokens | ||
| pass | ||
| pass | ||
|
|
||
|
|
@@ -280,16 +290,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch | |
| # per_token_loss = -(per_token_loss - self.beta * per_token_kl) | ||
| # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() | ||
|
|
||
| old_per_token_logps = inputs["old_per_token_logps"] | ||
| old_hidden_states = inputs["old_per_token_logps"] | ||
| input_ids = input_ids[:, -logits_to_keep:] | ||
| #breakpoint() | ||
| if per_token_logps is not None: | ||
| # | ||
| loss, completion_length, mean_kl = grpo_compute_loss_slow( | ||
| ref_per_token_logps, old_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, | ||
| old_hidden_states, per_token_logps, ref_per_token_logps, input_ids, completion_mask, self.beta, advantages, | ||
| ) | ||
| else: | ||
| loss, completion_length, mean_kl = grpo_accumulated_loss( | ||
| self, _input_ids, logits_to_keep, completion_mask, advantages, old_per_token_logps, | ||
| self, _input_ids, logits_to_keep, completion_mask, advantages, old_hidden_states, | ||
| n_chunks = self.args.unsloth_num_chunks | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: These are logits. can we rename appropriately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually return those as hidden states so we save memory and can calculate logits on the fly over here:
https://github.com/pluesclues/unsloth-zoo/blob/8b5bfe233f819aac89876025d121d6af28713af6/unsloth_zoo/rl_replacements.py#L127-L129