Skip to content

Commit 42f612d

Browse files
GHGmc2vermouth1992
andauthored
[rollout] refactor: Add option for rollout_log_probs, and default as False (#2072)
### Checklist Before Starting - [x] Searched for similar PR(s). - [x] Checked PR Title format - In format of: [modules] type: Title - modules are in `fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data` - type is in `feat, fix, refactor, chore` - can involve multiple modules, seperated by `,` or space, like `[megatron, fsdp, doc] feat: xxx` ### What does this PR do? > As discussed in #1712, we may want to minimize communication cost on large clusters, add an option for it and default as `False` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc. ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. ### API > Demonstrate how the API changes if any. ### Usage Example > Provide usage example(s) for easier usage. ```python # Add code snippet or script demonstrating how to use this ``` ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title `description` if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] New CI unit test(s) are added to cover the code path. - [x] Rely on existing unit tests on CI that covers the code path. --------- Co-authored-by: Chi Zhang <[email protected]>
1 parent 0077f3e commit 42f612d

File tree

6 files changed

+32
-13
lines changed

6 files changed

+32
-13
lines changed

verl/trainer/config/generation.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ rollout:
4141
disable_log_stats: True
4242
enable_chunked_prefill: True
4343
n: 1
44+
# support logging rollout prob for debugging purpose
45+
calculate_log_probs: False
4446
actor:
4547
strategy: fsdp # This is for backward-compatibility
4648
ulysses_sequence_parallel_size: 1 # sp size

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ actor_rollout_ref:
204204
# To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:
205205
# Qwen/QwQ-32B, Qwen/Qwen3-xxB
206206
enable_tokenization_sanity_check: True
207+
# support logging rollout prob for debugging purpose
208+
calculate_log_probs: False
207209
# Nsight system profiler configs
208210
profiler:
209211
discrete: False

verl/trainer/config/ppo_trainer.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,9 @@ actor_rollout_ref:
505505
# Qwen/QwQ-32B, Qwen/Qwen3-xxB
506506
enable_tokenization_sanity_check: True
507507

508+
# support logging rollout prob for debugging purpose
509+
calculate_log_probs: False
510+
508511
# profiler configs
509512
profiler:
510513

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,11 +666,13 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP
666666
out = _post_process_outputs(self.tokenizer, output)
667667

668668
response = out[0].to(idx.device)
669-
rollout_log_probs = out[1].to(idx.device)
669+
if self.config.calculate_log_probs:
670+
rollout_log_probs = out[1].to(idx.device)
670671

671672
if response.shape[1] < self.config.response_length:
672673
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
673-
rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id)
674+
if self.config.calculate_log_probs:
675+
rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id)
674676

675677
# utilize current sampling params
676678
if self.sampling_params.get("n", 1) > 1 and do_sample:
@@ -706,12 +708,14 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP
706708
"prompts": idx,
707709
"responses": response,
708710
"input_ids": seq, # here input_ids become the whole sentences
709-
"rollout_log_probs": rollout_log_probs, # we will recompute old log prob with actor
710711
"attention_mask": attention_mask,
711712
"position_ids": position_ids,
712713
},
713714
batch_size=batch_size,
714715
)
716+
if self.config.calculate_log_probs:
717+
# we will recompute old log prob with actor
718+
batch["rollout_log_probs"] = rollout_log_probs
715719

716720
# free cache engine
717721
if self.config.free_cache_engine and self._engine is not None:

verl/workers/rollout/vllm_rollout/vllm_rollout.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,13 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
241241
# TODO(sgm): disable logprob when recompute_log_prob is enable
242242
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
243243
response = output[0].to(idx.device)
244-
log_probs = output[1].to(idx.device)
244+
if self.config.calculate_log_probs:
245+
rollout_log_probs = output[1].to(idx.device)
245246

246247
if response.shape[1] < self.config.response_length:
247248
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
248-
log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
249+
if self.config.calculate_log_probs:
250+
rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id)
249251

250252
# utilize current sampling params
251253
if self.sampling_params.n > 1 and do_sample:
@@ -276,12 +278,14 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
276278
"prompts": idx,
277279
"responses": response,
278280
"input_ids": seq, # here input_ids become the whole sentences
279-
"rollout_log_probs": log_probs, # we will recompute old log prob with actor
280281
"attention_mask": attention_mask,
281282
"position_ids": position_ids,
282283
},
283284
batch_size=batch_size,
284285
)
286+
if self.config.calculate_log_probs:
287+
# we will recompute old log prob with actor
288+
batch["rollout_log_probs"] = rollout_log_probs
285289

286290
# free vllm cache engine
287291
if self.config.free_cache_engine:

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,16 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
298298
for sample_id in range(len(output.outputs)):
299299
response_ids = output.outputs[sample_id].token_ids
300300
response.append(response_ids)
301-
curr_log_prob = []
302-
for i, logprob in enumerate(output.outputs[sample_id].logprobs):
303-
curr_log_prob.append(logprob[response_ids[i]].logprob)
304-
rollout_log_probs.append(curr_log_prob)
301+
if self.config.calculate_log_probs:
302+
curr_log_prob = []
303+
for i, logprob in enumerate(output.outputs[sample_id].logprobs):
304+
curr_log_prob.append(logprob[response_ids[i]].logprob)
305+
rollout_log_probs.append(curr_log_prob)
305306

306307
response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device)
307-
rollout_log_probs = pad_2d_list_to_length(rollout_log_probs, -1, max_length=self.config.response_length).to(idx.device)
308-
rollout_log_probs = rollout_log_probs.to(torch.float32)
308+
if self.config.calculate_log_probs:
309+
rollout_log_probs = pad_2d_list_to_length(rollout_log_probs, -1, max_length=self.config.response_length).to(idx.device)
310+
rollout_log_probs = rollout_log_probs.to(torch.float32)
309311

310312
if self.sampling_params.n > 1 and do_sample:
311313
idx = _repeat_interleave(idx, self.sampling_params.n)
@@ -339,12 +341,14 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
339341
"prompts": idx,
340342
"responses": response,
341343
"input_ids": seq, # here input_ids become the whole sentences
342-
"rollout_log_probs": rollout_log_probs, # we will recompute old log prob with actor
343344
"attention_mask": attention_mask,
344345
"position_ids": position_ids,
345346
},
346347
batch_size=batch_size,
347348
)
349+
if self.config.calculate_log_probs:
350+
# we will recompute old log prob with actor
351+
batch["rollout_log_probs"] = rollout_log_probs
348352

349353
# free vllm cache engine
350354
if (

0 commit comments

Comments
 (0)