Skip to content

Commit d3dc066

Browse files
yyDing1gemini-code-assist[bot]
authored andcommitted
[ci] refactor: add ci test for refactored reward worker and add some args to GenRM config (volcengine#3385)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. - add ci test for new reward model (accuracy check for the results of server mode rm and hf rm) - add some args for genrm (e.g., reward_type, sampling parameters) ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent e3cc337 commit d3dc066

File tree

2 files changed

+169
-101
lines changed

2 files changed

+169
-101
lines changed

.github/workflows/reward_model.yml

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# # Tests layout
2+
3+
# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance:
4+
# - `tests/trainer` for testing functionality related to `verl/trainer`
5+
# - `tests/models` for testing functionality related to `verl/models`
6+
# - ...
7+
8+
# There are a few folders with `special_` prefix, created for special purposes:
9+
# - `special_distributed`: unit tests that must run with multiple GPUs
10+
# - `special_e2e`: end-to-end tests with training/generation scripts
11+
# - `special_npu`: tests for NPUs
12+
# - `special_sanity`: a suite of quick sanity tests
13+
# - `special_standalone`: a set of test that are designed to run in dedicated environments
14+
15+
# Accelerators for tests
16+
# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`.
17+
# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment.
18+
19+
# # Workflow layout
20+
21+
# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs:
22+
# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml`
23+
# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml`
24+
# 3. End-to-end tests: `e2e_*.yml`
25+
# 4. Unit tests
26+
# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py`
27+
# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix.
28+
# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when
29+
# - new workflow yaml is added to `.github/workflows`
30+
# - new tests are added to workflow mentioned in 2.
31+
# name: Check PR Title
32+
33+
name: reward_model
34+
35+
on:
36+
# Trigger the workflow on push or pull request,
37+
# but only for the main branch
38+
push:
39+
branches:
40+
- main
41+
- v0.*
42+
pull_request:
43+
branches:
44+
- main
45+
- v0.*
46+
paths:
47+
- "verl/**/*.py"
48+
# Entrypoints
49+
- ".github/workflows/reward_model.yml"
50+
- "tests/workers/reward_model/**"
51+
52+
# Declare permissions just read content.
53+
permissions:
54+
contents: read
55+
56+
# Cancel jobs on the same ref if a new one is triggered
57+
concurrency:
58+
group: ${{ github.workflow }}-${{ github.ref }}
59+
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
60+
61+
jobs:
62+
discriminative_reward_model:
63+
runs-on: [L20x8]
64+
timeout-minutes: 20 # Increase this timeout value as needed
65+
env:
66+
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
67+
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
68+
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
69+
HF_ENDPOINT: "https://hf-mirror.com"
70+
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
71+
SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK: "True"
72+
NCCL_SHM_DISABLE: "1"
73+
NCCL_P2P_DISABLE: "1"
74+
container:
75+
image: verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2
76+
options: --gpus all --shm-size=10g
77+
steps:
78+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
79+
with:
80+
fetch-depth: 0
81+
- name: Install the current repository
82+
run: |
83+
pip3 install -e .[test]
84+
- name: Download model config files
85+
run: |
86+
hf download Skywork/Skywork-Reward-V2-Llama-3.2-1B --local-dir $HOME/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B
87+
- name: Running discriminative reward model tests on 8 L20 GPUs
88+
run: |
89+
unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY
90+
pytest -s -x tests/workers/reward_model/test_reward_model.py

tests/workers/reward_model/test_reward_model.py

Lines changed: 79 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -11,138 +11,79 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import os
1516

1617
import ray
1718
import torch
18-
from hydra import compose, initialize_config_dir
19-
from transformers import AutoTokenizer
20-
21-
from verl.protocol import DataProto
22-
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
23-
from verl.single_controller.ray.base import create_colocated_worker_cls
24-
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
25-
26-
27-
def test_agent_loop_compute_score_with_model():
28-
ray.init(
29-
runtime_env={
30-
"env_vars": {
31-
"TOKENIZERS_PARALLELISM": "true",
32-
"NCCL_DEBUG": "WARN",
33-
"VLLM_LOGGING_LEVEL": "INFO",
34-
"VLLM_USE_V1": "1",
35-
}
36-
}
37-
)
19+
from transformers import AutoModelForSequenceClassification
3820

39-
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
40-
config = compose("ppo_trainer")
21+
from verl import DataProto
22+
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
23+
from verl.utils.model import compute_position_id_with_mask
24+
from verl.workers.config import HFModelConfig, RewardModelConfig
25+
from verl.workers.roles import RewardModelWorker
4126

42-
rm_path = "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2"
43-
44-
if os.environ["LEGACY_IMPL_RM"] == "disable":
45-
from verl.workers.config import HFModelConfig, RewardModelConfig
46-
from verl.workers.roles import RewardModelWorker
47-
48-
model_config = HFModelConfig(path=rm_path)
49-
reward_model_config = RewardModelConfig(
50-
enable=True,
51-
model_config=model_config,
52-
input_model_config=None,
53-
tensor_model_parallel_size=1,
54-
gpu_memory_utilization=0.8,
55-
)
56-
else:
57-
from verl.workers.fsdp_workers import RewardModelWorker
58-
59-
config.reward_model.enable = True
60-
config.reward_model.model.path = rm_path
61-
config.reward_model.use_dynamic_bsz = True
62-
config.reward_model.forward_max_token_len_per_gpu = 6000
63-
config.reward_model.micro_batch_size_per_gpu = 40
64-
config.reward_model.model.trust_remote_code = True
65-
config.reward_model.model.input_tokenizer = None
66-
reward_model_config = config.reward_model
67-
68-
config.trainer.n_gpus_per_node = 2
69-
config.trainer.nnodes = 1
70-
71-
role_worker_mapping = {}
72-
if reward_model_config.enable:
73-
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
74-
75-
global_pool_id = "global_pool"
76-
resource_pool_spec = {
77-
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
78-
}
79-
mapping = {}
80-
mapping[Role.RewardModel] = "global_pool"
81-
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
82-
resource_pool_manager.create_resource_pool()
83-
resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()}
84-
85-
if reward_model_config.enable:
86-
# we create a RM here
87-
resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel)
88-
rm_cls = RayClassWithInitArgs(role_worker_mapping[Role.RewardModel], config=reward_model_config)
89-
resource_pool_to_cls[resource_pool]["rm"] = rm_cls
90-
91-
all_wg = {}
92-
for resource_pool, class_dict in resource_pool_to_cls.items():
93-
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
94-
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
95-
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
96-
all_wg.update(spawn_wg)
97-
98-
rm_wg = all_wg["rm"]
99-
rm_wg.init_model()
10027

28+
def create_data_samples(tokenizer) -> DataProto:
10129
convs = [
10230
[
10331
{
10432
"role": "user",
10533
"content": "What is the range of the numeric output of a sigmoid node in a neural network?",
10634
},
107-
{"role": "assistant", "content": "The output is bounded between -1 and 1."},
35+
{"role": "assistant", "content": "Between -1 and 1."},
10836
],
10937
[
11038
{
11139
"role": "user",
11240
"content": "What is the range of the numeric output of a sigmoid node in a neural network?",
11341
},
114-
{"role": "assistant", "content": "The output is bounded between 0 and 1."},
42+
{"role": "assistant", "content": "Between 0 and 1."},
43+
],
44+
[
45+
{"role": "user", "content": "What is the capital of Australia?"},
46+
{
47+
"role": "assistant",
48+
"content": "Canberra is the capital city of Australia.",
49+
},
50+
],
51+
[
52+
{"role": "user", "content": "What is the capital of Australia?"},
53+
{
54+
"role": "assistant",
55+
"content": "Sydney is the capital of Australia.",
56+
},
11557
],
11658
]
117-
tokenizer = AutoTokenizer.from_pretrained(rm_path)
11859

11960
prompt_length, response_length = 1024, 4096
12061
pad_token_id = tokenizer.pad_token_id
121-
prompts, responses, input_ids, attention_masks, position_ids = [], [], [], [], []
62+
prompts, responses, input_ids, attention_masks = [], [], [], []
12263
for conv in convs:
123-
prompt = tokenizer.apply_chat_template(conv[:1], tokenize=True)
124-
response = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt) :]
64+
prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True)
65+
response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :]
66+
67+
padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens
68+
padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens))
12569
attention_mask = (
126-
[0] * (prompt_length - len(prompt))
127-
+ [1] * len(prompt)
128-
+ [1] * len(response)
129-
+ [0] * (response_length - len(response))
70+
[0] * (prompt_length - len(prompt_tokens))
71+
+ [1] * len(prompt_tokens)
72+
+ [1] * len(response_tokens)
73+
+ [0] * (response_length - len(response_tokens))
13074
)
131-
prompt = [pad_token_id] * (prompt_length - len(prompt)) + prompt
132-
response = response + [pad_token_id] * (response_length - len(response))
133-
prompts.append(torch.tensor(prompt))
134-
responses.append(torch.tensor(response))
135-
input_ids.append(torch.tensor(prompt + response))
75+
prompts.append(torch.tensor(padded_prompt))
76+
responses.append(torch.tensor(padded_response))
77+
input_ids.append(torch.tensor(padded_prompt + padded_response))
13678
attention_masks.append(torch.tensor(attention_mask))
13779

138-
from verl.utils.model import compute_position_id_with_mask
139-
14080
prompts = torch.stack(prompts)
14181
responses = torch.stack(responses)
14282
input_ids = torch.stack(input_ids)
14383
attention_masks = torch.stack(attention_masks)
14484
position_ids = compute_position_id_with_mask(attention_masks)
145-
data = DataProto.from_dict(
85+
86+
return DataProto.from_dict(
14687
tensors={
14788
"prompts": prompts,
14889
"responses": responses,
@@ -151,8 +92,45 @@ def test_agent_loop_compute_score_with_model():
15192
"position_ids": position_ids,
15293
},
15394
)
95+
96+
97+
def test_reward_model():
98+
ray.init()
99+
100+
rm_path = os.path.expanduser("~/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B")
101+
model_config = HFModelConfig(path=rm_path)
102+
config = RewardModelConfig(
103+
enable=True,
104+
dtype="bfloat16",
105+
model_config=model_config,
106+
input_model_config=None,
107+
tensor_model_parallel_size=2,
108+
)
109+
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(RewardModelWorker), config=config)
110+
resource_pool = RayResourcePool(process_on_nodes=[8])
111+
rm_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
112+
# init model
113+
rm_wg.init_model()
114+
115+
# create data samples
116+
tokenizer = model_config.get_processor()
117+
data = create_data_samples(tokenizer)
118+
154119
gen_batch = rm_wg.compute_rm_score(data)
155-
rm_scores = gen_batch.batch["rm_scores"]
156-
sample_scores = rm_scores.sum(dim=1)
157-
print(sample_scores)
120+
server_rm_scores = gen_batch.batch["rm_scores"].sum(dim=-1)
121+
print(f"{server_rm_scores=}")
122+
server_rm_scores_mean = torch.mean(server_rm_scores)
123+
124+
hf_model = AutoModelForSequenceClassification.from_pretrained(rm_path, torch_dtype=torch.bfloat16)
125+
hf_model.pad_token_id = tokenizer.pad_token_id
126+
hf_output = hf_model(
127+
input_ids=data.batch["input_ids"],
128+
attention_mask=data.batch["attention_mask"],
129+
)
130+
hf_rm_scores = hf_output.logits.squeeze().detach().to("cpu")
131+
print(f"{hf_rm_scores=}")
132+
hf_rm_scores_mean = torch.mean(hf_rm_scores).to(server_rm_scores.dtype)
133+
134+
torch.testing.assert_close(server_rm_scores_mean, hf_rm_scores_mean, atol=2e-2, rtol=1e-2)
135+
158136
ray.shutdown()

0 commit comments

Comments
 (0)