Skip to content

Commit e0eec05

Browse files
qgallouedecalbertvillanovaYonatanGideoniburtenshawsergiopaniego
authored
🧺 [4/N] Refactor _generate in GRPO/RLOO: Move forward_kwargs outside generation method (#4154)
Co-authored-by: Albert Villanova del Moral <[email protected]> Co-authored-by: YonatanGideoni <[email protected]> Co-authored-by: burtenshaw <[email protected]> Co-authored-by: sergiopaniego <[email protected]> Co-authored-by: lewtun <[email protected]> Co-authored-by: Kashif Rasul <[email protected]>
1 parent f4c554d commit e0eec05

File tree

3 files changed

+65
-51
lines changed

3 files changed

+65
-51
lines changed

‎trl/experimental/gfpo/gfpo_trainer.py‎

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
from accelerate.utils import gather_object
2020

21-
from ...data_utils import is_conversational
21+
from ...data_utils import apply_chat_template, is_conversational
2222
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
2323
from ...trainer.utils import nanmax, nanmin, nanstd, pad
2424

@@ -80,13 +80,9 @@ def _generate_and_score_completions(self, inputs):
8080
if images is not None and all(img_list == [] for img_list in images):
8181
images = None
8282

83-
(
84-
prompt_ids_list,
85-
completion_ids_list,
86-
num_items_in_batch,
87-
sampling_per_token_logps_list,
88-
forward_kwargs,
89-
) = self._generate(prompts, images)
83+
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
84+
prompts, images
85+
)
9086

9187
# Convert lists of token IDs to padded tensors
9288
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@@ -112,18 +108,30 @@ def _generate_and_score_completions(self, inputs):
112108
# Concatenate prompt_mask with completion_mask for logit computation
113109
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
114110
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
111+
112+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
113+
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
114+
115+
num_images = [len(img_list) for img_list in images] if images is not None else None
116+
117+
# Get forward_kwargs for models with multimodal inputs
118+
if images is not None:
119+
prompts_text = [
120+
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
121+
]
122+
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
123+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
124+
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
125+
else:
126+
forward_kwargs = {}
127+
115128
# If token_type_ids are used, extend them with zeros for the completion part
116129
if "token_type_ids" in forward_kwargs:
117130
token_type_ids = forward_kwargs["token_type_ids"]
118131
forward_kwargs["token_type_ids"] = torch.cat(
119132
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
120133
)
121134

122-
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
123-
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
124-
125-
num_images = [len(img_list) for img_list in images] if images is not None else None
126-
127135
with torch.no_grad():
128136
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
129137
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the

‎trl/trainer/grpo_trainer.py‎

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,13 +1086,6 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
10861086
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
10871087
]
10881088

1089-
if images is not None:
1090-
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
1091-
prompt_inputs = super()._prepare_inputs(prompt_inputs)
1092-
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
1093-
else:
1094-
forward_kwargs = {}
1095-
10961089
# Generate completions using either vLLM or regular generation
10971090
if self.use_vllm:
10981091
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
@@ -1307,13 +1300,13 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
13071300
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
13081301
logprobs = None # not used in this case
13091302

1310-
return prompt_ids, completion_ids, logprobs, forward_kwargs
1303+
return prompt_ids, completion_ids, logprobs
13111304

13121305
def _generate(self, prompts: list[str], images: Optional[list]):
13131306
device = self.accelerator.device
13141307
mode = "train" if self.model.training else "eval"
13151308

1316-
prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
1309+
prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts, images)
13171310

13181311
# Get completion length per sequence, used for logging
13191312
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
@@ -1345,7 +1338,7 @@ def _generate(self, prompts: list[str], images: Optional[list]):
13451338
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
13461339
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
13471340

1348-
return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
1341+
return prompt_ids, completion_ids, total_completion_tokens, logprobs
13491342

13501343
def _generate_and_score_completions(
13511344
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
@@ -1365,13 +1358,9 @@ def _generate_and_score_completions(
13651358
if images is not None and all(img_list == [] for img_list in images):
13661359
images = None
13671360

1368-
(
1369-
prompt_ids_list,
1370-
completion_ids_list,
1371-
num_items_in_batch,
1372-
sampling_per_token_logps_list,
1373-
forward_kwargs,
1374-
) = self._generate(prompts, images)
1361+
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
1362+
prompts, images
1363+
)
13751364

13761365
# Convert lists of token IDs to padded tensors
13771366
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@@ -1397,18 +1386,30 @@ def _generate_and_score_completions(
13971386
# Concatenate prompt_mask with completion_mask for logit computation
13981387
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
13991388
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
1389+
1390+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1391+
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
1392+
1393+
num_images = [len(img_list) for img_list in images] if images is not None else None
1394+
1395+
# Get forward_kwargs for models with multimodal inputs
1396+
if images is not None:
1397+
prompts_text = [
1398+
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
1399+
]
1400+
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
1401+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
1402+
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
1403+
else:
1404+
forward_kwargs = {}
1405+
14001406
# If token_type_ids are used, extend them with zeros for the completion part
14011407
if "token_type_ids" in forward_kwargs:
14021408
token_type_ids = forward_kwargs["token_type_ids"]
14031409
forward_kwargs["token_type_ids"] = torch.cat(
14041410
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
14051411
)
14061412

1407-
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1408-
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
1409-
1410-
num_images = [len(img_list) for img_list in images] if images is not None else None
1411-
14121413
with torch.no_grad():
14131414
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
14141415
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the

‎trl/trainer/rloo_trainer.py‎

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,13 +1082,6 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
10821082
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
10831083
]
10841084

1085-
if images is not None:
1086-
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
1087-
prompt_inputs = super()._prepare_inputs(prompt_inputs)
1088-
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
1089-
else:
1090-
forward_kwargs = {}
1091-
10921085
# Generate completions using either vLLM or regular generation
10931086
if self.use_vllm:
10941087
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
@@ -1292,13 +1285,13 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
12921285
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
12931286
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
12941287

1295-
return prompt_ids, completion_ids, forward_kwargs
1288+
return prompt_ids, completion_ids
12961289

12971290
def _generate(self, prompts: list[str], images: Optional[list]):
12981291
device = self.accelerator.device
12991292
mode = "train" if self.model.training else "eval"
13001293

1301-
prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images)
1294+
prompt_ids, completion_ids = self._generate_single_turn(prompts, images)
13021295

13031296
# Get completion length per sequence, used for logging
13041297
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
@@ -1331,7 +1324,7 @@ def _generate(self, prompts: list[str], images: Optional[list]):
13311324
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
13321325
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
13331326

1334-
return prompt_ids, completion_ids, forward_kwargs
1327+
return prompt_ids, completion_ids
13351328

13361329
def _generate_and_score_completions(
13371330
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
@@ -1351,7 +1344,7 @@ def _generate_and_score_completions(
13511344
if images is not None and all(img_list == [] for img_list in images):
13521345
images = None
13531346

1354-
prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images)
1347+
prompt_ids_list, completion_ids_list = self._generate(prompts, images)
13551348

13561349
# Convert lists of token IDs to padded tensors
13571350
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@@ -1372,18 +1365,30 @@ def _generate_and_score_completions(
13721365
# Concatenate prompt_mask with completion_mask for logit computation
13731366
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
13741367
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
1368+
1369+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1370+
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
1371+
1372+
num_images = [len(img_list) for img_list in images] if images is not None else None
1373+
1374+
# Get forward_kwargs for models with multimodal inputs
1375+
if images is not None:
1376+
prompts_text = [
1377+
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
1378+
]
1379+
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
1380+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
1381+
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
1382+
else:
1383+
forward_kwargs = {}
1384+
13751385
# If token_type_ids are used, extend them with zeros for the completion part
13761386
if "token_type_ids" in forward_kwargs:
13771387
token_type_ids = forward_kwargs["token_type_ids"]
13781388
forward_kwargs["token_type_ids"] = torch.cat(
13791389
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
13801390
)
13811391

1382-
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1383-
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
1384-
1385-
num_images = [len(img_list) for img_list in images] if images is not None else None
1386-
13871392
with torch.no_grad():
13881393
# Compute the per-token log probabilities for the current model
13891394
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(

0 commit comments

Comments
 (0)