Skip to content

Commit 0e57b4a

Browse files
authored
🧺 [3/N] Refactor _generate in GRPO/RLOO: Rely on generator for prompt truncation (#4153)
1 parent 98488e0 commit 0e57b4a

File tree

8 files changed

+55
-334
lines changed

8 files changed

+55
-334
lines changed

‎tests/test_grpo_trainer.py‎

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,47 +1471,6 @@ def reward_func(completions, **kwargs):
14711471
new_param = trainer.model.get_parameter(n)
14721472
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
14731473

1474-
@require_vision
1475-
def test_training_vlm_and_prompt_truncation(self):
1476-
# If not handled properly, prompt truncation may truncate image token
1477-
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
1478-
1479-
def reward_func(completions, **kwargs):
1480-
"""Reward function that rewards longer completions."""
1481-
return [float(len(completion[0]["content"])) for completion in completions]
1482-
1483-
training_args = GRPOConfig(
1484-
output_dir=self.tmp_dir,
1485-
learning_rate=0.1, # increase the learning rate to speed up the test
1486-
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
1487-
num_generations=3, # reduce the number of generations to reduce memory usage
1488-
max_completion_length=8, # reduce the completion length to reduce memory usage
1489-
max_prompt_length=18,
1490-
report_to="none",
1491-
)
1492-
trainer = GRPOTrainer(
1493-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1494-
reward_funcs=reward_func,
1495-
args=training_args,
1496-
train_dataset=dataset,
1497-
)
1498-
1499-
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1500-
1501-
trainer.train()
1502-
1503-
assert trainer.state.log_history[-1]["train_loss"] is not None
1504-
1505-
# Check that the params have changed
1506-
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
1507-
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
1508-
params_to_skip = ("model.visual.",)
1509-
for n, param in previous_trainable_params.items():
1510-
if n.startswith(params_to_skip):
1511-
continue
1512-
new_param = trainer.model.get_parameter(n)
1513-
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
1514-
15151474
@parameterized.expand(
15161475
[
15171476
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),

‎tests/test_rloo_trainer.py‎

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,47 +1212,6 @@ def reward_func(completions, **kwargs):
12121212
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
12131213
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
12141214

1215-
@require_vision
1216-
def test_training_vlm_and_prompt_truncation(self):
1217-
# If not handled properly, prompt truncation may truncate image token
1218-
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
1219-
1220-
def reward_func(completions, **kwargs):
1221-
"""Reward function that rewards longer completions."""
1222-
return [float(len(completion[0]["content"])) for completion in completions]
1223-
1224-
training_args = RLOOConfig(
1225-
output_dir=self.tmp_dir,
1226-
learning_rate=0.1, # increase the learning rate to speed up the test
1227-
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
1228-
num_generations=3, # reduce the number of generations to reduce memory usage
1229-
max_completion_length=8, # reduce the completion length to reduce memory usage
1230-
max_prompt_length=18,
1231-
report_to="none",
1232-
)
1233-
trainer = RLOOTrainer(
1234-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1235-
reward_funcs=reward_func,
1236-
args=training_args,
1237-
train_dataset=dataset,
1238-
)
1239-
1240-
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1241-
1242-
trainer.train()
1243-
1244-
assert trainer.state.log_history[-1]["train_loss"] is not None
1245-
1246-
# Check that the params have changed
1247-
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
1248-
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
1249-
params_to_skip = ("model.visual.",)
1250-
for n, param in previous_trainable_params.items():
1251-
if n.startswith(params_to_skip):
1252-
continue
1253-
new_param = trainer.model.get_parameter(n)
1254-
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
1255-
12561215
@parameterized.expand(
12571216
[
12581217
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),

‎tests/test_utils.py‎

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
shuffle_sequence_dict,
4343
split_pixel_values_by_grid,
4444
split_tensor_dict,
45-
truncate_with_protected_tokens,
4645
unsplit_pixel_values_by_grid,
4746
)
4847

@@ -1009,84 +1008,6 @@ def test_multi_images(self):
10091008
assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
10101009

10111010

1012-
class TestTruncateWithProtectedTokens(TrlTestCase):
1013-
def test_basic_example(self):
1014-
"""Test the basic example from the problem description."""
1015-
prompt_ids = [1, 2, 3, 4, 5]
1016-
protected_tokens = [2, 3]
1017-
target_length = 3
1018-
1019-
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
1020-
1021-
expected_ids = [2, 3, 5]
1022-
assert new_ids == expected_ids
1023-
1024-
def test_no_truncation_needed(self):
1025-
"""Test when target length equals current length."""
1026-
prompt_ids = [1, 2, 3]
1027-
protected_tokens = [2]
1028-
target_length = 3
1029-
1030-
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
1031-
1032-
assert new_ids == prompt_ids
1033-
1034-
def test_no_protected_tokens(self):
1035-
"""Test truncation with no protected tokens (normal right truncation)."""
1036-
prompt_ids = [1, 2, 3, 4, 5]
1037-
protected_tokens = []
1038-
target_length = 3
1039-
1040-
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
1041-
1042-
expected_ids = [3, 4, 5] # Last 3 tokens
1043-
assert new_ids == expected_ids
1044-
1045-
def test_all_tokens_protected(self):
1046-
"""Test when all remaining tokens are protected."""
1047-
prompt_ids = [1, 2, 3, 4, 5]
1048-
protected_tokens = [3, 4, 5]
1049-
target_length = 3
1050-
1051-
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
1052-
1053-
expected_ids = [3, 4, 5]
1054-
assert new_ids == expected_ids
1055-
1056-
def test_too_many_protected_tokens(self):
1057-
"""Test error when too many protected tokens for target length."""
1058-
prompt_ids = [1, 2, 3, 4, 5]
1059-
protected_tokens = [1, 2, 3, 4]
1060-
target_length = 3
1061-
1062-
with pytest.raises(ValueError):
1063-
truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
1064-
1065-
def test_single_batch_single_token(self):
1066-
"""Test edge case with single batch and single token."""
1067-
prompt_ids = [5]
1068-
protected_tokens = [5]
1069-
target_length = 1
1070-
1071-
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
1072-
1073-
assert new_ids == prompt_ids
1074-
1075-
def test_order_preservation(self):
1076-
"""Test that relative order is preserved."""
1077-
prompt_ids = [10, 2, 20, 3, 30, 40]
1078-
protected_tokens = [2, 3]
1079-
target_length = 4
1080-
1081-
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
1082-
1083-
# Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
1084-
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
1085-
expected_ids = [2, 3, 30, 40]
1086-
1087-
assert new_ids == expected_ids
1088-
1089-
10901011
class TestUnsplitPixelValuesByGrid(TrlTestCase):
10911012
def test_unsplit_correctly(self):
10921013
pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]

‎trl/extras/vllm_client.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def generate(
182182
top_k: int = -1,
183183
min_p: float = 0.0,
184184
max_tokens: int = 16,
185+
truncate_prompt_tokens: Optional[int] = None,
185186
guided_decoding_regex: Optional[str] = None,
186187
generation_kwargs: Optional[dict] = None,
187188
) -> list[list[int]]:
@@ -207,6 +208,10 @@ def generate(
207208
Minimum probability for sampling.
208209
max_tokens (`int`, *optional*, defaults to `16`):
209210
Maximum number of tokens to generate for each prompt.
211+
truncate_prompt_tokens (`int`, *optional*):
212+
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
213+
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
214+
disabled.
210215
guided_decoding_regex (`str`, *optional*):
211216
Regular expression to guide the decoding process.
212217
generation_kwargs (`dict`, *optional*):
@@ -246,6 +251,7 @@ def pil_to_base64(image):
246251
"top_k": top_k,
247252
"min_p": min_p,
248253
"max_tokens": max_tokens,
254+
"truncate_prompt_tokens": truncate_prompt_tokens,
249255
"guided_decoding_regex": guided_decoding_regex,
250256
"generation_kwargs": generation_kwargs or {},
251257
},

‎trl/scripts/vllm_serve.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ class GenerateRequest(BaseModel):
495495
top_k: int = -1
496496
min_p: float = 0.0
497497
max_tokens: int = 16
498+
truncate_prompt_tokens: Optional[int] = None
498499
guided_decoding_regex: Optional[str] = None
499500
generation_kwargs: dict = field(default_factory=dict)
500501

@@ -525,6 +526,9 @@ async def generate(request: GenerateRequest):
525526
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
526527
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
527528
completion.
529+
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
530+
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
531+
truncation). If set to `None`, truncation is disabled.
528532
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
529533
model will only generate tokens that match this regex pattern.
530534
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
@@ -575,6 +579,7 @@ async def generate(request: GenerateRequest):
575579
"top_k": request.top_k,
576580
"min_p": request.min_p,
577581
"max_tokens": request.max_tokens,
582+
"truncate_prompt_tokens": request.truncate_prompt_tokens,
578583
"guided_decoding": guided_decoding,
579584
"logprobs": 0,
580585
}

‎trl/trainer/grpo_trainer.py‎

Lines changed: 22 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import inspect
1616
import os
17-
import re
1817
import textwrap
1918
from collections import defaultdict, deque
2019
from contextlib import nullcontext
@@ -71,7 +70,6 @@
7170
shuffle_sequence_dict,
7271
split_pixel_values_by_grid,
7372
split_tensor_dict,
74-
truncate_with_protected_tokens,
7573
unsplit_pixel_values_by_grid,
7674
)
7775

@@ -275,7 +273,7 @@ def __init__(
275273

276274
# Processing class
277275
if processing_class is None:
278-
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
276+
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
279277

280278
# Handle pad token for processors or tokenizers
281279
if isinstance(processing_class, ProcessorMixin):
@@ -291,10 +289,6 @@ def __init__(
291289
self.pad_token = tokenizer.pad_token
292290
self.pad_token_id = tokenizer.pad_token_id
293291
self.eos_token_id = tokenizer.eos_token_id
294-
self.image_token = getattr(processing_class, "image_token", None)
295-
self.image_token_id = getattr(processing_class, "image_token_id", None)
296-
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
297-
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)
298292

299293
# Reward functions
300294
if not isinstance(reward_funcs, list):
@@ -1092,58 +1086,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
10921086
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
10931087
]
10941088

1095-
prompt_inputs = self.processing_class(
1096-
text=prompts_text,
1097-
return_tensors="pt",
1098-
padding=True,
1099-
padding_side="left",
1100-
add_special_tokens=False,
1101-
**kwargs,
1102-
)
1103-
prompt_inputs = super()._prepare_inputs(prompt_inputs)
1104-
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
1105-
1106-
if self.max_prompt_length is not None:
1107-
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
1108-
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
1109-
1110-
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
1111-
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
1112-
# tokens are needed for generation.
1113-
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
1114-
protected = [token for token in protected if token is not None]
1115-
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]
1116-
1117-
prompts_text = self.processing_class.batch_decode(
1118-
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
1119-
)
1120-
1121-
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
1122-
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
1123-
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
1124-
# collapse them back into a single token string to match the original chat template in case it originally
1125-
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
1126-
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
1127-
# the vision_start_token_id (e.g. <start_of_image>).
1128-
if self.image_token is not None:
1129-
escaped_img_token = re.escape(self.image_token)
1130-
# Search for the image token in the chat template
1131-
if re.search(escaped_img_token, self.processing_class.chat_template):
1132-
prompts_text = [
1133-
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
1134-
]
1135-
else:
1136-
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
1137-
if self.vision_end_token_id is not None:
1138-
escaped_eoi_token = re.escape(
1139-
self.processing_class.tokenizer.decode([self.vision_end_token_id])
1140-
)
1141-
prompts_text = [
1142-
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
1143-
]
1144-
else:
1145-
# If vision_end_token_id is None, just remove the image tokens
1146-
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
1089+
if images is not None:
1090+
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
1091+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
1092+
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
1093+
else:
1094+
forward_kwargs = {}
11471095

11481096
# Generate completions using either vLLM or regular generation
11491097
if self.use_vllm:
@@ -1185,6 +1133,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
11851133
top_k=-1 if self.top_k is None else self.top_k,
11861134
min_p=0.0 if self.min_p is None else self.min_p,
11871135
max_tokens=self.max_completion_length,
1136+
truncate_prompt_tokens=self.max_prompt_length,
11881137
guided_decoding_regex=self.guided_decoding_regex,
11891138
generation_kwargs=self.args.generation_kwargs,
11901139
)
@@ -1223,6 +1172,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
12231172
"top_k": -1 if self.top_k is None else self.top_k,
12241173
"min_p": 0.0 if self.min_p is None else self.min_p,
12251174
"max_tokens": self.max_completion_length,
1175+
"truncate_prompt_tokens": self.max_prompt_length,
12261176
"guided_decoding": guided_decoding,
12271177
"logprobs": 0, # only return the logprob of the generated token
12281178
}
@@ -1319,7 +1269,17 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
13191269

13201270
else:
13211271
# Regular generation path
1322-
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
1272+
generate_inputs = self.processing_class(
1273+
text=prompts_text,
1274+
return_tensors="pt",
1275+
padding=True,
1276+
padding_side="left",
1277+
max_length=self.max_prompt_length,
1278+
truncation=True,
1279+
add_special_tokens=False,
1280+
**kwargs,
1281+
)
1282+
generate_inputs = super()._prepare_inputs(generate_inputs)
13231283

13241284
with (
13251285
profiling_context(self, "transformers.generate"),
@@ -1330,15 +1290,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
13301290
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
13311291
):
13321292
prompt_completion_ids = unwrapped_model.generate(
1333-
input_ids=prompt_ids,
1334-
attention_mask=prompt_mask,
1335-
**forward_kwargs,
1336-
generation_config=self.generation_config,
1337-
disable_compile=True,
1293+
**generate_inputs, generation_config=self.generation_config, disable_compile=True
13381294
)
13391295
# Compute prompt length and extract completion ids
1296+
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
13401297
prompt_length = prompt_ids.size(1)
1341-
prompt_ids = prompt_completion_ids[:, :prompt_length]
13421298
completion_ids = prompt_completion_ids[:, prompt_length:]
13431299

13441300
# Mask everything after the first EOS token

0 commit comments

Comments
 (0)