Skip to content

Commit 8616165

Browse files
Mightentechkang
authored andcommitted
[rollout, vllm, sglang] fix: allow user customization of repetition_penalty to avoid watchdog timeout during GRPO rollout (volcengine#3309)
Allow user customization of `repetition_penalty` to avoid watchdog timeout during GRPO rollout ### What does this PR do? This PR adds an interface for users to specify `repetition_penalty`, which helps avoid repetition in LLM generation and prevents watchdog timeouts during GRPO rollout. If not specified, `repetition_penalty` will remain at its default value of `1.0`. ### Checklist Before Starting - [X] Search for similar PRs. No similar PRs found. - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test This PR can be vetted by existing CI test cases. ### API and Usage Example Previously, users could not specify `repetition_penalty`, but this PR adds support for it. For example, users can now start GRPO training with a command like: ```bash python -m verl.trainer.main_ppo \ +actor_rollout_ref.rollout.repetition_penalty=1.05 \ # other params here... ``` ### Design & Code Changes This PR adds an interface allowing users to specify the `repetition_penalty` (e.g., `1.05`), while maintaining backward compatibility with the default value of `1.0`. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 817d64d commit 8616165

File tree

4 files changed

+4
-1
lines changed

4 files changed

+4
-1
lines changed

verl/workers/rollout/sglang_rollout/async_sglang_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ async def generate(
8282
request_id: str,
8383
image_data: Optional[list[Any]] = None,
8484
) -> TokenOutput:
85+
sampling_params.setdefault("repetition_penalty", self.config.rollout.get("repetition_penalty", 1.0))
8586
return await self.master_worker.generate.remote(prompt_ids, sampling_params, request_id, image_data=image_data)
8687

8788
async def wake_up(self):

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def _init_sampling_params(self, **kwargs):
481481
max_new_tokens=self.config.response_length,
482482
presence_penalty=0.0,
483483
frequency_penalty=0.0,
484-
repetition_penalty=1.0,
484+
repetition_penalty=self.config.get("repetition_penalty", 1.0),
485485
)
486486
# supporting adding any sampling params from the config file
487487
for k in self.config.keys():

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ async def generate(
350350
) -> TokenOutput:
351351
max_tokens = self.max_model_len - len(prompt_ids)
352352
sampling_params["logprobs"] = 0 if sampling_params.pop("logprobs", False) else None
353+
sampling_params.setdefault("repetition_penalty", self.config.rollout.get("repetition_penalty", 1.0))
353354
sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)
354355
prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.processor)
355356
prompt = TokensPrompt(

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def __init__(
206206
n=1,
207207
logprobs=0, # can be set to 0 and let actor to recompute
208208
max_tokens=config.response_length,
209+
repetition_penalty=config.get("repetition_penalty", 1.0),
209210
)
210211

211212
kwargs["detokenize"] = False

0 commit comments

Comments
 (0)