Skip to content

Commit 1aa2e37

Browse files
committed
Merge remote-tracking branch 'plues/trl_update' into trl_upgrade_fix
2 parents 1aa7aa1 + 4becc0c commit 1aa2e37

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

unsloth/models/rl_replacements.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -215,17 +215,27 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep,
215215
if not hasattr(self, '_autocast_dtype'):
216216
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
217217
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
218-
# os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
218+
219+
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
219220
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
220221
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
221-
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
222-
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
223-
# logits = logits.to(torch.float32)
224-
input_ids = input_ids[:, -logits_to_keep:]
222+
hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
223+
#logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
224+
return hidden_states
225+
# input_ids = input_ids[:, -logits_to_keep:]
225226
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
226227
# See https://github.com/huggingface/trl/issues/2770
227-
logits = logits[:, -logits_to_keep:]
228-
return logits
228+
# logits = logits[:, -logits_to_keep:]
229+
# return logits
230+
# logps = selective_log_softmax(logits, input_ids)
231+
232+
# row_indices, col_indices = torch.where(logps < -20)
233+
234+
# # Method 1: Check if tensors have elements
235+
# if len(row_indices) > 0 and len(col_indices) > 0:
236+
# breakpoint() # Breakpoint triggered here
237+
# print("Found high values!")
238+
# return logps # compute logprobs for the input tokens
229239
pass
230240
pass
231241

@@ -280,20 +290,21 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
280290
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
281291
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
282292

283-
old_per_token_logps = inputs["old_per_token_logps"]
293+
old_hidden_states = inputs["old_per_token_logps"]
284294
input_ids = input_ids[:, -logits_to_keep:]
285295
if per_token_logps is not None:
286296
loss, completion_length, mean_kl = grpo_compute_loss_slow(
287-
ref_per_token_logps, per_token_logps, old_per_token_logps, input_ids, completion_mask, self.beta, advantages,
297+
ref_per_token_logps, per_token_logps, old_hidden_states, input_ids, completion_mask, self.beta, advantages,
288298
loss_type = self.args.loss_type,
289299
epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high,
290300
max_completion_length = self.args.max_completion_length,
291301
delta = self.args.delta,
292302
)
293303
else:
294304
loss, completion_length, mean_kl = grpo_accumulated_loss(
295-
self, _input_ids, logits_to_keep, completion_mask, advantages, old_per_token_logps,
296-
n_chunks = self.args.unsloth_num_chunks, loss_type = self.args.loss_type,
305+
self, _input_ids, logits_to_keep, completion_mask, advantages, old_hidden_states,
306+
n_chunks = self.args.unsloth_num_chunks,
307+
loss_type = self.args.loss_type,
297308
epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high,
298309
max_completion_length = self.args.max_completion_length,
299310
delta = self.args.delta,

0 commit comments

Comments
 (0)