Skip to content

Commit 5c7d3c4

Browse files
authored
💠 Fix multi-gpu padding free (huggingface#3245)
1 parent 909a480 commit 5c7d3c4

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

trl/trainer/sft_trainer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,9 +573,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
573573
if "attention_mask" in inputs:
574574
num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
575575
elif "position_ids" in inputs:
576-
num_tokens_in_batch = (
577-
self.accelerator.gather_for_metrics(torch.tensor(inputs["position_ids"].size(1))).sum().item()
578-
)
576+
local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device)
577+
num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item()
579578
else:
580579
raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.")
581580
self._total_train_tokens += num_tokens_in_batch

0 commit comments

Comments
 (0)