@@ -215,17 +215,27 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep,
215215 if not hasattr (self , '_autocast_dtype' ):
216216 self ._autocast_dtype = torch .float16 if os .environ .get ('ACCELERATE_MIXED_PRECISION' , 'fp16' ) == 'fp16' else torch .bfloat16
217217 if os .environ .get ('UNSLOTH_FORCE_FLOAT32' , '0' ) == '1' : self ._autocast_dtype = torch .float16
218- # os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
218+
219+ os .environ ["UNSLOTH_RETURN_HIDDEN_STATES" ] = "1"
219220 with torch .amp .autocast (device_type = 'cuda' , dtype = self ._autocast_dtype ):
220221 # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
221- logits = model (input_ids = input_ids , attention_mask = attention_mask , logits_to_keep = logits_to_keep + 1 ).logits
222- logits = logits [:, :- 1 , :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
223- # logits = logits.to(torch.float32)
224- input_ids = input_ids [:, - logits_to_keep :]
222+ hidden_states = model (input_ids = input_ids , attention_mask = attention_mask , logits_to_keep = logits_to_keep + 1 ).logits
223+ # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
224+ return hidden_states
225+ # input_ids = input_ids[:, -logits_to_keep:]
225226 # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
226227 # See https://github.com/huggingface/trl/issues/2770
227- logits = logits [:, - logits_to_keep :]
228- return logits
228+ # logits = logits[:, -logits_to_keep:]
229+ # return logits
230+ # logps = selective_log_softmax(logits, input_ids)
231+
232+ # row_indices, col_indices = torch.where(logps < -20)
233+
234+ # # Method 1: Check if tensors have elements
235+ # if len(row_indices) > 0 and len(col_indices) > 0:
236+ # breakpoint() # Breakpoint triggered here
237+ # print("Found high values!")
238+ # return logps # compute logprobs for the input tokens
229239 pass
230240 pass
231241
@@ -280,20 +290,21 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
280290 # per_token_loss = -(per_token_loss - self.beta * per_token_kl)
281291 # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
282292
283- old_per_token_logps = inputs ["old_per_token_logps" ]
293+ old_hidden_states = inputs ["old_per_token_logps" ]
284294 input_ids = input_ids [:, - logits_to_keep :]
285295 if per_token_logps is not None :
286296 loss , completion_length , mean_kl = grpo_compute_loss_slow (
287- ref_per_token_logps , per_token_logps , old_per_token_logps , input_ids , completion_mask , self .beta , advantages ,
297+ ref_per_token_logps , per_token_logps , old_hidden_states , input_ids , completion_mask , self .beta , advantages ,
288298 loss_type = self .args .loss_type ,
289299 epsilon_low = self .epsilon_low , epsilon_high = self .epsilon_high ,
290300 max_completion_length = self .args .max_completion_length ,
291301 delta = self .args .delta ,
292302 )
293303 else :
294304 loss , completion_length , mean_kl = grpo_accumulated_loss (
295- self , _input_ids , logits_to_keep , completion_mask , advantages , old_per_token_logps ,
296- n_chunks = self .args .unsloth_num_chunks , loss_type = self .args .loss_type ,
305+ self , _input_ids , logits_to_keep , completion_mask , advantages , old_hidden_states ,
306+ n_chunks = self .args .unsloth_num_chunks ,
307+ loss_type = self .args .loss_type ,
297308 epsilon_low = self .epsilon_low , epsilon_high = self .epsilon_high ,
298309 max_completion_length = self .args .max_completion_length ,
299310 delta = self .args .delta ,
0 commit comments