Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,58 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps)

def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
if function_name != "_get_per_token_logps_and_entropies": return function

# Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway
def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None, compute_entropy = False):
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
return {"logps": None, "entropies": None} # Unsloth efficient GRPO
# Otherwise, calculate normally:
if not hasattr(self, '_autocast_dtype'):
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16

os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids = input_ids,
attention_mask = attention_mask,
logits_to_keep = logits_to_keep + 1,
).logits

entropies = None
if compute_entropy:
from trl.trainer.utils import entropy_from_logits
entropies = entropy_from_logits(logits)

# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return {"logps": logits, "entropies": entropies}
# input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
# logits = logits[:, -logits_to_keep:]
# return logits
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
# logits = logits / self.temperature
# logps = selective_log_softmax(logits, input_ids)

# row_indices, col_indices = torch.where(logps < -20)

# # Method 1: Check if tensors have elements
# if len(row_indices) > 0 and len(col_indices) > 0:
# breakpoint() # Breakpoint triggered here
# print("Found high values!")
# return logps # compute logprobs for the input tokens
pass
pass

function = inspect.getsource(_get_per_token_logps_and_entropies)
return function
pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps_and_entropies)

grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"]
grpo_compute_loss_slow = RL_REPLACEMENTS["grpo_compute_loss_slow"]
UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"]
Expand Down Expand Up @@ -319,14 +371,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
_input_ids = input_ids
_logits_to_keep = logits_to_keep

per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
get_logps_func = lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False: self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, batch_size) if hasattr(self, "_get_per_token_logps") else self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy)['logps']

per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)

# Compute the KL divergence between the model and the reference model
# _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
# https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
if self.beta != 0.0:
with torch.inference_mode(), model.disable_adapter():
ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)
else:
ref_per_token_logps = None
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
Expand Down