Skip to content
Merged
Show file tree
Hide file tree
Changes from 140 commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
552e899
Refactor image handling: replace `image_split_sizes` with `image_grid…
qgallouedec Sep 19, 2025
449ef07
simpler
qgallouedec Sep 19, 2025
c8933aa
gfpo
qgallouedec Sep 19, 2025
229c554
multi-image grpo
qgallouedec Sep 19, 2025
3ca6ad5
log with wandb
qgallouedec Sep 19, 2025
dcf4b92
no vlm reward models
qgallouedec Sep 20, 2025
30ad7ca
rloo
qgallouedec Sep 20, 2025
86cc30b
gfpo
qgallouedec Sep 20, 2025
088897b
fix
qgallouedec Sep 20, 2025
d2adc63
test peft
qgallouedec Sep 20, 2025
f4c82bf
fix gfpo
qgallouedec Sep 20, 2025
1257796
rloo test
qgallouedec Sep 20, 2025
099a39b
peft rloo
qgallouedec Sep 20, 2025
529add6
oops
qgallouedec Sep 20, 2025
fc6b11f
update test
qgallouedec Sep 20, 2025
ae1f497
generate method
qgallouedec Sep 20, 2025
f998432
debug
qgallouedec Sep 20, 2025
fa73876
skip failing test
qgallouedec Sep 20, 2025
52d8bd9
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 20, 2025
dfc0d38
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 20, 2025
fc52e68
test fixed!
qgallouedec Sep 20, 2025
4d12aeb
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 20, 2025
4fc2b5b
gfpo
qgallouedec Sep 20, 2025
b628744
rm vllm
qgallouedec Sep 20, 2025
d3a769f
fix doc
qgallouedec Sep 20, 2025
e17ec42
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 22, 2025
efbb03a
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 22, 2025
562c662
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
485781c
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
05270f8
update layers to ignore
qgallouedec Sep 22, 2025
1c53094
clarify image column desc
qgallouedec Sep 22, 2025
9b6652e
rm VLM x RM warning
qgallouedec Sep 23, 2025
c500440
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 23, 2025
a6a8c44
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
d8665e1
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
365d501
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
cdb4c76
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
c83e710
same for rloo
qgallouedec Sep 24, 2025
ec6ad25
nits style and align
qgallouedec Sep 24, 2025
b4cadde
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
b0dceb9
restart
qgallouedec Sep 25, 2025
ebe32c2
progress
qgallouedec Sep 25, 2025
0213662
progress continues
qgallouedec Sep 25, 2025
8b3a724
progress again again
qgallouedec Sep 25, 2025
c1ae6aa
back to working point
qgallouedec Sep 25, 2025
1a66b43
revert chage data utils
qgallouedec Sep 25, 2025
2dc69a6
Merge branch 'main' into generate-method
qgallouedec Sep 26, 2025
9435a94
refactor in grpo
qgallouedec Sep 26, 2025
d3f1d3c
Merge branch 'main' into refactor_generate
qgallouedec Sep 26, 2025
3d8ea27
wrong merge commit
qgallouedec Sep 26, 2025
27dc958
fix num_input_tokens_seen
qgallouedec Sep 26, 2025
53772ef
getting closer
qgallouedec Sep 26, 2025
8766fa5
consistent naming
qgallouedec Sep 26, 2025
236b78b
better
qgallouedec Sep 26, 2025
9da4830
simplify a bit + comment
qgallouedec Sep 26, 2025
b3bd0b0
another one
qgallouedec Sep 26, 2025
d79b9e1
get prompt ids from generation
qgallouedec Sep 26, 2025
8d34d54
remove pad token removal
qgallouedec Sep 26, 2025
e770efe
Merge branch 'refactor_generate' into refactor_generate_2
qgallouedec Sep 26, 2025
0e2ae34
rely on generator for prompt truncation
qgallouedec Sep 26, 2025
46d8eb7
revert
qgallouedec Sep 26, 2025
11acc75
rm enforce eager
qgallouedec Sep 26, 2025
acee7d8
rm truncate_with_protected_tokens
qgallouedec Sep 26, 2025
0b5865e
ensure proper truncation and side
qgallouedec Sep 26, 2025
d8af003
rm useless comment
qgallouedec Sep 26, 2025
fc263a3
rm imports
qgallouedec Sep 26, 2025
35f99fd
requires padding
qgallouedec Sep 26, 2025
8149d05
rm truncation test
qgallouedec Sep 26, 2025
9925199
move forward_kwargs outside of generate
qgallouedec Sep 26, 2025
48a1c30
don't re-prepare data
qgallouedec Sep 26, 2025
15c6620
refactor: update prepare_multimodal_messages to accept images directl…
qgallouedec Sep 26, 2025
55a2480
rloo + doc
qgallouedec Sep 26, 2025
c8041e1
Merge branch 'refactor_generate' into refactor_generate_2
qgallouedec Sep 26, 2025
b8c0c9b
Merge branch 'refactor_generate_2' into refactor_generate_3
qgallouedec Sep 26, 2025
7b7a11d
test and doc
qgallouedec Sep 27, 2025
c5064d6
gfpo
qgallouedec Sep 27, 2025
effb41b
Merge branch 'main' into refactor_generate
qgallouedec Sep 27, 2025
e82bfb4
Merge branch 'main' into refactor_generate
qgallouedec Sep 27, 2025
4b9c126
Merge branch 'refactor_generate' into refactor_generate_2
qgallouedec Sep 27, 2025
3f02702
Merge branch 'refactor_generate_2' into refactor_generate_3
qgallouedec Sep 27, 2025
b0e0279
Merge branch 'refactor_generate_3' into refactor_generate_4
qgallouedec Sep 27, 2025
a01b9ca
Merge branch 'refactor_generate_4' into refactor_generate_5
qgallouedec Sep 27, 2025
f11759e
Merge branch 'main' into refactor_generate_2
qgallouedec Sep 30, 2025
e7aa945
fix vllm client server
qgallouedec Sep 30, 2025
e164ec5
repicate all_prompt_ids
qgallouedec Oct 1, 2025
49577ad
Same for RLOO
qgallouedec Oct 1, 2025
5fca5b8
fix normal generation path
qgallouedec Oct 1, 2025
5cc6af5
Merge branch 'refactor_generate_2' into refactor_generate_3
qgallouedec Oct 1, 2025
4dce145
remove vision tokens
qgallouedec Oct 1, 2025
ddfd3b5
same for rloo
qgallouedec Oct 1, 2025
c434fa2
truncation_side=left
qgallouedec Oct 1, 2025
377b081
rm test_training_vlm_and_prompt_truncation
qgallouedec Oct 1, 2025
d599c20
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 1, 2025
e82db74
🔣 Fix test: replace `trainer.tokenizer` by `trainer.processing_class`…
qgallouedec Oct 1, 2025
192deb3
Fix CI ImportError: FlashAttention2 and decorator order for all param…
albertvillanova Oct 1, 2025
cf9d8e7
Hotfix wrong formatting of docstrings with blockquote tips (#4187)
albertvillanova Oct 1, 2025
f9c3c3c
🌡️ Have vLLM return processed (temperature scaled) log probs (#4163)
YonatanGideoni Oct 1, 2025
6489479
Replace remaining trainer.tokenizer with trainer.processing_class in …
albertvillanova Oct 3, 2025
21a67fc
[DOCS] Lora without regret (#4181)
burtenshaw Oct 3, 2025
c1e7ad2
[DOCS/FIX] lora without regrets - fix lr (#4207)
burtenshaw Oct 6, 2025
5d34144
Remove custome_container for building the docs (#4198)
albertvillanova Oct 6, 2025
ae2a0e7
Remove tokenizer creation from `sft` example script (#4197)
sergiopaniego Oct 6, 2025
6543f51
Hotfix: Exclude transformers 4.57.0 for Python 3.9 (#4209)
albertvillanova Oct 6, 2025
8319ce0
Replace unittest with pytest (#4188)
albertvillanova Oct 6, 2025
4fdaa4c
Updated vLLM integration guide (#4162)
sergiopaniego Oct 6, 2025
d258e36
Remove `Optional` from `processing_class` in `PPOTrainer` (#4212)
sergiopaniego Oct 6, 2025
7f5b499
Replace setup with pyproject and fix packaging unintended modules (#4…
albertvillanova Oct 6, 2025
df386f9
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 6, 2025
5b9a6ab
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 6, 2025
766bbce
Merge branch 'refactor_generate_2' into refactor_generate_3
qgallouedec Oct 6, 2025
ac2717f
Merge branch 'refactor_generate_3' into refactor_generate_4
qgallouedec Oct 6, 2025
4a274d5
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 6, 2025
db552be
Merge branch 'refactor_generate_2' into refactor_generate_3
qgallouedec Oct 6, 2025
2c012dc
Merge branch 'refactor_generate_3' into refactor_generate_4
qgallouedec Oct 6, 2025
cb1d420
Merge branch 'refactor_generate_4' into refactor_generate_5
qgallouedec Oct 6, 2025
a84325c
style
qgallouedec Oct 6, 2025
34034e7
Merge branch 'refactor_generate_3' into refactor_generate_4
qgallouedec Oct 6, 2025
2ce6c1f
token_type_ids and RLOO
qgallouedec Oct 6, 2025
ddf3405
gfpo
qgallouedec Oct 6, 2025
e3c679c
style
qgallouedec Oct 6, 2025
ee03478
remove test case for prompt truncation
qgallouedec Oct 7, 2025
ed54e2a
Merge branch 'refactor_generate_3' into refactor_generate_4
qgallouedec Oct 7, 2025
5e4a026
Merge branch 'refactor_generate_4' into refactor_generate_5
qgallouedec Oct 7, 2025
45290c9
Merge branch 'main' into refactor_generate_3
qgallouedec Oct 7, 2025
a0ee1e6
Merge branch 'refactor_generate_3' into refactor_generate_4
qgallouedec Oct 7, 2025
f6e7c20
Merge branch 'refactor_generate_4' into refactor_generate_5
qgallouedec Oct 7, 2025
919ff5b
Merge branch 'main' into refactor_generate_5
qgallouedec Oct 17, 2025
fe11512
dedup and some fixes
qgallouedec Oct 18, 2025
c0c8807
fix style
qgallouedec Oct 18, 2025
ba8b938
rloo
qgallouedec Oct 18, 2025
7a2936e
style
qgallouedec Oct 18, 2025
1a6f040
test
qgallouedec Oct 18, 2025
ced5450
safe prepare_multimodal_messages_vllm
qgallouedec Oct 18, 2025
23d13f9
oops
qgallouedec Oct 18, 2025
5f87ee9
fix return-dict
qgallouedec Oct 18, 2025
31913e2
fix prepare_multimodal_messages
qgallouedec Oct 18, 2025
ff6782a
fix: update documentation for prepare_multimodal_messages_vllm
qgallouedec Oct 18, 2025
7bb1ee0
require vision
qgallouedec Oct 18, 2025
0739b1f
fix import
qgallouedec Oct 19, 2025
89bbe0d
Merge branch 'main' into refactor_generate_5
qgallouedec Oct 20, 2025
1a77ba7
Merge branch 'main' into refactor_generate_5
qgallouedec Oct 21, 2025
a75b790
style and type hint
qgallouedec Oct 21, 2025
32c7880
expose prepare_multimodal_messages_vllm
qgallouedec Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/data_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

[[autodoc]] prepare_multimodal_messages

## prepare_multimodal_messages_vllm

[[autodoc]] prepare_multimodal_messages_vllm

## is_conversational

[[autodoc]] is_conversational
Expand Down
203 changes: 176 additions & 27 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from datasets import Dataset, DatasetDict
from parameterized import parameterized
from transformers import AutoProcessor, AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer, is_vision_available

from trl.data_utils import (
apply_chat_template,
Expand All @@ -32,44 +32,66 @@
maybe_unpair_preference_dataset,
pack_dataset,
prepare_multimodal_messages,
prepare_multimodal_messages_vllm,
truncate_dataset,
unpair_preference_dataset,
)

from .testing_utils import TrlTestCase
from .testing_utils import TrlTestCase, require_vision


if is_vision_available():
from PIL import Image


@require_vision
class TestPrepareMultimodalMessages:
def test_basic_user_assistant_conversation(self):
"""Test basic conversation with user and assistant messages."""
messages = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]

prepare_multimodal_messages(messages, num_images=1)
image = Image.new("RGB", (10, 10), color="blue")
messages = prepare_multimodal_messages(messages, images=[image])

expected = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
]

assert messages == expected

def test_first_user_message_gets_image(self):
"""Test that only the first user message gets an image placeholder."""
"""Test that only the first user message gets an image."""
messages = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
{"role": "user", "content": "How about the grass?"},
]

prepare_multimodal_messages(messages, num_images=1)
image = Image.new("RGB", (10, 10), color="blue")
messages = prepare_multimodal_messages(messages, images=[image])

expected = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]},
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
{
"role": "user",
"content": [{"type": "text", "text": "How about the grass?"}],
},
]

assert messages == expected
Expand All @@ -80,20 +102,23 @@ def test_multiple_images(self):
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]

prepare_multimodal_messages(messages, num_images=3)
images = [Image.new("RGB", (10, 10), color=color) for color in ["red", "green", "blue"]]
messages = prepare_multimodal_messages(messages, images=images)

expected = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image"},
{"type": "image"},
{"type": "image", "image": images[0]},
{"type": "image", "image": images[1]},
{"type": "image", "image": images[2]},
{"type": "text", "text": "What color is the sky?"},
],
},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
]

assert messages == expected
Expand All @@ -105,11 +130,18 @@ def test_system_message_transformation(self):
{"role": "user", "content": "What color is the sky?"},
]

prepare_multimodal_messages(messages, num_images=1)
image = Image.new("RGB", (10, 10), color="blue")
messages = prepare_multimodal_messages(messages, images=[image])

expected = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]},
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant"}],
},
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
]

assert messages == expected
Expand All @@ -122,10 +154,25 @@ def test_already_prepared_messages_unchanged(self):
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
]

original = copy.deepcopy(messages)
prepare_multimodal_messages(messages, num_images=1)
image = Image.new("RGB", (10, 10), color="blue")
messages = prepare_multimodal_messages(messages, images=[image])

assert messages == original
expected = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant"}],
},
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
]

assert messages == expected

def test_mixed_prepared_and_unprepared_messages(self):
"""Test handling of mixed prepared and unprepared messages."""
Expand All @@ -135,17 +182,119 @@ def test_mixed_prepared_and_unprepared_messages(self):
{"role": "user", "content": "What about the grass?"},
]

prepare_multimodal_messages(messages, num_images=1)
image = Image.new("RGB", (10, 10), color="blue")
messages = prepare_multimodal_messages(messages, images=[image])

expected = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]},
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
{
"role": "user",
"content": [{"type": "text", "text": "What about the grass?"}],
},
]

assert messages == expected


@require_vision
class TestPrepareMultimodalMessagesVLLM:
def test_single_image_conversion(self):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": Image.new("RGB", (10, 10), color="blue")},
{"type": "text", "text": "What color is the sky?"},
],
}
]

result = prepare_multimodal_messages_vllm(messages)

# Original should remain unchanged (deepcopy test)
assert messages[0]["content"][0]["type"] == "image"

# Converted version should have correct structure
assert result[0]["content"][0]["type"] == "image_pil"
assert "image_pil" in result[0]["content"][0]
assert "image" not in result[0]["content"][0]
assert isinstance(result[0]["content"][0]["image_pil"], Image.Image)
assert result[0]["content"][1]["type"] == "text"

def test_mixed_content_conversion(self):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What color is the sky?"},
{"type": "image", "image": Image.new("RGB", (10, 10), color="blue")},
],
}
]

result = prepare_multimodal_messages_vllm(messages)

# The image part should be converted, text should be unchanged
assert result[0]["content"][0]["type"] == "text"
assert result[0]["content"][1]["type"] == "image_pil"

def test_no_images(self):
messages = [{"role": "user", "content": [{"type": "text", "text": "What color is the sky?"}]}]

result = prepare_multimodal_messages_vllm(messages)

# Should be identical since there are no images
assert result == messages
# And a deepcopy — not the same object
assert result is not messages
assert result[0] is not messages[0]

def test_multiple_messages(self):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What color is the sky?"},
{"type": "image", "image": Image.new("RGB", (10, 10), color="blue")},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
]

result = prepare_multimodal_messages_vllm(messages)

assert result[0]["content"][1]["type"] == "image_pil"
assert result[1]["content"][0]["type"] == "text"
assert result[1]["content"][0]["text"] == "It is blue."

def test_deepcopy_integrity(self):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What color is the sky?"},
{"type": "image", "image": Image.new("RGB", (10, 10), color="blue")},
],
},
]
original = copy.deepcopy(messages)

_ = prepare_multimodal_messages_vllm(messages)

# Original should not be mutated
assert messages == original


class TestIsConversational(TrlTestCase):
conversational_examples = [
{ # Language modeling
Expand Down
Loading
Loading