From 71765ee4179659f3be02cbe61b22753dab7fa023 Mon Sep 17 00:00:00 2001 From: yuyangding Date: Sun, 7 Sep 2025 16:21:31 +0800 Subject: [PATCH 01/14] update --- verl/workers/config/rollout.py | 1 + verl/workers/reward_model/sglang_reward_model.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index 00e2b8928c9..94018d7b642 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -26,6 +26,7 @@ "CustomAsyncServerConfig", "AgentLoopConfig", "TraceConfig", + "ServerConfig", "RolloutConfig", ] diff --git a/verl/workers/reward_model/sglang_reward_model.py b/verl/workers/reward_model/sglang_reward_model.py index 1c000eadca0..665268935b4 100644 --- a/verl/workers/reward_model/sglang_reward_model.py +++ b/verl/workers/reward_model/sglang_reward_model.py @@ -65,15 +65,15 @@ def __init__( actor_module = model_config.local_path trust_remote_code = model_config.trust_remote_code + self.reward_mode = self.config.mode # discriminative or generative port = None - kwargs = {} os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") - self._init_distributed_env(device_mesh_cpu=None, **kwargs) + self._init_distributed_env(device_mesh_cpu=None) self._init_inference_engine(trust_remote_code, actor_module, port) - def _init_distributed_env(self, device_mesh_cpu, **kwargs): + def _init_distributed_env(self, device_mesh_cpu): self._device_mesh_cpu = device_mesh_cpu os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") self.tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) From a40b51b881e67965631ae960118c4dd75124874a Mon Sep 17 00:00:00 2001 From: yuyangding Date: Sun, 7 Sep 2025 23:56:09 +0800 Subject: [PATCH 02/14] add ci test --- .../workers/reward_model/test_reward_model.py | 183 ++++++++---------- verl/workers/config/reward_model.py | 54 +----- .../reward_model/sglang_reward_model.py | 1 - 3 files changed, 89 insertions(+), 149 deletions(-) diff --git a/tests/workers/reward_model/test_reward_model.py b/tests/workers/reward_model/test_reward_model.py index 36bb96fac53..8866e82e228 100644 --- a/tests/workers/reward_model/test_reward_model.py +++ b/tests/workers/reward_model/test_reward_model.py @@ -11,138 +11,77 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import ray import torch -from hydra import compose, initialize_config_dir -from transformers import AutoTokenizer - -from verl.protocol import DataProto -from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup -from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role - - -def test_agent_loop_compute_score_with_model(): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - } - ) - - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): - config = compose("ppo_trainer") +from transformers import AutoModelForSequenceClassification - rm_path = "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2" +from verl import DataProto +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.utils.model import compute_position_id_with_mask +from verl.workers.config import HFModelConfig, RewardModelConfig +from verl.workers.roles import RewardModelWorker - if os.environ["LEGACY_IMPL_RM"] == "disable": - from verl.workers.config import HFModelConfig, RewardModelConfig - from verl.workers.roles import RewardModelWorker - - model_config = HFModelConfig(path=rm_path) - reward_model_config = RewardModelConfig( - enable=True, - model_config=model_config, - input_model_config=None, - tensor_model_parallel_size=1, - gpu_memory_utilization=0.8, - ) - else: - from verl.workers.fsdp_workers import RewardModelWorker - - config.reward_model.enable = True - config.reward_model.model.path = rm_path - config.reward_model.use_dynamic_bsz = True - config.reward_model.forward_max_token_len_per_gpu = 6000 - config.reward_model.micro_batch_size_per_gpu = 40 - config.reward_model.model.trust_remote_code = True - config.reward_model.model.input_tokenizer = None - reward_model_config = config.reward_model - - config.trainer.n_gpus_per_node = 2 - config.trainer.nnodes = 1 - - role_worker_mapping = {} - if reward_model_config.enable: - role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) - - global_pool_id = "global_pool" - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = {} - mapping[Role.RewardModel] = "global_pool" - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - resource_pool_manager.create_resource_pool() - resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} - - if reward_model_config.enable: - # we create a RM here - resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs(role_worker_mapping[Role.RewardModel], config=reward_model_config) - resource_pool_to_cls[resource_pool]["rm"] = rm_cls - - all_wg = {} - for resource_pool, class_dict in resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - - rm_wg = all_wg["rm"] - rm_wg.init_model() +def _create_data_samples(tokenizer) -> DataProto: convs = [ [ { "role": "user", "content": "What is the range of the numeric output of a sigmoid node in a neural network?", }, - {"role": "assistant", "content": "The output is bounded between -1 and 1."}, + {"role": "assistant", "content": "Between -1 and 1."}, ], [ { "role": "user", "content": "What is the range of the numeric output of a sigmoid node in a neural network?", }, - {"role": "assistant", "content": "The output is bounded between 0 and 1."}, + {"role": "assistant", "content": "Between 0 and 1."}, + ], + [ + {"role": "user", "content": "What is the capital of Australia?"}, + { + "role": "assistant", + "content": "Canberra is the capital city of Australia.", + }, + ], + [ + {"role": "user", "content": "What is the capital of Australia?"}, + { + "role": "assistant", + "content": "Sydney is the capital of Australia.", + }, ], ] - tokenizer = AutoTokenizer.from_pretrained(rm_path) prompt_length, response_length = 1024, 4096 pad_token_id = tokenizer.pad_token_id - prompts, responses, input_ids, attention_masks, position_ids = [], [], [], [], [] + prompts, responses, input_ids, attention_masks = [], [], [], [] for conv in convs: - prompt = tokenizer.apply_chat_template(conv[:1], tokenize=True) - response = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt) :] + prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True) + response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :] + + padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens + padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens)) attention_mask = ( - [0] * (prompt_length - len(prompt)) - + [1] * len(prompt) - + [1] * len(response) - + [0] * (response_length - len(response)) + [0] * (prompt_length - len(prompt_tokens)) + + [1] * len(prompt_tokens) + + [1] * len(response_tokens) + + [0] * (response_length - len(response_tokens)) ) - prompt = [pad_token_id] * (prompt_length - len(prompt)) + prompt - response = response + [pad_token_id] * (response_length - len(response)) - prompts.append(torch.tensor(prompt)) - responses.append(torch.tensor(response)) - input_ids.append(torch.tensor(prompt + response)) + prompts.append(torch.tensor(padded_prompt)) + responses.append(torch.tensor(padded_response)) + input_ids.append(torch.tensor(padded_prompt + padded_response)) attention_masks.append(torch.tensor(attention_mask)) - from verl.utils.model import compute_position_id_with_mask - prompts = torch.stack(prompts) responses = torch.stack(responses) input_ids = torch.stack(input_ids) attention_masks = torch.stack(attention_masks) position_ids = compute_position_id_with_mask(attention_masks) - data = DataProto.from_dict( + + return DataProto.from_dict( tensors={ "prompts": prompts, "responses": responses, @@ -151,8 +90,48 @@ def test_agent_loop_compute_score_with_model(): "position_ids": position_ids, }, ) + + +def test_reward_model(): + ray.init() + + rm_path = "Skywork/Skywork-Reward-V2-Llama-3.2-1B" + model_config = HFModelConfig(path=rm_path) + config = RewardModelConfig( + enable=True, + model_type="discriminative", + dtype="bfloat16", + model_config=model_config, + input_model_config=None, + tensor_model_parallel_size=2, + ) + ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(RewardModelWorker), config=config) + resource_pool = RayResourcePool(process_on_nodes=[8]) + rm_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) + # init model + rm_wg.init_model() + + # create data samples + tokenizer = model_config.get_processor() + data = _create_data_samples(tokenizer) + gen_batch = rm_wg.compute_rm_score(data) - rm_scores = gen_batch.batch["rm_scores"] - sample_scores = rm_scores.sum(dim=1) - print(sample_scores) + server_rm_scores = gen_batch.batch["rm_scores"].sum(dim=-1) + print(f"{server_rm_scores=}") + server_rm_scores_mean = torch.mean(server_rm_scores) + + hf_model = AutoModelForSequenceClassification.from_pretrained(rm_path, torch_dtype=torch.bfloat16) + hf_model.pad_token_id = tokenizer.pad_token_id + hf_output = hf_model( + input_ids=data.batch["input_ids"], + attention_mask=data.batch["attention_mask"], + ) + hf_rm_scores = hf_output.logits.squeeze().detach().to("cpu") + print(f"{hf_rm_scores=}") + + hf_rm_scores_mean = torch.mean(hf_rm_scores).to(server_rm_scores_mean) + print(hf_rm_scores_mean, server_rm_scores_mean) + + torch.testing.assert_close(hf_rm_scores_mean, server_rm_scores_mean, atol=2e-2, rtol=1e-2) + ray.shutdown() diff --git a/verl/workers/config/reward_model.py b/verl/workers/config/reward_model.py index fea41bfc5ec..ed677f87b36 100644 --- a/verl/workers/config/reward_model.py +++ b/verl/workers/config/reward_model.py @@ -19,21 +19,9 @@ from verl.utils.profiler import ProfilerConfig from .model import HFModelConfig +from .rollout import SamplingConfig, ServerConfig -__all__ = ["ServerConfig", "SandboxFusionConfig", "RewardModelConfig"] - - -@dataclass -class ServerConfig(BaseConfig): - """ - Configuration for SGLang server when running in server mode - """ - - timeout: float = 60.0 - max_attempts: int = 3 - retry_delay: float = 2.0 - max_connections: int = 1000 - max_start_wait_time: float = 300.0 +__all__ = ["SandboxFusionConfig", "RewardModelConfig"] @dataclass @@ -53,50 +41,24 @@ class SandboxFusionConfig(BaseConfig): @dataclass class RewardModelConfig(BaseConfig): - """Configuration for reward model scoring. - - The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. - - Args: - enable (bool): Whether to enable reward model. - enable_resource_pool (bool): Whether to deploy the model to a separate resource pool. - n_gpus_per_node (int): Number of GPUs per node when using resource pool. - nnodes (int): Number of nodes when using resource pool. - strategy (str): FSDP strategy: "fsdp" or "fsdp2". - model (Dict[str, Any]): Model configuration for reward scoring. - micro_batch_size (Optional[int]): Global micro batch size (deprecated). - micro_batch_size_per_gpu (Optional[int]): Local per-GPU micro batch size. - max_length (Optional[int]): Maximum sequence length to process for scoring. - use_dynamic_bsz (bool): Whether to dynamically adjust batch size at runtime. - forward_max_token_len_per_gpu (int): Maximum number of tokens per GPU in one forward pass. - reward_manager (str): Reward manager type (naive or prime). - launch_reward_fn_async (bool): Whether to launch custom reward function asynchronously during log_prob. - sandbox_fusion (Dict[str, Any]): Cloud/local sandbox fusion configuration for custom reward logic. - profiler (Dict[str, Any]): Profiler configuration for reward model. - """ - _mutable_fields = BaseConfig._mutable_fields enable: bool = False + model_type: str = "discriminative" enable_resource_pool: bool = False n_gpus_per_node: int = 0 nnodes: int = 0 - # strategy: str = MISSING - # model: BaseModelConfig = field(default_factory=BaseModelConfig) - # micro_batch_size: Optional[int] = None - # micro_batch_size_per_gpu: Optional[int] = None - # max_length: Optional[int] = None - # use_dynamic_bsz: bool = False - # forward_max_token_len_per_gpu: int = 32768 reward_manager: str = "naive" launch_reward_fn_async: bool = False - tensor_model_parallel_size: int = 2 - engine_kwargs: dict = field(default_factory=dict) - max_num_seqs: int = 1024 dtype: str = "bfloat16" gpu_memory_utilization: float = 0.5 free_cache_engine: bool = True + tensor_model_parallel_size: int = 2 + sampling_config: SamplingConfig = field(default_factory=SamplingConfig) + + engine_kwargs: dict = field(default_factory=dict) + max_num_seqs: int = 1024 sandbox_fusion: SandboxFusionConfig = field(default_factory=SandboxFusionConfig) profiler: ProfilerConfig = field(default_factory=ProfilerConfig) diff --git a/verl/workers/reward_model/sglang_reward_model.py b/verl/workers/reward_model/sglang_reward_model.py index 665268935b4..1475d46dbb2 100644 --- a/verl/workers/reward_model/sglang_reward_model.py +++ b/verl/workers/reward_model/sglang_reward_model.py @@ -65,7 +65,6 @@ def __init__( actor_module = model_config.local_path trust_remote_code = model_config.trust_remote_code - self.reward_mode = self.config.mode # discriminative or generative port = None os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") From 52b88501fd53e46fb896de9e8c41c292df673c7c Mon Sep 17 00:00:00 2001 From: yuyangding Date: Mon, 8 Sep 2025 00:02:09 +0800 Subject: [PATCH 03/14] update --- .github/workflows/reward_model.yml | 88 ++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 .github/workflows/reward_model.yml diff --git a/.github/workflows/reward_model.yml b/.github/workflows/reward_model.yml new file mode 100644 index 00000000000..6e8977d03e1 --- /dev/null +++ b/.github/workflows/reward_model.yml @@ -0,0 +1,88 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. +# name: Check PR Title + +name: reward_model + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + - v0.* + pull_request: + branches: + - main + - v0.* + paths: + - "verl/**/*.py" + # Entrypoints + - ".github/workflows/reward_model.yml" + - "tests/workers/reward_model/**" + +# Declare permissions just read content. +permissions: + contents: read + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +jobs: + discriminative_reward_model: + runs-on: [L20x8] + timeout-minutes: 20 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + container: + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + pip install --upgrade "huggingface_hub[cli]" + - name: Download model config files + run: | + hf download Qwen/Qwen2.5-0.5B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-0.5B-Instruct + + - name: Running discriminative reward model tests on 8 L20 GPUs + run: | + pytest -s -x tests/workers/reward_model/test_reward_model.py From 08e7db780f1c06dda0776b4511befc945874182e Mon Sep 17 00:00:00 2001 From: Yuyang Ding <61647442+yyDing1@users.noreply.github.com> Date: Mon, 8 Sep 2025 00:10:28 +0800 Subject: [PATCH 04/14] Update tests/workers/reward_model/test_reward_model.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/workers/reward_model/test_reward_model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/workers/reward_model/test_reward_model.py b/tests/workers/reward_model/test_reward_model.py index 8866e82e228..36d82546288 100644 --- a/tests/workers/reward_model/test_reward_model.py +++ b/tests/workers/reward_model/test_reward_model.py @@ -129,9 +129,6 @@ def test_reward_model(): hf_rm_scores = hf_output.logits.squeeze().detach().to("cpu") print(f"{hf_rm_scores=}") - hf_rm_scores_mean = torch.mean(hf_rm_scores).to(server_rm_scores_mean) - print(hf_rm_scores_mean, server_rm_scores_mean) - - torch.testing.assert_close(hf_rm_scores_mean, server_rm_scores_mean, atol=2e-2, rtol=1e-2) + torch.testing.assert_close(server_rm_scores, hf_rm_scores.to(server_rm_scores.dtype), atol=2e-2, rtol=1e-2) ray.shutdown() From 2fed7e46e3991b042390cc4e129d6b7b1172aafa Mon Sep 17 00:00:00 2001 From: yuyangding Date: Mon, 8 Sep 2025 00:17:07 +0800 Subject: [PATCH 05/14] fix --- .github/workflows/reward_model.yml | 2 +- tests/workers/reward_model/test_reward_model.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/reward_model.yml b/.github/workflows/reward_model.yml index 6e8977d03e1..0c863b904ff 100644 --- a/.github/workflows/reward_model.yml +++ b/.github/workflows/reward_model.yml @@ -81,7 +81,7 @@ jobs: pip install --upgrade "huggingface_hub[cli]" - name: Download model config files run: | - hf download Qwen/Qwen2.5-0.5B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-0.5B-Instruct + hf download Skywork/Skywork-Reward-V2-Llama-3.2-1B --local-dir $HOME/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B - name: Running discriminative reward model tests on 8 L20 GPUs run: | diff --git a/tests/workers/reward_model/test_reward_model.py b/tests/workers/reward_model/test_reward_model.py index 36d82546288..95e50c31160 100644 --- a/tests/workers/reward_model/test_reward_model.py +++ b/tests/workers/reward_model/test_reward_model.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import ray import torch from transformers import AutoModelForSequenceClassification @@ -95,7 +97,7 @@ def _create_data_samples(tokenizer) -> DataProto: def test_reward_model(): ray.init() - rm_path = "Skywork/Skywork-Reward-V2-Llama-3.2-1B" + rm_path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct") model_config = HFModelConfig(path=rm_path) config = RewardModelConfig( enable=True, From 2238f2a7bafff1ddebcf53261a2292c07224a657 Mon Sep 17 00:00:00 2001 From: yuyangding Date: Mon, 8 Sep 2025 00:20:16 +0800 Subject: [PATCH 06/14] fix --- tests/workers/reward_model/test_reward_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/workers/reward_model/test_reward_model.py b/tests/workers/reward_model/test_reward_model.py index 95e50c31160..2e48a6946df 100644 --- a/tests/workers/reward_model/test_reward_model.py +++ b/tests/workers/reward_model/test_reward_model.py @@ -97,7 +97,7 @@ def _create_data_samples(tokenizer) -> DataProto: def test_reward_model(): ray.init() - rm_path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct") + rm_path = os.path.expanduser("~/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B") model_config = HFModelConfig(path=rm_path) config = RewardModelConfig( enable=True, From 5301eebfd5f8b8e8dc04778868d3fd551129278e Mon Sep 17 00:00:00 2001 From: yuyangding Date: Mon, 8 Sep 2025 13:47:15 +0800 Subject: [PATCH 07/14] fix --- tests/workers/reward_model/test_reward_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/workers/reward_model/test_reward_model.py b/tests/workers/reward_model/test_reward_model.py index 2e48a6946df..001c5f2903b 100644 --- a/tests/workers/reward_model/test_reward_model.py +++ b/tests/workers/reward_model/test_reward_model.py @@ -130,7 +130,8 @@ def test_reward_model(): ) hf_rm_scores = hf_output.logits.squeeze().detach().to("cpu") print(f"{hf_rm_scores=}") + hf_rm_scores_mean = torch.mean(hf_rm_scores).to(server_rm_scores.dtype) - torch.testing.assert_close(server_rm_scores, hf_rm_scores.to(server_rm_scores.dtype), atol=2e-2, rtol=1e-2) + torch.testing.assert_close(server_rm_scores_mean, hf_rm_scores_mean, atol=2e-2, rtol=1e-2) ray.shutdown() From a67cea5f4992dc02d15603baf58434de5087dc8b Mon Sep 17 00:00:00 2001 From: yuyangding Date: Mon, 8 Sep 2025 14:12:20 +0800 Subject: [PATCH 08/14] fix --- .github/workflows/reward_model.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/reward_model.yml b/.github/workflows/reward_model.yml index 0c863b904ff..60808821b5c 100644 --- a/.github/workflows/reward_model.yml +++ b/.github/workflows/reward_model.yml @@ -68,6 +68,9 @@ jobs: NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK: "True" + NCCL_SHM_DISABLE: "1" + NCCL_P2P_DISABLE: "1" container: image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2 options: --gpus all --shm-size=10g From 5165ddfd945bd3a04b468a88de5e71e1a2e3c00e Mon Sep 17 00:00:00 2001 From: yuyangding Date: Mon, 8 Sep 2025 23:04:54 +0800 Subject: [PATCH 09/14] update --- .github/workflows/reward_model.yml | 6 ++---- tests/workers/reward_model/test_reward_model.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/reward_model.yml b/.github/workflows/reward_model.yml index 60808821b5c..8fa8ed07019 100644 --- a/.github/workflows/reward_model.yml +++ b/.github/workflows/reward_model.yml @@ -72,7 +72,7 @@ jobs: NCCL_SHM_DISABLE: "1" NCCL_P2P_DISABLE: "1" container: - image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2 + image: verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -80,12 +80,10 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install --no-deps -e .[test] - pip install --upgrade "huggingface_hub[cli]" + pip3 install -e .[test] - name: Download model config files run: | hf download Skywork/Skywork-Reward-V2-Llama-3.2-1B --local-dir $HOME/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B - - name: Running discriminative reward model tests on 8 L20 GPUs run: | pytest -s -x tests/workers/reward_model/test_reward_model.py diff --git a/tests/workers/reward_model/test_reward_model.py b/tests/workers/reward_model/test_reward_model.py index 001c5f2903b..092ce527bfd 100644 --- a/tests/workers/reward_model/test_reward_model.py +++ b/tests/workers/reward_model/test_reward_model.py @@ -25,7 +25,7 @@ from verl.workers.roles import RewardModelWorker -def _create_data_samples(tokenizer) -> DataProto: +def create_data_samples(tokenizer) -> DataProto: convs = [ [ { @@ -115,7 +115,7 @@ def test_reward_model(): # create data samples tokenizer = model_config.get_processor() - data = _create_data_samples(tokenizer) + data = create_data_samples(tokenizer) gen_batch = rm_wg.compute_rm_score(data) server_rm_scores = gen_batch.batch["rm_scores"].sum(dim=-1) From b4d103eca17db4538097c8e471f0df3208bc4771 Mon Sep 17 00:00:00 2001 From: yuyangding Date: Mon, 8 Sep 2025 23:22:13 +0800 Subject: [PATCH 10/14] restore other files and debug ci only --- verl/workers/config/reward_model.py | 56 ++++++++++++++++--- verl/workers/config/rollout.py | 3 +- .../reward_model/sglang_reward_model.py | 7 ++- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/verl/workers/config/reward_model.py b/verl/workers/config/reward_model.py index ed677f87b36..361278ecb92 100644 --- a/verl/workers/config/reward_model.py +++ b/verl/workers/config/reward_model.py @@ -19,9 +19,21 @@ from verl.utils.profiler import ProfilerConfig from .model import HFModelConfig -from .rollout import SamplingConfig, ServerConfig -__all__ = ["SandboxFusionConfig", "RewardModelConfig"] +__all__ = ["ServerConfig", "SandboxFusionConfig", "RewardModelConfig"] + + +@dataclass +class ServerConfig(BaseConfig): + """ + Configuration for SGLang server when running in server mode + """ + + timeout: float = 60.0 + max_attempts: int = 3 + retry_delay: float = 2.0 + max_connections: int = 1000 + max_start_wait_time: float = 300.0 @dataclass @@ -41,28 +53,54 @@ class SandboxFusionConfig(BaseConfig): @dataclass class RewardModelConfig(BaseConfig): + """Configuration for reward model scoring. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + enable (bool): Whether to enable reward model. + enable_resource_pool (bool): Whether to deploy the model to a separate resource pool. + n_gpus_per_node (int): Number of GPUs per node when using resource pool. + nnodes (int): Number of nodes when using resource pool. + strategy (str): FSDP strategy: "fsdp" or "fsdp2". + model (Dict[str, Any]): Model configuration for reward scoring. + micro_batch_size (Optional[int]): Global micro batch size (deprecated). + micro_batch_size_per_gpu (Optional[int]): Local per-GPU micro batch size. + max_length (Optional[int]): Maximum sequence length to process for scoring. + use_dynamic_bsz (bool): Whether to dynamically adjust batch size at runtime. + forward_max_token_len_per_gpu (int): Maximum number of tokens per GPU in one forward pass. + reward_manager (str): Reward manager type (naive or prime). + launch_reward_fn_async (bool): Whether to launch custom reward function asynchronously during log_prob. + sandbox_fusion (Dict[str, Any]): Cloud/local sandbox fusion configuration for custom reward logic. + profiler (Dict[str, Any]): Profiler configuration for reward model. + """ + _mutable_fields = BaseConfig._mutable_fields enable: bool = False - model_type: str = "discriminative" enable_resource_pool: bool = False n_gpus_per_node: int = 0 nnodes: int = 0 + # strategy: str = MISSING + # model: BaseModelConfig = field(default_factory=BaseModelConfig) + # micro_batch_size: Optional[int] = None + # micro_batch_size_per_gpu: Optional[int] = None + # max_length: Optional[int] = None + # use_dynamic_bsz: bool = False + # forward_max_token_len_per_gpu: int = 32768 reward_manager: str = "naive" launch_reward_fn_async: bool = False - dtype: str = "bfloat16" - gpu_memory_utilization: float = 0.5 - free_cache_engine: bool = True tensor_model_parallel_size: int = 2 - sampling_config: SamplingConfig = field(default_factory=SamplingConfig) - engine_kwargs: dict = field(default_factory=dict) max_num_seqs: int = 1024 + dtype: str = "bfloat16" + gpu_memory_utilization: float = 0.5 + free_cache_engine: bool = True sandbox_fusion: SandboxFusionConfig = field(default_factory=SandboxFusionConfig) profiler: ProfilerConfig = field(default_factory=ProfilerConfig) input_model_config: HFModelConfig = field(default_factory=HFModelConfig) model_config: HFModelConfig = field(default_factory=HFModelConfig) # Server configuration for sglang server mode - server: ServerConfig = field(default_factory=ServerConfig) + server: ServerConfig = field(default_factory=ServerConfig) \ No newline at end of file diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index 94018d7b642..388d0434b12 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -26,7 +26,6 @@ "CustomAsyncServerConfig", "AgentLoopConfig", "TraceConfig", - "ServerConfig", "RolloutConfig", ] @@ -165,4 +164,4 @@ class RolloutConfig(BaseConfig): sglang_engine_mode: str = "local" - limit_images: Optional[int] = None + limit_images: Optional[int] = None \ No newline at end of file diff --git a/verl/workers/reward_model/sglang_reward_model.py b/verl/workers/reward_model/sglang_reward_model.py index 1475d46dbb2..ecf6debfd19 100644 --- a/verl/workers/reward_model/sglang_reward_model.py +++ b/verl/workers/reward_model/sglang_reward_model.py @@ -66,13 +66,14 @@ def __init__( actor_module = model_config.local_path trust_remote_code = model_config.trust_remote_code port = None + kwargs = {} os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") - self._init_distributed_env(device_mesh_cpu=None) + self._init_distributed_env(device_mesh_cpu=None, **kwargs) self._init_inference_engine(trust_remote_code, actor_module, port) - def _init_distributed_env(self, device_mesh_cpu): + def _init_distributed_env(self, device_mesh_cpu, **kwargs): self._device_mesh_cpu = device_mesh_cpu os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") self.tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) @@ -221,4 +222,4 @@ async def resume(self, tags: list[str]): async def release(self): """Release weights and kv cache in GPU memory.""" if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.config.free_cache_engine: - await self._engine.release_memory_occupation(tags=["kv_cache", "weights"]) + await self._engine.release_memory_occupation(tags=["kv_cache", "weights"]) \ No newline at end of file From 01ec0188c0c043a46643511b0d8081bd3e3990b9 Mon Sep 17 00:00:00 2001 From: yuyangding Date: Mon, 8 Sep 2025 23:24:07 +0800 Subject: [PATCH 11/14] restore other files and debug ci only --- verl/workers/config/reward_model.py | 2 +- verl/workers/config/rollout.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/verl/workers/config/reward_model.py b/verl/workers/config/reward_model.py index 361278ecb92..fea41bfc5ec 100644 --- a/verl/workers/config/reward_model.py +++ b/verl/workers/config/reward_model.py @@ -103,4 +103,4 @@ class RewardModelConfig(BaseConfig): input_model_config: HFModelConfig = field(default_factory=HFModelConfig) model_config: HFModelConfig = field(default_factory=HFModelConfig) # Server configuration for sglang server mode - server: ServerConfig = field(default_factory=ServerConfig) \ No newline at end of file + server: ServerConfig = field(default_factory=ServerConfig) diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index 388d0434b12..00e2b8928c9 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -164,4 +164,4 @@ class RolloutConfig(BaseConfig): sglang_engine_mode: str = "local" - limit_images: Optional[int] = None \ No newline at end of file + limit_images: Optional[int] = None From 9d62ca9e2b307176f7b146d1eb91151495e2dc27 Mon Sep 17 00:00:00 2001 From: yuyangding Date: Mon, 8 Sep 2025 23:24:33 +0800 Subject: [PATCH 12/14] restore other files and debug ci only --- verl/workers/reward_model/sglang_reward_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/reward_model/sglang_reward_model.py b/verl/workers/reward_model/sglang_reward_model.py index ecf6debfd19..1c000eadca0 100644 --- a/verl/workers/reward_model/sglang_reward_model.py +++ b/verl/workers/reward_model/sglang_reward_model.py @@ -222,4 +222,4 @@ async def resume(self, tags: list[str]): async def release(self): """Release weights and kv cache in GPU memory.""" if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.config.free_cache_engine: - await self._engine.release_memory_occupation(tags=["kv_cache", "weights"]) \ No newline at end of file + await self._engine.release_memory_occupation(tags=["kv_cache", "weights"]) From 93a3b7d915b505d9666809ee01394f769ceb2a25 Mon Sep 17 00:00:00 2001 From: yuyangding Date: Mon, 8 Sep 2025 23:26:41 +0800 Subject: [PATCH 13/14] fix --- tests/workers/reward_model/test_reward_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/workers/reward_model/test_reward_model.py b/tests/workers/reward_model/test_reward_model.py index 092ce527bfd..a708230795d 100644 --- a/tests/workers/reward_model/test_reward_model.py +++ b/tests/workers/reward_model/test_reward_model.py @@ -101,7 +101,6 @@ def test_reward_model(): model_config = HFModelConfig(path=rm_path) config = RewardModelConfig( enable=True, - model_type="discriminative", dtype="bfloat16", model_config=model_config, input_model_config=None, From 6a0d181422566ccffdc6ee5c3fc303eba7f44c24 Mon Sep 17 00:00:00 2001 From: Yuyang Ding <61647442+yyDing1@users.noreply.github.com> Date: Tue, 9 Sep 2025 02:49:28 +0800 Subject: [PATCH 14/14] Update reward_model.yml --- .github/workflows/reward_model.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/reward_model.yml b/.github/workflows/reward_model.yml index 8fa8ed07019..cb9d0fa48eb 100644 --- a/.github/workflows/reward_model.yml +++ b/.github/workflows/reward_model.yml @@ -86,4 +86,5 @@ jobs: hf download Skywork/Skywork-Reward-V2-Llama-3.2-1B --local-dir $HOME/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B - name: Running discriminative reward model tests on 8 L20 GPUs run: | + unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY pytest -s -x tests/workers/reward_model/test_reward_model.py