diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 9014470b55c..01c54487117 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -818,18 +818,19 @@ def __init__( ) # Loss function - if args.loss_type == "nll": - pass # use the default loss - elif args.loss_type == "dft": - if compute_loss_func is not None: - raise ValueError( - "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " - "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a " - "`compute_loss_func` is not allowed." - ) - compute_loss_func = dft_loss - else: - raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") + if not args.use_liger_kernel: # liger supports dft loss by just passing use_token_scaling=True + if args.loss_type == "nll": + pass # use the default loss + elif args.loss_type == "dft": + if compute_loss_func is not None: + raise ValueError( + "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " + "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so " + "passing a `compute_loss_func` is not allowed." + ) + compute_loss_func = dft_loss + else: + raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} @@ -1095,6 +1096,11 @@ def compute_loss( # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing inputs["use_cache"] = False + # Request token accuracy from Liger kernel and set token scaling if using DFT loss + if self.args.use_liger_kernel: + inputs["return_token_accuracy"] = True + inputs["use_token_scaling"] = self.args.loss_type == "dft" + (loss, outputs) = super().compute_loss( model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch ) @@ -1133,8 +1139,12 @@ def compute_loss( self._total_train_tokens += num_tokens_in_batch self._metrics[mode]["num_tokens"] = [self._total_train_tokens] - # Compute token accuracy if we have labels and if the model is not using Liger (no logits) - if not self.args.use_liger_kernel: + if self.args.use_liger_kernel: + if hasattr(outputs, "token_accuracy") and outputs.token_accuracy is not None: + token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item() + self._metrics[mode]["mean_token_accuracy"].append(token_accuracy) + else: + # Compute accuracy from logits using argmax (traditional method) with torch.no_grad(): if "shift_labels" in inputs: # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because: @@ -1172,10 +1182,12 @@ def compute_loss( total_sum = total_tokens.sum() accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 self._metrics[mode]["mean_token_accuracy"].append(accuracy) - if self.aux_loss_enabled: - aux_loss = outputs.aux_loss - aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() - self._metrics[mode]["aux_loss"].append(aux_loss) + + # Log auxiliary loss if enabled (applies to both Liger and non-Liger) + if self.aux_loss_enabled: + aux_loss = outputs.aux_loss + aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() + self._metrics[mode]["aux_loss"].append(aux_loss) return (loss, outputs) if return_outputs else loss