diff --git a/tests/utils/test_config_on_cpu.py b/tests/utils/test_config_on_cpu.py index 13b1732c120..f55e7d68291 100644 --- a/tests/utils/test_config_on_cpu.py +++ b/tests/utils/test_config_on_cpu.py @@ -17,19 +17,20 @@ from omegaconf import OmegaConf +from verl.base_config import BaseConfig from verl.utils import omega_conf_to_dataclass @dataclass -class TestDataclass: - hidden_size: int - activation: str +class TestDataclass(BaseConfig): + hidden_size: int = 0 + activation: str = "relu" @dataclass -class TestTrainConfig: - batch_size: int - model: TestDataclass +class TestTrainConfig(BaseConfig): + batch_size: int = 0 + model: TestDataclass = field(default_factory=TestDataclass) override_config: dict = field(default_factory=dict) @@ -79,7 +80,7 @@ def test_command_with_override(self): # Run the command result = subprocess.run( - ["python3", "scripts/print_cfg.py", "+critic.profiler.extra.any_key=val"], + ["python3", "scripts/print_cfg.py"], capture_output=True, text=True, ) @@ -90,7 +91,6 @@ def test_command_with_override(self): # Verify the output contains expected config information self.assertIn("critic", result.stdout) self.assertIn("profiler", result.stdout) - self.assertIn("extra={'any_key': 'val'}", result.stdout) if __name__ == "__main__": diff --git a/tests/utils/test_nvtx_profile.py b/tests/utils/test_nvtx_profile.py index 248078fe01c..fea7675335a 100644 --- a/tests/utils/test_nvtx_profile.py +++ b/tests/utils/test_nvtx_profile.py @@ -56,7 +56,7 @@ def test_frozen_config(self): from verl.utils.profiler.config import ProfilerConfig # Create a new ProfilerConfig instance - config = ProfilerConfig(all_ranks=False, ranks=[0], extra={"key": "value"}) + config = ProfilerConfig(all_ranks=False, ranks=[0]) with self.assertRaises(FrozenInstanceError): config.all_ranks = True @@ -70,10 +70,6 @@ def test_frozen_config(self): with self.assertRaises(TypeError): config["ranks"] = [1, 2, 3] - assert config["extra"]["key"] == "value" - config["extra"]["key"] = "value2" - assert config["extra"]["key"] == "value2" - class TestNsightSystemsProfiler(unittest.TestCase): """Test suite for NsightSystemsProfiler functionality. diff --git a/verl/base_config.py b/verl/base_config.py index b01b61dd393..f425dd1464b 100644 --- a/verl/base_config.py +++ b/verl/base_config.py @@ -13,7 +13,7 @@ # limitations under the License. import collections -from dataclasses import FrozenInstanceError, dataclass, field, fields +from dataclasses import FrozenInstanceError, dataclass, fields from typing import Any @@ -27,8 +27,8 @@ class BaseConfig(collections.abc.Mapping): This allows instances of this class to be used like dictionaries. """ - _mutable_fields = {"extra"} - extra: dict[str, Any] = field(default_factory=dict) + _mutable_fields = set() + _target_: str = "" def __setattr__(self, name: str, value): """Set the value of an attribute. Check if the attr is mutable before setting the value.""" diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 39e034e8b28..56e0e9eb97c 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -137,6 +137,7 @@ actor_rollout_ref: use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} load_weight: true rollout: + _target_: verl.workers.config.RolloutConfig name: ??? mode: sync temperature: 1.0 @@ -169,12 +170,14 @@ actor_rollout_ref: sglang: attention_backend: null val_kwargs: + _target_: verl.workers.config.SamplingConfig top_k: -1 top_p: 1.0 temperature: 0 'n': 1 do_sample: false multi_turn: + _target_: verl.workers.config.MultiTurnConfig enable: false max_assistant_turns: null tool_config_path: null @@ -188,13 +191,16 @@ actor_rollout_ref: format: hermes calculate_log_probs: false agent: + _target_: verl.workers.config.AgentLoopConfig num_workers: 8 agent_loop_config_path: null custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig path: null name: null update_weights_bucket_megabytes: 512 trace: + _target_: verl.workers.config.TraceConfig backend: null token2text: false skip_rollout: false diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index fed8e713552..4c16fd04088 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -112,6 +112,7 @@ actor_rollout_ref: entropy_from_logits_with_chunking: false entropy_checkpointing: false rollout: + _target_: verl.workers.config.RolloutConfig name: ??? mode: sync temperature: 1.0 @@ -144,12 +145,14 @@ actor_rollout_ref: sglang: attention_backend: null val_kwargs: + _target_: verl.workers.config.SamplingConfig top_k: -1 top_p: 1.0 temperature: 0 'n': 1 do_sample: false multi_turn: + _target_: verl.workers.config.MultiTurnConfig enable: false max_assistant_turns: null tool_config_path: null @@ -163,13 +166,16 @@ actor_rollout_ref: format: hermes calculate_log_probs: false agent: + _target_: verl.workers.config.AgentLoopConfig num_workers: 8 agent_loop_config_path: null custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig path: null name: null update_weights_bucket_megabytes: 512 trace: + _target_: verl.workers.config.TraceConfig backend: null token2text: false skip_rollout: false diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml index 6ac43b5dcd9..c5d66218bb2 100644 --- a/verl/trainer/config/generation.yaml +++ b/verl/trainer/config/generation.yaml @@ -14,6 +14,7 @@ model: path: ~/models/Qwen2-7B-Instruct external_lib: null rollout: + _target_: verl.workers.config.RolloutConfig name: vllm mode: sync # sync: LLM, async: AsyncLLM temperature: 1.0 diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index e7295a4d900..95c42c8d69c 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -1,3 +1,6 @@ +# Target class for this configuration +_target_: verl.workers.config.RolloutConfig + # actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future name: ??? @@ -103,6 +106,9 @@ engine_kwargs: # Sampling parameters used during validation. val_kwargs: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.SamplingConfig + # sampling parameters for validation # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. top_k: -1 @@ -122,6 +128,9 @@ val_kwargs: # Multi-turn interaction config for tools or chat. multi_turn: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.MultiTurnConfig + # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well enable: False @@ -170,6 +179,9 @@ calculate_log_probs: False # [Experimental] agent loop based rollout configs agent: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.AgentLoopConfig + # Number of agent loop workers num_workers: 8 @@ -188,6 +200,9 @@ agent: # custom async server configs custom_async_server: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.CustomAsyncServerConfig + # Path to the custom async server implementation path: null @@ -211,6 +226,9 @@ update_weights_bucket_megabytes: 512 # trace rollout data trace: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.TraceConfig + # trace backend, support mlflow, weave backend: null diff --git a/verl/utils/config.py b/verl/utils/config.py index bbafa8b2c70..fabed0b2526 100644 --- a/verl/utils/config.py +++ b/verl/utils/config.py @@ -53,8 +53,10 @@ def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[ raise ValueError(f"{dataclass_type} must be a dataclass") cfg = OmegaConf.create(config) # in case it's a dict # pop _target_ to avoid hydra instantiate error, as most dataclass do not have _target_ - if "_target_" in cfg: - cfg.pop("_target_") + # Updated (vermouth1992) We add _target_ to BaseConfig so that it is compatible. + # Otherwise, this code path can't support recursive instantiation. + # if "_target_" in cfg: + # cfg.pop("_target_") cfg_from_dataclass = OmegaConf.structured(dataclass_type) # let cfg override the existing vals in `cfg_from_dataclass` cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg) diff --git a/verl/workers/config/__init__.py b/verl/workers/config/__init__.py index ca9937e9e99..e7180c96094 100644 --- a/verl/workers/config/__init__.py +++ b/verl/workers/config/__init__.py @@ -16,6 +16,7 @@ from .actor import * # noqa from .engine import * # noqa from .optimizer import * # noqa -from . import actor, critic, engine, optimizer +from .rollout import * # noqa +from . import actor, critic, engine, optimizer, rollout -__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__ +__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__ + rollout.__all__ diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py new file mode 100644 index 00000000000..370a61f928c --- /dev/null +++ b/verl/workers/config/rollout.py @@ -0,0 +1,141 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from verl.base_config import BaseConfig +from verl.utils.profiler import ProfilerConfig + +__all__ = [ + "SamplingConfig", + "MultiTurnConfig", + "CustomAsyncServerConfig", + "AgentLoopConfig", + "TraceConfig", + "RolloutConfig", +] + + +@dataclass +class SamplingConfig(BaseConfig): + temperature: float = 1.0 + top_k: int = -1 + top_p: float = 1.0 + do_sample: bool = True + n: int = 1 + + +@dataclass +class MultiTurnConfig(BaseConfig): + _mutable_fields = {"max_assistant_turns", "max_user_turns"} + + enable: bool = False + max_assistant_turns: Optional[int] = None + tool_config_path: Optional[str] = None + max_user_turns: Optional[int] = None + max_parallel_calls: int = 1 + max_tool_response_length: int = 256 + tool_response_truncate_side: str = "middle" + interaction_config_path: Optional[str] = None + use_inference_chat_template: bool = False + tokenization_sanity_check_mode: str = "strict" + format: str = "hermes" + + +@dataclass +class CustomAsyncServerConfig(BaseConfig): + path: Optional[str] = None + name: Optional[str] = None + + +@dataclass +class AgentLoopConfig(BaseConfig): + num_workers: int = 8 + agent_loop_config_path: Optional[str] = None + custom_async_server: CustomAsyncServerConfig = field(default_factory=CustomAsyncServerConfig) + + +@dataclass +class TraceConfig(BaseConfig): + backend: Optional[str] = None + token2text: bool = False + + +@dataclass +class RolloutConfig(BaseConfig): + _mutable_fields = {"max_model_len"} + + name: Optional[str] = None + mode: str = "sync" + + temperature: float = 1.0 + top_k: int = -1 + top_p: float = 1.0 + do_sample: bool = True + n: int = 1 + + prompt_length: int = 512 + response_length: int = 512 + + dtype: str = "bfloat16" + gpu_memory_utilization: float = 0.5 + ignore_eos: bool = False + enforce_eager: bool = True + cudagraph_capture_sizes: Optional[list] = None + free_cache_engine: bool = True + tensor_model_parallel_size: int = 2 + max_num_batched_tokens: int = 8192 + + # TODO: enable train_kwargs + # train_sampling_config: SamplingConfig = field(default_factory=SamplingConfig) + + val_kwargs: SamplingConfig = field(default_factory=SamplingConfig) + + max_model_len: Optional[int] = None + max_num_seqs: int = 1024 + + # note that the logprob computation should belong to the actor + log_prob_micro_batch_size: Optional[int] = None + log_prob_micro_batch_size_per_gpu: Optional[int] = None + log_prob_use_dynamic_bsz: bool = False + log_prob_max_token_len_per_gpu: int = 16384 + + disable_log_stats: bool = True + + multi_stage_wake_up: bool = False + engine_kwargs: dict = field(default_factory=dict) + + calculate_log_probs: bool = False + + agent: AgentLoopConfig = field(default_factory=AgentLoopConfig) + + trace: TraceConfig = field(default_factory=TraceConfig) + + multi_turn: MultiTurnConfig = field(default_factory=MultiTurnConfig) + + update_weights_bucket_megabytes: int = 512 + + skip_rollout: bool = False + + skip_dump_dir: str = "/tmp/rollout_dump" + + profiler: ProfilerConfig = field(default_factory=ProfilerConfig) + + enable_chunked_prefill: bool = True + load_format: str = "dummy_dtensor" + + layered_summon: bool = False + + layer_name_map: dict = field(default_factory=dict) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 32ff55e036b..3dea025afa6 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -75,7 +75,7 @@ from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer from verl.utils.profiler.performance import reduce_timing from verl.utils.py_functional import convert_to_regular_types -from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig +from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, RolloutConfig from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager logger = logging.getLogger(__file__) @@ -513,11 +513,13 @@ def _build_rollout(self, trust_remote_code=False): "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect ) + rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) + if rollout_name == "hf": from verl.workers.rollout import HFRollout from verl.workers.sharding_manager.base import BaseShardingManager - rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout) + rollout = HFRollout(module=self.actor_module_fsdp, config=rollout_config) rollout_sharding_manager = BaseShardingManager() # TODO: a sharding manager that do nothing? @@ -535,10 +537,10 @@ def _build_rollout(self, trust_remote_code=False): # lora_kwargs = {} from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout - vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + vllm_rollout_cls = vLLMRollout if rollout_config.mode == "sync" else vLLMAsyncRollout rollout = vllm_rollout_cls( model_path=local_path, - config=self.config.rollout, + config=rollout_config, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, device_mesh=rollout_device_mesh, @@ -577,7 +579,7 @@ def _build_rollout(self, trust_remote_code=False): log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) rollout = SGLangRollout( actor_module=local_path, - config=self.config.rollout, + config=rollout_config, processing_class=self.processor if self.processor is not None else self.tokenizer, model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 1227078b720..15787d45494 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -61,7 +61,7 @@ ) from verl.utils.profiler.performance import reduce_timing from verl.workers.actor.megatron_actor import MegatronPPOActor -from verl.workers.config import McoreCriticConfig +from verl.workers.config import McoreCriticConfig, RolloutConfig from verl.workers.critic.megatron_critic import MegatronPPOCritic from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel @@ -380,6 +380,9 @@ def _build_rollout(self, trust_remote_code=False): "qkv_layer_name": "self_attention.linear_qkv.", "gate_proj_layer_name": "linear_fc1.", } + + rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) + if self.config.rollout.name == "vllm": from torch.distributed.device_mesh import init_device_mesh @@ -405,7 +408,7 @@ def _build_rollout(self, trust_remote_code=False): vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout rollout = vllm_rollout_cls( model_path=local_path, - config=self.config.rollout, + config=rollout_config, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, device_mesh=rollout_device_mesh, @@ -466,7 +469,7 @@ def _build_rollout(self, trust_remote_code=False): log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None) rollout = SGLangRollout( actor_module=local_path, - config=self.config.rollout, + config=rollout_config, processing_class=self.processor if self.processor is not None else self.tokenizer, model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 38b7673e35b..914c9c37863 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -29,7 +29,6 @@ import sglang.srt.entrypoints.engine import torch import torch.distributed as dist -from omegaconf import DictConfig from sglang.srt.managers.tokenizer_manager import ( ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, @@ -60,6 +59,7 @@ from verl.utils.net_utils import is_ipv6 from verl.utils.profiler import GPUMemoryLogger from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length +from verl.workers.config import RolloutConfig from verl.workers.rollout.base import BaseRollout from verl.workers.rollout.schemas import ( AsyncRolloutRequest, @@ -247,7 +247,7 @@ class SGLangRollout(BaseRollout): def __init__( self, actor_module: str, - config: DictConfig, + config: RolloutConfig, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, model_hf_config, port=None, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index be40e195931..7ffac347918 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -33,7 +33,6 @@ import pickle import socket from contextlib import contextmanager -from copy import deepcopy from types import MethodType from typing import Any @@ -44,7 +43,7 @@ import zmq import zmq.asyncio from filelock import FileLock -from omegaconf import DictConfig, ListConfig, OmegaConf +from omegaconf import DictConfig, ListConfig from tensordict import TensorDict from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationLevel @@ -57,6 +56,7 @@ from verl.utils.profiler import GPUMemoryLogger from verl.utils.ray_utils import ray_noset_visible_devices from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length +from verl.workers.config import RolloutConfig from verl.workers.rollout.base import BaseRollout logger = logging.getLogger(__file__) @@ -79,7 +79,7 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[in class vLLMRollout(BaseRollout): - def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): + def __init__(self, model_path: str, config: RolloutConfig, tokenizer, model_hf_config, **kwargs): """A vLLM rollout. It requires the module is supported by the vllm. Args: @@ -153,11 +153,8 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf lora_kwargs = kwargs.pop("lora_kwargs", {}) self.lora_kwargs = lora_kwargs # copy it to avoid secretly modifying the engine config - engine_kwargs = ( - {} - if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs - else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm)) - ) + engine_kwargs = config.get("engine_kwargs", {}).get("vllm", {}) + # For each vLLM engine parameter, # - `None` means not setting it, so we pop it, and leave it to vLLM default value # (which can vary across different vLLM versions);