Skip to content

Commit bcc88df

Browse files
yyDing1masoudhashemi
authored andcommitted
[worker] refactor: move the implementation of rm to workers.roles and polish (volcengine#3423)
1 parent d92b1df commit bcc88df

File tree

7 files changed

+174
-55
lines changed

7 files changed

+174
-55
lines changed

verl/workers/config/reward_model.py

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,9 @@
1919
from verl.utils.profiler import ProfilerConfig
2020

2121
from .model import HFModelConfig
22+
from .rollout import SamplingConfig, ServerConfig
2223

23-
__all__ = ["ServerConfig", "SandboxFusionConfig", "RewardModelConfig"]
24-
25-
26-
@dataclass
27-
class ServerConfig(BaseConfig):
28-
"""
29-
Configuration for SGLang server when running in server mode
30-
"""
31-
32-
timeout: float = 60.0
33-
max_attempts: int = 3
34-
retry_delay: float = 2.0
35-
max_connections: int = 1000
36-
max_start_wait_time: float = 300.0
24+
__all__ = ["SandboxFusionConfig", "RewardModelConfig"]
3725

3826

3927
@dataclass
@@ -53,50 +41,25 @@ class SandboxFusionConfig(BaseConfig):
5341

5442
@dataclass
5543
class RewardModelConfig(BaseConfig):
56-
"""Configuration for reward model scoring.
57-
58-
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
59-
60-
Args:
61-
enable (bool): Whether to enable reward model.
62-
enable_resource_pool (bool): Whether to deploy the model to a separate resource pool.
63-
n_gpus_per_node (int): Number of GPUs per node when using resource pool.
64-
nnodes (int): Number of nodes when using resource pool.
65-
strategy (str): FSDP strategy: "fsdp" or "fsdp2".
66-
model (Dict[str, Any]): Model configuration for reward scoring.
67-
micro_batch_size (Optional[int]): Global micro batch size (deprecated).
68-
micro_batch_size_per_gpu (Optional[int]): Local per-GPU micro batch size.
69-
max_length (Optional[int]): Maximum sequence length to process for scoring.
70-
use_dynamic_bsz (bool): Whether to dynamically adjust batch size at runtime.
71-
forward_max_token_len_per_gpu (int): Maximum number of tokens per GPU in one forward pass.
72-
reward_manager (str): Reward manager type (naive or prime).
73-
launch_reward_fn_async (bool): Whether to launch custom reward function asynchronously during log_prob.
74-
sandbox_fusion (Dict[str, Any]): Cloud/local sandbox fusion configuration for custom reward logic.
75-
profiler (Dict[str, Any]): Profiler configuration for reward model.
76-
"""
77-
7844
_mutable_fields = BaseConfig._mutable_fields
7945

8046
enable: bool = False
47+
model_type: str = "discriminative"
48+
name: str = "sglang"
8149
enable_resource_pool: bool = False
8250
n_gpus_per_node: int = 0
8351
nnodes: int = 0
84-
# strategy: str = MISSING
85-
# model: BaseModelConfig = field(default_factory=BaseModelConfig)
86-
# micro_batch_size: Optional[int] = None
87-
# micro_batch_size_per_gpu: Optional[int] = None
88-
# max_length: Optional[int] = None
89-
# use_dynamic_bsz: bool = False
90-
# forward_max_token_len_per_gpu: int = 32768
9152
reward_manager: str = "naive"
9253
launch_reward_fn_async: bool = False
9354

94-
tensor_model_parallel_size: int = 2
95-
engine_kwargs: dict = field(default_factory=dict)
96-
max_num_seqs: int = 1024
9755
dtype: str = "bfloat16"
9856
gpu_memory_utilization: float = 0.5
9957
free_cache_engine: bool = True
58+
tensor_model_parallel_size: int = 2
59+
sampling_config: SamplingConfig = field(default_factory=SamplingConfig)
60+
61+
engine_kwargs: dict = field(default_factory=dict)
62+
max_num_seqs: int = 1024
10063

10164
sandbox_fusion: SandboxFusionConfig = field(default_factory=SandboxFusionConfig)
10265
profiler: ProfilerConfig = field(default_factory=ProfilerConfig)

verl/workers/config/rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"CustomAsyncServerConfig",
2727
"AgentLoopConfig",
2828
"TraceConfig",
29+
"ServerConfig",
2930
"RolloutConfig",
3031
]
3132

verl/workers/roles/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .critic import CriticWorker
1717

1818
try:
19-
from .reward import RewardModelWorker
19+
from .reward_model import RewardModelWorker
2020
except ImportError:
2121
RewardModelWorker = None
2222

verl/workers/roles/reward.py renamed to verl/workers/roles/reward_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from verl.utils.model import compute_position_id_with_mask
3535
from verl.utils.profiler import DistProfiler, DistProfilerExtension, log_gpu_memory_usage
3636
from verl.workers.config import HFModelConfig, RewardModelConfig
37-
from verl.workers.reward_model.sglang_reward_model import SGLangRewardModel
37+
from verl.workers.roles.reward_model_engine import get_reward_model_class
3838

3939
logger = logging.getLogger(__file__)
4040
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@@ -92,7 +92,7 @@ def _build_reward_model(self):
9292

9393
# 4. build reward model
9494
log_gpu_memory_usage("Before building sglang reward model", logger=logger)
95-
self.reward_model = SGLangRewardModel(
95+
self.reward_model = get_reward_model_class(reward_model_config.name)(
9696
config=reward_model_config, model_config=model_config, device_mesh=reward_model_device_mesh
9797
)
9898
log_gpu_memory_usage("After building sglang reward model", logger=logger)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .base import get_reward_model_class
16+
17+
__all__ = ["get_reward_model_class"]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
The base class for reward model
16+
"""
17+
18+
import importlib
19+
from abc import ABC, abstractmethod
20+
21+
from torch.distributed.device_mesh import DeviceMesh
22+
23+
from verl import DataProto
24+
from verl.workers.config import HFModelConfig, RewardModelConfig
25+
26+
__all__ = ["BaseRewardModel"]
27+
28+
29+
class BaseRewardModel(ABC):
30+
"""base class for reward model"""
31+
32+
def __init__(
33+
self,
34+
config: RewardModelConfig,
35+
model_config: HFModelConfig,
36+
device_mesh: DeviceMesh,
37+
):
38+
self.config = config
39+
self.model_config = model_config
40+
self.device_mesh = device_mesh
41+
42+
@abstractmethod
43+
async def resume(self, tags: list[str]):
44+
"""Resume reward model weights or kv cache in GPU memory.
45+
46+
Args:
47+
tags: weights or kv_cache.
48+
"""
49+
pass
50+
51+
@abstractmethod
52+
async def release(self):
53+
"""Release weights and kv cache in GPU memory."""
54+
pass
55+
56+
@abstractmethod
57+
def compute_reward(self, data: DataProto) -> DataProto:
58+
"""Computing reward given input_ids. The transformers should output a tensor with shape
59+
[batch_size, sequence_length], and the value at [EOS] mask should be gathered.
60+
61+
Args:
62+
data: must contain keys "input_ids", "attention_mask" and "position_ids".
63+
- input_ids: [batch_size, sequence_length]
64+
- attention_mask: [batch_size, sequence_length]
65+
- position_ids: [batch_size, sequence_length]
66+
67+
Returns: a data pass protocol containing "reward". Only the [EOS] position contains the reward.
68+
Other position should have zero reward. Note that this may change in the future if we use
69+
dense reward. So, we leave the interface for general case.
70+
- reward: [batch_size, sequence_length].
71+
72+
"""
73+
pass
74+
75+
76+
_REWARD_MODEL_REGISTRY = {
77+
"sglang": "verl.workers.roles.reward_model_engine.sglang_reward_model.SGLangRewardModel",
78+
}
79+
80+
81+
def get_reward_model_class(reward_model_name: str) -> type[BaseRewardModel]:
82+
"""Get the reward model class by name.
83+
84+
Args:
85+
reward_model_name: The name of the reward model.
86+
87+
Returns:
88+
The reward model class.
89+
"""
90+
assert reward_model_name in _REWARD_MODEL_REGISTRY, f"Reward Model {reward_model_name} with mode not found"
91+
fqdn = _REWARD_MODEL_REGISTRY[reward_model_name]
92+
module_name, class_name = fqdn.rsplit(".", 1)
93+
reward_model_module = importlib.import_module(module_name)
94+
return getattr(reward_model_module, class_name)

verl/workers/reward_model/sglang_reward_model.py renamed to verl/workers/roles/reward_model_engine/sglang_reward_model.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,72 @@
1515

1616
import asyncio
1717
import logging
18+
import multiprocessing as mp
1819
import os
1920

21+
import sglang.srt.entrypoints.engine
2022
import torch
2123
import torch.distributed as dist
24+
from sglang.srt.server_args import ServerArgs
2225
from sglang.srt.utils import (
26+
assert_pkg_version,
2327
get_ip,
2428
get_open_port,
29+
is_cuda,
30+
set_prometheus_multiproc_dir,
31+
set_ulimit,
2532
)
2633
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
2734

2835
from verl import DataProto
2936
from verl.utils.net_utils import is_ipv6
3037
from verl.workers.config import HFModelConfig, RewardModelConfig
31-
from verl.workers.reward_model import BasePPORewardModel
38+
from verl.workers.roles.reward_model_engine.base import BaseRewardModel
3239
from verl.workers.rollout.sglang_rollout.http_server_engine import AsyncHttpServerAdapter
3340
from verl.workers.rollout.sglang_rollout.utils import broadcast_pyobj
3441

3542
logger = logging.getLogger(__file__)
3643
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
3744

3845

46+
# patch to avoid issue https://github.com/sgl-project/sglang/issues/6723
47+
def _set_envs_and_config(server_args: ServerArgs):
48+
# Set global environments
49+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
50+
os.environ["NCCL_CUMEM_ENABLE"] = "0"
51+
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
52+
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
53+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
54+
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
55+
56+
# Set prometheus env vars
57+
if server_args.enable_metrics:
58+
set_prometheus_multiproc_dir()
59+
60+
# Set ulimit
61+
set_ulimit()
62+
63+
# Check flashinfer version
64+
if server_args.attention_backend == "flashinfer":
65+
assert_pkg_version(
66+
"flashinfer_python",
67+
"0.2.5",
68+
"Please uninstall the old version and reinstall the latest version by following the instructions at https://docs.flashinfer.ai/installation.html.",
69+
)
70+
if is_cuda():
71+
assert_pkg_version(
72+
"sgl-kernel",
73+
"0.1.1",
74+
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
75+
)
76+
77+
# Set mp start method
78+
mp.set_start_method("spawn", force=True)
79+
80+
81+
sglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config
82+
83+
3984
def _pre_process_inputs(
4085
attention_mask: torch.Tensor,
4186
prompt_token_ids: torch.Tensor,
@@ -54,7 +99,7 @@ def _map_each_output(output):
5499
return scores
55100

56101

57-
class SGLangRewardModel(BasePPORewardModel):
102+
class SGLangRewardModel(BaseRewardModel):
58103
def __init__(
59104
self,
60105
config: RewardModelConfig,
@@ -66,14 +111,13 @@ def __init__(
66111
actor_module = model_config.local_path
67112
trust_remote_code = model_config.trust_remote_code
68113
port = None
69-
kwargs = {}
70114

71115
os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true")
72116

73-
self._init_distributed_env(device_mesh_cpu=None, **kwargs)
117+
self._init_distributed_env(device_mesh_cpu=None)
74118
self._init_inference_engine(trust_remote_code, actor_module, port)
75119

76-
def _init_distributed_env(self, device_mesh_cpu, **kwargs):
120+
def _init_distributed_env(self, device_mesh_cpu):
77121
self._device_mesh_cpu = device_mesh_cpu
78122
os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true")
79123
self.tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
@@ -211,7 +255,7 @@ def compute_reward(self, data: DataProto):
211255
return reward_score
212256

213257
async def resume(self, tags: list[str]):
214-
"""Resume rollout weights or kv cache in GPU memory.
258+
"""Resume reward model weights or kv cache in GPU memory.
215259
216260
Args:
217261
tag: weights or kv_cache.

0 commit comments

Comments
 (0)