diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 4965fdab032..59dd02283bc 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -573,9 +573,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if "attention_mask" in inputs: num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() elif "position_ids" in inputs: - num_tokens_in_batch = ( - self.accelerator.gather_for_metrics(torch.tensor(inputs["position_ids"].size(1))).sum().item() - ) + local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device) + num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item() else: raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") self._total_train_tokens += num_tokens_in_batch