diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index aa32ffc6eaf..277364f9bfc 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -82,6 +82,7 @@ async def generate( request_id: str, image_data: Optional[list[Any]] = None, ) -> TokenOutput: + sampling_params.setdefault("repetition_penalty", self.config.rollout.get("repetition_penalty", 1.0)) return await self.master_worker.generate.remote(prompt_ids, sampling_params, request_id, image_data=image_data) async def wake_up(self): diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index c716647913a..6b8470bd398 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -481,7 +481,7 @@ def _init_sampling_params(self, **kwargs): max_new_tokens=self.config.response_length, presence_penalty=0.0, frequency_penalty=0.0, - repetition_penalty=1.0, + repetition_penalty=self.config.get("repetition_penalty", 1.0), ) # supporting adding any sampling params from the config file for k in self.config.keys(): diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index c406b8846c3..fd476227ffe 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -350,6 +350,7 @@ async def generate( ) -> TokenOutput: max_tokens = self.max_model_len - len(prompt_ids) sampling_params["logprobs"] = 0 if sampling_params.pop("logprobs", False) else None + sampling_params.setdefault("repetition_penalty", self.config.rollout.get("repetition_penalty", 1.0)) sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params) prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.processor) prompt = TokensPrompt( diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 147b6f54ad9..7520ea1f51e 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -207,6 +207,7 @@ def __init__( n=1, logprobs=0, # can be set to 0 and let actor to recompute max_tokens=config.response_length, + repetition_penalty=config.get("repetition_penalty", 1.0), ) kwargs["detokenize"] = False