Skip to content

Commit d1d0407

Browse files
authored
🏷️ Account for token_type_ids in DataCollatorForVisionLanguageModeling (#4190)
1 parent 824ff8c commit d1d0407

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

tests/test_sft_trainer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,38 @@ def test_train_vlm_prompt_completion(self):
14411441
new_param = trainer.model.get_parameter(n)
14421442
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
14431443

1444+
# Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator.
1445+
@require_vision
1446+
def test_train_vlm_prompt_completion_gemma(self):
1447+
# Get the dataset
1448+
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train")
1449+
1450+
# Initialize the trainer
1451+
training_args = SFTConfig(
1452+
output_dir=self.tmp_dir,
1453+
max_length=None, # For VLMs, truncating can remove image tokens, leading to errors
1454+
report_to="none",
1455+
)
1456+
trainer = SFTTrainer(
1457+
model="trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
1458+
args=training_args,
1459+
train_dataset=dataset,
1460+
)
1461+
1462+
# Save the initial parameters to compare them later
1463+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1464+
1465+
# Train the model
1466+
trainer.train()
1467+
1468+
# Check that the training loss is not None
1469+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
1470+
1471+
# Check the params have changed
1472+
for n, param in previous_trainable_params.items():
1473+
new_param = trainer.model.get_parameter(n)
1474+
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
1475+
14441476
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
14451477
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
14461478
@pytest.mark.slow

trl/trainer/sft_trainer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,15 +424,26 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str
424424
input_ids = torch.cat((prompt_ids, completion_ids), dim=1)
425425
attention_mask = torch.cat((prompt_mask, completion_mask), dim=1)
426426
completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1)
427+
if "token_type_ids" in processed_prompts: # special case for Gemma
428+
prompt_token_type_ids = processed_prompts["token_type_ids"]
429+
completion_token_type_ids = processed_completions["token_type_ids"]
430+
token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1)
427431

428432
# Flush left to reduce padding
429-
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
433+
if "token_type_ids" in processed_prompts:
434+
attention_mask, input_ids, completion_mask, token_type_ids = flush_left(
435+
attention_mask, input_ids, completion_mask, token_type_ids
436+
)
437+
else:
438+
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
430439

431440
# Truncate if necessary
432441
if self.max_length is not None:
433442
input_ids = input_ids[:, : self.max_length]
434443
attention_mask = attention_mask[:, : self.max_length]
435444
completion_mask = completion_mask[:, : self.max_length]
445+
if "token_type_ids" in processed_prompts:
446+
token_type_ids = token_type_ids[:, : self.max_length]
436447

437448
# Create labels and mask padding tokens
438449
labels = input_ids.clone()
@@ -445,6 +456,8 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str
445456
output["input_ids"] = input_ids
446457
output["attention_mask"] = attention_mask
447458
output["labels"] = labels
459+
if "token_type_ids" in processed_prompts:
460+
output["token_type_ids"] = token_type_ids
448461
return output
449462

450463

0 commit comments

Comments
 (0)