Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions src/transformers/test_simple_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3
"""
Simple test to demonstrate the DataParallel num_items_in_batch fix.
This will show the exact difference in behavior between old and new trainers.
"""

import sys
from pathlib import Path

import torch

import shutil
import tempfile

from datasets import Dataset

from transformers import (
AutoConfig,
AutoTokenizer,
DataCollatorForLanguageModeling,
GPTNeoXForCausalLM,
TrainingArguments,
)

# Import both trainer versions to compare
from transformers.trainer_new import Trainer as NewTrainer
from transformers.trainer_old import Trainer as OldTrainer


def create_simple_test():

print(f"Found {torch.cuda.device_count()} different GPUs")

# Setup model and data
config = AutoConfig.from_pretrained("EleutherAI/pythia-14m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
tokenizer.pad_token = tokenizer.eos_token

# Create simple test data
test_texts = ["Hello world this is a test. " * 8] * 10
dataset = Dataset.from_dict({"text": test_texts})

def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, max_length=32, padding="max_length")

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Test both trainers
for trainer_name, TrainerClass in [("OLD", OldTrainer), ("NEW", NewTrainer)]:
print(f"\n{'-' * 40}")
print(f"Testing {trainer_name} Trainer")
print(f"{'-' * 40}")

# Create fresh model
torch.manual_seed(42)
model = GPTNeoXForCausalLM(config=config)

temp_dir = tempfile.mkdtemp(prefix=f"test_{trainer_name.lower()}_")

try:
training_args = TrainingArguments(
output_dir=temp_dir,
per_device_train_batch_size=3,
gradient_accumulation_steps=5, # This will show the difference clearly
max_steps=2,
logging_steps=1,
save_steps=1000,
learning_rate=1e-4,
dataloader_drop_last=True,
report_to=[],
disable_tqdm=True,
)

trainer = TrainerClass(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
processing_class=tokenizer,
data_collator=data_collator,
)

print(f"Starting {trainer_name} trainer...")
trainer.train()

except Exception as e:
print(f"Error in {trainer_name} trainer: {e}")

finally:
shutil.rmtree(temp_dir, ignore_errors=True)


if __name__ == "__main__":
create_simple_test()
21 changes: 20 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2669,7 +2669,26 @@ def _inner_training_loop(
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
# Fix for DataParallel: Adjust token count for per-GPU loss scaling
# When using multiple GPUs with DataParallel, each GPU processes a different slice
# of the batch, so we need to calculate the actual tokens each GPU will see
current_num_items_in_batch = num_items_in_batch
if num_items_in_batch is not None and self.args.n_gpu > 1 and "labels" in inputs:
try:
# Count non-padding tokens in this specific batch (inputs represents one batch)
full_batch_tokens = (inputs["labels"].ne(-100)).sum() # -100 = padding tokens => ignored
tokens_per_gpu = full_batch_tokens // self.args.n_gpu # each GPU sees 1/n_gpu tokens
current_num_items_in_batch = tokens_per_gpu # Update to the per-GPU token count

# Convert scalar tensor to 1D tensor if needed (required by some model implementations)
# This ensures compatibility with models expecting a 1D tensor for num_items_in_batch
if current_num_items_in_batch.dim() == 0:
current_num_items_in_batch = current_num_items_in_batch.unsqueeze(0)
except (TypeError, AttributeError):
# Fallback for non-standard label formats (e.g., object detection, custom models)
current_num_items_in_batch = num_items_in_batch

tr_loss_step = self.training_step(model, inputs, current_num_items_in_batch)

if (
args.logging_nan_inf_filter
Expand Down
Loading