Skip to content
90 changes: 90 additions & 0 deletions .github/workflows/reward_model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# # Tests layout

# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance:
# - `tests/trainer` for testing functionality related to `verl/trainer`
# - `tests/models` for testing functionality related to `verl/models`
# - ...

# There are a few folders with `special_` prefix, created for special purposes:
# - `special_distributed`: unit tests that must run with multiple GPUs
# - `special_e2e`: end-to-end tests with training/generation scripts
# - `special_npu`: tests for NPUs
# - `special_sanity`: a suite of quick sanity tests
# - `special_standalone`: a set of test that are designed to run in dedicated environments

# Accelerators for tests
# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`.
# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment.

# # Workflow layout

# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs:
# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml`
# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml`
# 3. End-to-end tests: `e2e_*.yml`
# 4. Unit tests
# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py`
# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix.
# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when
# - new workflow yaml is added to `.github/workflows`
# - new tests are added to workflow mentioned in 2.
# name: Check PR Title

name: reward_model

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
- v0.*
pull_request:
branches:
- main
- v0.*
paths:
- "verl/**/*.py"
# Entrypoints
- ".github/workflows/reward_model.yml"
- "tests/workers/reward_model/**"

# Declare permissions just read content.
permissions:
contents: read

# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}

jobs:
discriminative_reward_model:
runs-on: [L20x8]
timeout-minutes: 20 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK: "True"
NCCL_SHM_DISABLE: "1"
NCCL_P2P_DISABLE: "1"
container:
image: verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install -e .[test]
- name: Download model config files
run: |
hf download Skywork/Skywork-Reward-V2-Llama-3.2-1B --local-dir $HOME/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B
- name: Running discriminative reward model tests on 8 L20 GPUs
run: |
unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY
pytest -s -x tests/workers/reward_model/test_reward_model.py
180 changes: 79 additions & 101 deletions tests/workers/reward_model/test_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,138 +11,79 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import ray
import torch
from hydra import compose, initialize_config_dir
from transformers import AutoTokenizer

from verl.protocol import DataProto
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role


def test_agent_loop_compute_score_with_model():
ray.init(
runtime_env={
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "INFO",
"VLLM_USE_V1": "1",
}
}
)
from transformers import AutoModelForSequenceClassification

with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
config = compose("ppo_trainer")
from verl import DataProto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils.model import compute_position_id_with_mask
from verl.workers.config import HFModelConfig, RewardModelConfig
from verl.workers.roles import RewardModelWorker

rm_path = "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2"

if os.environ["LEGACY_IMPL_RM"] == "disable":
from verl.workers.config import HFModelConfig, RewardModelConfig
from verl.workers.roles import RewardModelWorker

model_config = HFModelConfig(path=rm_path)
reward_model_config = RewardModelConfig(
enable=True,
model_config=model_config,
input_model_config=None,
tensor_model_parallel_size=1,
gpu_memory_utilization=0.8,
)
else:
from verl.workers.fsdp_workers import RewardModelWorker

config.reward_model.enable = True
config.reward_model.model.path = rm_path
config.reward_model.use_dynamic_bsz = True
config.reward_model.forward_max_token_len_per_gpu = 6000
config.reward_model.micro_batch_size_per_gpu = 40
config.reward_model.model.trust_remote_code = True
config.reward_model.model.input_tokenizer = None
reward_model_config = config.reward_model

config.trainer.n_gpus_per_node = 2
config.trainer.nnodes = 1

role_worker_mapping = {}
if reward_model_config.enable:
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)

global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {}
mapping[Role.RewardModel] = "global_pool"
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
resource_pool_manager.create_resource_pool()
resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()}

if reward_model_config.enable:
# we create a RM here
resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(role_worker_mapping[Role.RewardModel], config=reward_model_config)
resource_pool_to_cls[resource_pool]["rm"] = rm_cls

all_wg = {}
for resource_pool, class_dict in resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)

rm_wg = all_wg["rm"]
rm_wg.init_model()

def create_data_samples(tokenizer) -> DataProto:
convs = [
[
{
"role": "user",
"content": "What is the range of the numeric output of a sigmoid node in a neural network?",
},
{"role": "assistant", "content": "The output is bounded between -1 and 1."},
{"role": "assistant", "content": "Between -1 and 1."},
],
[
{
"role": "user",
"content": "What is the range of the numeric output of a sigmoid node in a neural network?",
},
{"role": "assistant", "content": "The output is bounded between 0 and 1."},
{"role": "assistant", "content": "Between 0 and 1."},
],
[
{"role": "user", "content": "What is the capital of Australia?"},
{
"role": "assistant",
"content": "Canberra is the capital city of Australia.",
},
],
[
{"role": "user", "content": "What is the capital of Australia?"},
{
"role": "assistant",
"content": "Sydney is the capital of Australia.",
},
],
]
tokenizer = AutoTokenizer.from_pretrained(rm_path)

prompt_length, response_length = 1024, 4096
pad_token_id = tokenizer.pad_token_id
prompts, responses, input_ids, attention_masks, position_ids = [], [], [], [], []
prompts, responses, input_ids, attention_masks = [], [], [], []
for conv in convs:
prompt = tokenizer.apply_chat_template(conv[:1], tokenize=True)
response = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt) :]
prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True)
response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :]

padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens
padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens))
attention_mask = (
[0] * (prompt_length - len(prompt))
+ [1] * len(prompt)
+ [1] * len(response)
+ [0] * (response_length - len(response))
[0] * (prompt_length - len(prompt_tokens))
+ [1] * len(prompt_tokens)
+ [1] * len(response_tokens)
+ [0] * (response_length - len(response_tokens))
)
prompt = [pad_token_id] * (prompt_length - len(prompt)) + prompt
response = response + [pad_token_id] * (response_length - len(response))
prompts.append(torch.tensor(prompt))
responses.append(torch.tensor(response))
input_ids.append(torch.tensor(prompt + response))
prompts.append(torch.tensor(padded_prompt))
responses.append(torch.tensor(padded_response))
input_ids.append(torch.tensor(padded_prompt + padded_response))
attention_masks.append(torch.tensor(attention_mask))

from verl.utils.model import compute_position_id_with_mask

prompts = torch.stack(prompts)
responses = torch.stack(responses)
input_ids = torch.stack(input_ids)
attention_masks = torch.stack(attention_masks)
position_ids = compute_position_id_with_mask(attention_masks)
data = DataProto.from_dict(

return DataProto.from_dict(
tensors={
"prompts": prompts,
"responses": responses,
Expand All @@ -151,8 +92,45 @@ def test_agent_loop_compute_score_with_model():
"position_ids": position_ids,
},
)


def test_reward_model():
ray.init()

rm_path = os.path.expanduser("~/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B")
model_config = HFModelConfig(path=rm_path)
config = RewardModelConfig(
enable=True,
dtype="bfloat16",
model_config=model_config,
input_model_config=None,
tensor_model_parallel_size=2,
)
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(RewardModelWorker), config=config)
resource_pool = RayResourcePool(process_on_nodes=[8])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding 8 GPUs for a CI test might make it flaky if the CI environment doesn't have that many GPUs available. Since tensor_model_parallel_size is 2, using 2 GPUs should be sufficient and more robust for a CI environment.

Suggested change
resource_pool = RayResourcePool(process_on_nodes=[8])
resource_pool = RayResourcePool(process_on_nodes=[2])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data parallel with multiple server instances should be tested, where more gpus are needed.

rm_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
# init model
rm_wg.init_model()

# create data samples
tokenizer = model_config.get_processor()
data = create_data_samples(tokenizer)

gen_batch = rm_wg.compute_rm_score(data)
rm_scores = gen_batch.batch["rm_scores"]
sample_scores = rm_scores.sum(dim=1)
print(sample_scores)
server_rm_scores = gen_batch.batch["rm_scores"].sum(dim=-1)
print(f"{server_rm_scores=}")
server_rm_scores_mean = torch.mean(server_rm_scores)

hf_model = AutoModelForSequenceClassification.from_pretrained(rm_path, torch_dtype=torch.bfloat16)
hf_model.pad_token_id = tokenizer.pad_token_id
hf_output = hf_model(
input_ids=data.batch["input_ids"],
attention_mask=data.batch["attention_mask"],
)
hf_rm_scores = hf_output.logits.squeeze().detach().to("cpu")
print(f"{hf_rm_scores=}")
hf_rm_scores_mean = torch.mean(hf_rm_scores).to(server_rm_scores.dtype)

torch.testing.assert_close(server_rm_scores_mean, hf_rm_scores_mean, atol=2e-2, rtol=1e-2)

ray.shutdown()