Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,28 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep,
if not hasattr(self, '_autocast_dtype'):
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
logits = logits.to(torch.float32)
hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: These are logits. can we rename appropriately

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually return those as hidden states so we save memory and can calculate logits on the fly over here:

https://github.com/pluesclues/unsloth-zoo/blob/8b5bfe233f819aac89876025d121d6af28713af6/unsloth_zoo/rl_replacements.py#L127-L129

#breakpoint()
#logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return hidden_states
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're returning here, we don't need the code below. Remove perhaps?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to leave it there for debugging purposes, dan had the old code below because of that

input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
logits = logits.to(torch.float32)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Please remove this

#return logits
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
logps = selective_log_softmax(logits, input_ids)

# row_indices, col_indices = torch.where(logps < -20)

# # Method 1: Check if tensors have elements
# if len(row_indices) > 0 and len(col_indices) > 0:
# breakpoint() # Breakpoint triggered here
# print("Found high values!")
return logps # compute logprobs for the input tokens
pass
pass

Expand Down Expand Up @@ -280,16 +290,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

old_per_token_logps = inputs["old_per_token_logps"]
old_hidden_states = inputs["old_per_token_logps"]
input_ids = input_ids[:, -logits_to_keep:]
#breakpoint()
if per_token_logps is not None:
#
loss, completion_length, mean_kl = grpo_compute_loss_slow(
ref_per_token_logps, old_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
old_hidden_states, per_token_logps, ref_per_token_logps, input_ids, completion_mask, self.beta, advantages,
)
else:
loss, completion_length, mean_kl = grpo_accumulated_loss(
self, _input_ids, logits_to_keep, completion_mask, advantages, old_per_token_logps,
self, _input_ids, logits_to_keep, completion_mask, advantages, old_hidden_states,
n_chunks = self.args.unsloth_num_chunks
)

Expand Down