Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ actor_rollout_ref:
use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}
load_weight: true
rollout:
_target_: verl.workers.config.RolloutConfig
name: ???
mode: sync
temperature: 1.0
Expand All @@ -163,18 +164,23 @@ actor_rollout_ref:
'n': 1
multi_stage_wake_up: false
engine_kwargs:
_target_: verl.workers.config.EngineConfig
vllm:
_target_: verl.workers.config.vLLMEngineConfig
swap_space: null
disable_mm_preprocessor_cache: false
sglang:
_target_: verl.workers.config.SGLangEngineConfig
attention_backend: null
val_kwargs:
_target_: verl.workers.config.SamplingConfig
top_k: -1
top_p: 1.0
temperature: 0
'n': 1
do_sample: false
multi_turn:
_target_: verl.workers.config.MultiTurnConfig
enable: false
max_assistant_turns: null
tool_config_path: null
Expand All @@ -188,13 +194,16 @@ actor_rollout_ref:
format: hermes
calculate_log_probs: false
agent:
_target_: verl.workers.config.AgentLoopConfig
num_workers: 8
agent_loop_config_path: null
custom_async_server:
_target_: verl.workers.config.CustomAsyncServerConfig
path: null
name: null
update_weights_bucket_megabytes: 512
trace:
_target_: verl.workers.config.TraceConfig
backend: null
token2text: false
skip_rollout: false
Expand Down
9 changes: 9 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ actor_rollout_ref:
entropy_from_logits_with_chunking: false
entropy_checkpointing: false
rollout:
_target_: verl.workers.config.RolloutConfig
name: ???
mode: sync
temperature: 1.0
Expand All @@ -138,18 +139,23 @@ actor_rollout_ref:
'n': 1
multi_stage_wake_up: false
engine_kwargs:
_target_: verl.workers.config.EngineConfig
vllm:
_target_: verl.workers.config.vLLMEngineConfig
swap_space: null
disable_mm_preprocessor_cache: false
sglang:
_target_: verl.workers.config.SGLangEngineConfig
attention_backend: null
val_kwargs:
_target_: verl.workers.config.SamplingConfig
top_k: -1
top_p: 1.0
temperature: 0
'n': 1
do_sample: false
multi_turn:
_target_: verl.workers.config.MultiTurnConfig
enable: false
max_assistant_turns: null
tool_config_path: null
Expand All @@ -163,13 +169,16 @@ actor_rollout_ref:
format: hermes
calculate_log_probs: false
agent:
_target_: verl.workers.config.AgentLoopConfig
num_workers: 8
agent_loop_config_path: null
custom_async_server:
_target_: verl.workers.config.CustomAsyncServerConfig
path: null
name: null
update_weights_bucket_megabytes: 512
trace:
_target_: verl.workers.config.TraceConfig
backend: null
token2text: false
skip_rollout: false
Expand Down
11 changes: 11 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Target class for this configuration
_target_: verl.workers.config.RolloutConfig

# actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future
name: ???

Expand Down Expand Up @@ -84,9 +87,11 @@ multi_stage_wake_up: false

# Extra inference engine arguments (vllm, sglang).
engine_kwargs:
_target_: verl.workers.config.EngineConfig

# for vllm
vllm:
_target_: verl.workers.config.vLLMEngineConfig

# Swap space (in GB) used by inference engine. null uses default (e.g., 4 GB).
swap_space: null
Expand All @@ -96,12 +101,14 @@ engine_kwargs:

# for sglang
sglang:
_target_: verl.workers.config.SGLangEngineConfig

# The attention backend for sglang engine. Options: flashinfer, triton, flashmla, null for default.
attention_backend: null

# Sampling parameters used during validation.
val_kwargs:
_target_: verl.workers.config.SamplingConfig
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move rollout sampling params to SamplingConfig?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I left a TODO. Currently, we don't want to introduce breaking changes. RolloutConfig should match rollout.yaml exactly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also want to remove the logprob bsz into actor.infer in the following PRs


# sampling parameters for validation
# Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.
Expand All @@ -121,6 +128,7 @@ val_kwargs:

# Multi-turn interaction config for tools or chat.
multi_turn:
_target_: verl.workers.config.MultiTurnConfig

# set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
enable: False
Expand Down Expand Up @@ -169,6 +177,7 @@ calculate_log_probs: False

# [Experimental] agent loop based rollout configs
agent:
_target_: verl.workers.config.AgentLoopConfig

# Number of agent loop workers
num_workers: 8
Expand All @@ -187,6 +196,7 @@ agent:

# custom async server configs
custom_async_server:
_target_: verl.workers.config.CustomAsyncServerConfig

# Path to the custom async server implementation
path: null
Expand All @@ -210,6 +220,7 @@ update_weights_bucket_megabytes: 512

# trace rollout data
trace:
_target_: verl.workers.config.TraceConfig

# trace backend, support mlflow, weave
backend: null
Expand Down
5 changes: 3 additions & 2 deletions verl/workers/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .actor import * # noqa
from .engine import * # noqa
from .optimizer import * # noqa
from . import actor, critic, engine, optimizer
from .rollout import * # noqa
from . import actor, critic, engine, optimizer, rollout

__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__
__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__ + rollout.__all__
141 changes: 141 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from verl.base_config import BaseConfig

__all__ = [
"SamplingConfig",
"vLLMEngineConfig",
"SGLangEngineConfig",
"EngineConfig",
"MultiTurnConfig",
"CustomAsyncServerConfig",
"AgentLoopConfig",
"TraceConfig",
"RolloutConfig",
]


@dataclass
class SamplingConfig(BaseConfig):
temperature: float = 1.0
top_k: int = -1
top_p: float = 1.0
do_sample: bool = True
n: int = 1


@dataclass
class vLLMEngineConfig(BaseConfig):
swap_space: int = None
disable_mm_preprocessor_cache: bool = True


@dataclass
class SGLangEngineConfig(BaseConfig):
attention_backend: str = None


@dataclass
class EngineConfig(BaseConfig):
vllm: vLLMEngineConfig = field(default_factory=vLLMEngineConfig)
sglang: SGLangEngineConfig = field(default_factory=SGLangEngineConfig)


@dataclass
class MultiTurnConfig(BaseConfig):
enable: bool = False
max_assistant_turns: int = None
tool_config_path: str = None
max_user_turns: int = None
max_parallel_calls: int = 1
max_tool_response_length: int = 256
tool_response_truncate_side: str = "middle"
interaction_config_path: str = None
use_inference_chat_template: bool = False
tokenization_sanity_check_mode: str = "strict"
format: str = "hermes"


@dataclass
class CustomAsyncServerConfig(BaseConfig):
path: str = None
name: str = None


@dataclass
class AgentLoopConfig(BaseConfig):
num_workers: int = 8
agent_loop_config_path: str = None
custom_async_server: CustomAsyncServerConfig = field(default_factory=CustomAsyncServerConfig)


@dataclass
class TraceConfig(BaseConfig):
backend: str = None
token2text: bool = False


@dataclass
class RolloutConfig(BaseConfig):
name: str
mode: str = "sync"

temperature: float = 1.0
top_k: int = -1
top_p: float = 1.0
do_sample: bool = True
n: int = 1
Comment on lines +83 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The sampling parameters temperature, top_k, top_p, do_sample, and n are duplicated from SamplingConfig. This can lead to inconsistencies and makes the configuration harder to manage. As noted in the TODO on line 115, these should be encapsulated within a SamplingConfig object for training rollouts, similar to how val_kwargs is handled for validation. This will improve the structure and maintainability of the configuration. I recommend replacing these individual parameters with a train_kwargs field of type SamplingConfig. This will also require updating the corresponding YAML configuration files.

Suggested change
temperature: float = 1.0
top_k: int = -1
top_p: float = 1.0
do_sample: bool = True
n: int = 1
train_kwargs: SamplingConfig = field(default_factory=SamplingConfig)


prompt_length: int = 512
response_length: int = 512

dtype: str = "bfloat16"
gpu_memory_utilization: float = 0.5
ignore_eos: bool = False
enforce_eager: bool = True
cudagraph_capture_sizes: list = None
free_cache_engine: bool = True
tensor_model_parallel_size: int = 2
max_num_batched_tokens: int = 8192

# TODO: enable train_kwargs
# train_sampling_config: SamplingConfig = field(default_factory=SamplingConfig)

val_kwargs: SamplingConfig = field(default_factory=SamplingConfig)

max_model_len: int = None
max_num_seqs: int = 1024

# note that the logprob computation should belong to the actor
log_prob_micro_batch_size: int = None
log_prob_micro_batch_size_per_gpu: int = None
log_prob_use_dynamic_bsz: bool = False
log_prob_max_token_len_per_gpu: int = 16384
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

As the comment on line 123 correctly points out, the log probability computation is an actor responsibility. Therefore, the configuration parameters log_prob_micro_batch_size, log_prob_micro_batch_size_per_gpu, log_prob_use_dynamic_bsz, and log_prob_max_token_len_per_gpu should be part of ActorConfig, not RolloutConfig. Keeping them here is confusing and structurally incorrect, even if the values are linked from the actor config in YAML. Moving these parameters to ActorConfig would make the configuration more logical and easier to maintain.


disable_log_stats: bool = True

multi_stage_wake_up: bool = False
engine_kwargs: EngineConfig = field(default_factory=EngineConfig)

calculate_log_probs: bool = False

multi_turn: MultiTurnConfig = field(default_factory=MultiTurnConfig)

update_weights_bucket_megabytes: int = 512

skip_rollout: bool = False
skip_dump_dir: str = "/tmp/rollout_dump"
12 changes: 7 additions & 5 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer
from verl.utils.profiler.performance import reduce_timing
from verl.utils.py_functional import convert_to_regular_types
from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig
from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, RolloutConfig
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -513,11 +513,13 @@ def _build_rollout(self, trust_remote_code=False):
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)

rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)

if rollout_name == "hf":
from verl.workers.rollout import HFRollout
from verl.workers.sharding_manager.base import BaseShardingManager

rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
rollout = HFRollout(module=self.actor_module_fsdp, config=rollout_config)
rollout_sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?

Expand All @@ -535,10 +537,10 @@ def _build_rollout(self, trust_remote_code=False):
# lora_kwargs = {}
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout

vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
vllm_rollout_cls = vLLMRollout if rollout_config.mode == "sync" else vLLMAsyncRollout
rollout = vllm_rollout_cls(
model_path=local_path,
config=self.config.rollout,
config=rollout_config,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
device_mesh=rollout_device_mesh,
Expand Down Expand Up @@ -577,7 +579,7 @@ def _build_rollout(self, trust_remote_code=False):
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
rollout = SGLangRollout(
actor_module=local_path,
config=self.config.rollout,
config=rollout_config,
processing_class=self.processor if self.processor is not None else self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
Expand Down
Loading