|
| 1 | +# Copyright 2025 Bytedance Ltd. and/or its affiliates |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from dataclasses import dataclass, field |
| 16 | +from typing import Optional |
| 17 | + |
| 18 | +from verl.base_config import BaseConfig |
| 19 | +from verl.utils.profiler import ProfilerConfig |
| 20 | + |
| 21 | +__all__ = [ |
| 22 | + "SamplingConfig", |
| 23 | + "MultiTurnConfig", |
| 24 | + "CustomAsyncServerConfig", |
| 25 | + "AgentLoopConfig", |
| 26 | + "TraceConfig", |
| 27 | + "RolloutConfig", |
| 28 | +] |
| 29 | + |
| 30 | + |
| 31 | +@dataclass |
| 32 | +class SamplingConfig(BaseConfig): |
| 33 | + temperature: float = 1.0 |
| 34 | + top_k: int = -1 |
| 35 | + top_p: float = 1.0 |
| 36 | + do_sample: bool = True |
| 37 | + n: int = 1 |
| 38 | + |
| 39 | + |
| 40 | +@dataclass |
| 41 | +class MultiTurnConfig(BaseConfig): |
| 42 | + _mutable_fields = {"max_assistant_turns", "max_user_turns"} |
| 43 | + |
| 44 | + enable: bool = False |
| 45 | + max_assistant_turns: Optional[int] = None |
| 46 | + tool_config_path: Optional[str] = None |
| 47 | + max_user_turns: Optional[int] = None |
| 48 | + max_parallel_calls: int = 1 |
| 49 | + max_tool_response_length: int = 256 |
| 50 | + tool_response_truncate_side: str = "middle" |
| 51 | + interaction_config_path: Optional[str] = None |
| 52 | + use_inference_chat_template: bool = False |
| 53 | + tokenization_sanity_check_mode: str = "strict" |
| 54 | + format: str = "hermes" |
| 55 | + |
| 56 | + |
| 57 | +@dataclass |
| 58 | +class CustomAsyncServerConfig(BaseConfig): |
| 59 | + path: Optional[str] = None |
| 60 | + name: Optional[str] = None |
| 61 | + |
| 62 | + |
| 63 | +@dataclass |
| 64 | +class AgentLoopConfig(BaseConfig): |
| 65 | + num_workers: int = 8 |
| 66 | + agent_loop_config_path: Optional[str] = None |
| 67 | + custom_async_server: CustomAsyncServerConfig = field(default_factory=CustomAsyncServerConfig) |
| 68 | + |
| 69 | + |
| 70 | +@dataclass |
| 71 | +class TraceConfig(BaseConfig): |
| 72 | + backend: Optional[str] = None |
| 73 | + token2text: bool = False |
| 74 | + |
| 75 | + |
| 76 | +@dataclass |
| 77 | +class RolloutConfig(BaseConfig): |
| 78 | + _mutable_fields = {"max_model_len"} |
| 79 | + |
| 80 | + name: Optional[str] = None |
| 81 | + mode: str = "sync" |
| 82 | + |
| 83 | + temperature: float = 1.0 |
| 84 | + top_k: int = -1 |
| 85 | + top_p: float = 1.0 |
| 86 | + do_sample: bool = True |
| 87 | + n: int = 1 |
| 88 | + |
| 89 | + prompt_length: int = 512 |
| 90 | + response_length: int = 512 |
| 91 | + |
| 92 | + dtype: str = "bfloat16" |
| 93 | + gpu_memory_utilization: float = 0.5 |
| 94 | + ignore_eos: bool = False |
| 95 | + enforce_eager: bool = True |
| 96 | + cudagraph_capture_sizes: Optional[list] = None |
| 97 | + free_cache_engine: bool = True |
| 98 | + tensor_model_parallel_size: int = 2 |
| 99 | + max_num_batched_tokens: int = 8192 |
| 100 | + |
| 101 | + # TODO: enable train_kwargs |
| 102 | + # train_sampling_config: SamplingConfig = field(default_factory=SamplingConfig) |
| 103 | + |
| 104 | + val_kwargs: SamplingConfig = field(default_factory=SamplingConfig) |
| 105 | + |
| 106 | + max_model_len: Optional[int] = None |
| 107 | + max_num_seqs: int = 1024 |
| 108 | + |
| 109 | + # note that the logprob computation should belong to the actor |
| 110 | + log_prob_micro_batch_size: Optional[int] = None |
| 111 | + log_prob_micro_batch_size_per_gpu: Optional[int] = None |
| 112 | + log_prob_use_dynamic_bsz: bool = False |
| 113 | + log_prob_max_token_len_per_gpu: int = 16384 |
| 114 | + |
| 115 | + disable_log_stats: bool = True |
| 116 | + |
| 117 | + multi_stage_wake_up: bool = False |
| 118 | + engine_kwargs: dict = field(default_factory=dict) |
| 119 | + |
| 120 | + calculate_log_probs: bool = False |
| 121 | + |
| 122 | + agent: AgentLoopConfig = field(default_factory=AgentLoopConfig) |
| 123 | + |
| 124 | + trace: TraceConfig = field(default_factory=TraceConfig) |
| 125 | + |
| 126 | + multi_turn: MultiTurnConfig = field(default_factory=MultiTurnConfig) |
| 127 | + |
| 128 | + update_weights_bucket_megabytes: int = 512 |
| 129 | + |
| 130 | + skip_rollout: bool = False |
| 131 | + |
| 132 | + skip_dump_dir: str = "/tmp/rollout_dump" |
| 133 | + |
| 134 | + profiler: ProfilerConfig = field(default_factory=ProfilerConfig) |
| 135 | + |
| 136 | + enable_chunked_prefill: bool = True |
| 137 | + load_format: str = "dummy_dtensor" |
| 138 | + |
| 139 | + layered_summon: bool = False |
| 140 | + |
| 141 | + layer_name_map: dict = field(default_factory=dict) |
0 commit comments