diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9b0f4e4ae..4cf9174f2 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -291,6 +291,58 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) +def grpo_trainer__get_per_token_logps_and_entropies(function_name, function): + if function_name != "_get_per_token_logps_and_entropies": return function + + # Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway + def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None, compute_entropy = False): + if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': + return {"logps": None, "entropies": None} # Unsloth efficient GRPO + # Otherwise, calculate normally: + if not hasattr(self, '_autocast_dtype'): + self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 + if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 + + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" + with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model( + input_ids = input_ids, + attention_mask = attention_mask, + logits_to_keep = logits_to_keep + 1, + ).logits + + entropies = None + if compute_entropy: + from trl.trainer.utils import entropy_from_logits + entropies = entropy_from_logits(logits) + + # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + return {"logps": logits, "entropies": entropies} + # input_ids = input_ids[:, -logits_to_keep:] + # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. + # See https://github.com/huggingface/trl/issues/2770 + # logits = logits[:, -logits_to_keep:] + # return logits + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + # logits = logits / self.temperature + # logps = selective_log_softmax(logits, input_ids) + + # row_indices, col_indices = torch.where(logps < -20) + + # # Method 1: Check if tensors have elements + # if len(row_indices) > 0 and len(col_indices) > 0: + # breakpoint() # Breakpoint triggered here + # print("Found high values!") + # return logps # compute logprobs for the input tokens + pass + pass + + function = inspect.getsource(_get_per_token_logps_and_entropies) + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps_and_entropies) + grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] grpo_compute_loss_slow = RL_REPLACEMENTS["grpo_compute_loss_slow"] UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] @@ -319,14 +371,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch _input_ids = input_ids _logits_to_keep = logits_to_keep - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + get_logps_func = lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False: self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, batch_size) if hasattr(self, "_get_per_token_logps") else self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy)['logps'] + + per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 if self.beta != 0.0: with torch.inference_mode(), model.disable_adapter(): - ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) else: ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1