Skip to content

Commit 647c30e

Browse files
authored
[reward] support batch reward (#271)
1 parent c024735 commit 647c30e

File tree

10 files changed

+89
-40
lines changed

10 files changed

+89
-40
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.PHONY: build commit quality style
22

3-
check_dirs := scripts verl setup.py
3+
check_dirs := examples scripts verl setup.py
44

55
build:
66
python3 setup.py sdist bdist_wheel

examples/baselines/qwen2_5_vl_3b_clevr.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ python3 -m verl.trainer.main \
99
data.format_prompt=./examples/format_prompt/r1v_format.jinja \
1010
worker.actor.model.model_path=${MODEL_PATH} \
1111
worker.rollout.tensor_parallel_size=1 \
12+
worker.reward.reward_type=sequential \
1213
worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
1314
trainer.experiment_name=qwen2_5_vl_3b_clevr \
1415
trainer.n_gpus_per_node=2

examples/baselines/qwen2_5_vl_3b_geoqa8k.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ python3 -m verl.trainer.main \
99
data.format_prompt=./examples/format_prompt/r1v_format.jinja \
1010
worker.actor.model.model_path=${MODEL_PATH} \
1111
worker.rollout.tensor_parallel_size=1 \
12+
worker.reward.reward_type=sequential \
1213
worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
1314
trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
1415
trainer.n_gpus_per_node=8

examples/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ worker:
7171
offload_params: false
7272

7373
reward:
74-
reward_type: function
74+
reward_type: batch
7575
reward_function: ./examples/reward_function/math.py:compute_score
7676

7777
trainer:

examples/reward_function/math.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,34 @@
1313
# limitations under the License.
1414

1515
import re
16-
from typing import Dict
16+
from typing import Dict, List
1717

1818
from mathruler.grader import extract_boxed_content, grade_answer
1919

2020

21-
def format_reward(predict_str: str) -> float:
21+
def format_reward(predict: str) -> float:
2222
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
23-
format_match = re.fullmatch(pattern, predict_str)
23+
format_match = re.fullmatch(pattern, predict)
2424
return 1.0 if format_match else 0.0
2525

2626

27-
def accuracy_reward(predict_str: str, ground_truth: str) -> float:
28-
answer = extract_boxed_content(predict_str)
27+
def accuracy_reward(predict: str, ground_truth: str) -> float:
28+
answer = extract_boxed_content(predict)
2929
return 1.0 if grade_answer(answer, ground_truth) else 0.0
3030

3131

32-
def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.1) -> Dict[str, float]:
33-
predict_str = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str) # handle qwen2.5vl-32b format
34-
format_score = format_reward(predict_str)
35-
accuracy_score = accuracy_reward(predict_str, ground_truth)
36-
return {
37-
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
38-
"format": format_score,
39-
"accuracy": accuracy_score,
40-
}
32+
def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
33+
scores = []
34+
for predict, ground_truth in zip(predicts, ground_truths):
35+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format
36+
format_score = format_reward(predict)
37+
accuracy_score = accuracy_reward(predict, ground_truth)
38+
scores.append(
39+
{
40+
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
41+
"format": format_score,
42+
"accuracy": accuracy_score,
43+
}
44+
)
45+
46+
return scores

examples/reward_function/r1v.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@
1818
from mathruler.grader import grade_answer
1919

2020

21-
def format_reward(predict_str: str) -> float:
21+
def format_reward(predict: str) -> float:
2222
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
23-
format_match = re.fullmatch(pattern, predict_str)
23+
format_match = re.fullmatch(pattern, predict)
2424
return 1.0 if format_match else 0.0
2525

2626

27-
def accuracy_reward(predict_str: str, ground_truth: str) -> float:
27+
def accuracy_reward(predict: str, ground_truth: str) -> float:
2828
try:
29-
content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
30-
given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
29+
content_match = re.search(r"<answer>(.*?)</answer>", predict)
30+
given_answer = content_match.group(1).strip() if content_match else predict.strip()
3131
if grade_answer(given_answer, ground_truth.strip()):
3232
return 1.0
3333

@@ -37,9 +37,9 @@ def accuracy_reward(predict_str: str, ground_truth: str) -> float:
3737
return 0.0
3838

3939

40-
def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
41-
format_score = format_reward(predict_str)
42-
accuracy_score = accuracy_reward(predict_str, ground_truth)
40+
def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
41+
format_score = format_reward(predict)
42+
accuracy_score = accuracy_reward(predict, ground_truth)
4343
return {
4444
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
4545
"format": format_score,

verl/trainer/main.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..single_controller.ray import RayWorkerGroup
2121
from ..utils.tokenizer import get_processor, get_tokenizer
2222
from ..workers.fsdp_workers import FSDPWorker
23-
from ..workers.reward import FunctionRewardManager
23+
from ..workers.reward import BatchFunctionRewardManager, SequentialFunctionRewardManager
2424
from .config import PPOConfig
2525
from .data_loader import create_dataloader
2626
from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
@@ -67,7 +67,14 @@ def run(self, config: PPOConfig):
6767
}
6868
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
6969

70-
RemoteRewardManager = ray.remote(FunctionRewardManager).options(num_cpus=config.worker.reward.num_cpus)
70+
if config.worker.reward.reward_type == "sequential":
71+
RewardManager = SequentialFunctionRewardManager
72+
elif config.worker.reward.reward_type == "batch":
73+
RewardManager = BatchFunctionRewardManager
74+
else:
75+
raise NotImplementedError(f"Unknown reward type {config.worker.reward.reward_type}.")
76+
77+
RemoteRewardManager = ray.remote(RewardManager).options(num_cpus=config.worker.reward.num_cpus)
7178
reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
7279
val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
7380

verl/workers/reward/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from .config import RewardConfig
16-
from .function import FunctionRewardManager
16+
from .function import BatchFunctionRewardManager, FunctionRewardManager, SequentialFunctionRewardManager
1717

1818

19-
__all__ = ["FunctionRewardManager", "RewardConfig"]
19+
__all__ = ["BatchFunctionRewardManager", "FunctionRewardManager", "RewardConfig", "SequentialFunctionRewardManager"]

verl/workers/reward/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
@dataclass
2424
class RewardConfig:
25-
reward_type: str = "function"
25+
reward_type: str = "batch"
2626
reward_function: Optional[str] = None
2727
reward_function_kwargs: dict = field(default_factory=dict)
2828
skip_special_tokens: bool = True

verl/workers/reward/function.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import importlib.util
1616
import os
1717
import sys
18+
from abc import ABC, abstractmethod
1819
from collections import defaultdict
1920
from functools import partial
2021
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
@@ -32,10 +33,12 @@ class RewardScore(TypedDict):
3233
accuracy: Optional[float]
3334

3435

35-
RewardFunction = Callable[[str, str], RewardScore]
36+
SequentialRewardFunction = Callable[[str, str], RewardScore]
3637

38+
BatchRewardFunction = Callable[[List[str], List[str]], List[RewardScore]]
3739

38-
class FunctionRewardManager:
40+
41+
class FunctionRewardManager(ABC):
3942
"""Reward manager for rule-based reward."""
4043

4144
def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
@@ -56,29 +59,60 @@ def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
5659
if not hasattr(module, config.reward_function_name):
5760
raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.")
5861

59-
reward_fn: RewardFunction = getattr(module, config.reward_function_name)
62+
reward_fn = getattr(module, config.reward_function_name)
6063
print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.")
6164
self.reward_fn = partial(reward_fn, **config.reward_function_kwargs)
6265
self.config = config
6366
self.tokenizer = tokenizer
6467

68+
@abstractmethod
69+
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
70+
"""Compute reward for a batch of data."""
71+
...
72+
73+
74+
class SequentialFunctionRewardManager(FunctionRewardManager):
75+
reward_fn: SequentialRewardFunction
76+
6577
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
6678
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
6779
reward_metrics = defaultdict(list)
80+
response_ids = data.batch["responses"]
81+
response_length = data.batch["response_mask"].sum(dim=-1)
6882
for i in range(len(data)):
69-
data_item = data[i] # DataProtoItem
70-
response_ids = data_item.batch["responses"]
71-
response_mask = data_item.batch["response_mask"]
72-
valid_response_length = response_mask.sum()
73-
valid_response_ids = response_ids[:valid_response_length]
74-
83+
valid_response_ids = response_ids[i][: response_length[i]]
7584
response_str = self.tokenizer.decode(
7685
valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
7786
)
78-
ground_truth = data_item.non_tensor_batch["ground_truth"]
87+
ground_truth = data.non_tensor_batch["ground_truth"][i]
7988

8089
score = self.reward_fn(response_str, ground_truth)
81-
reward_tensor[i, valid_response_length - 1] = score["overall"]
90+
reward_tensor[i, response_length[i] - 1] = score["overall"]
91+
for key, value in score.items():
92+
reward_metrics[key].append(value)
93+
94+
return reward_tensor, reward_metrics
95+
96+
97+
class BatchFunctionRewardManager(FunctionRewardManager):
98+
reward_fn: BatchRewardFunction
99+
100+
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
101+
response_str, ground_truth = [], []
102+
response_ids = data.batch["responses"]
103+
response_length = data.batch["response_mask"].sum(dim=-1)
104+
for i in range(len(data)):
105+
valid_response_ids = response_ids[i][: response_length[i]]
106+
response_str.append(
107+
self.tokenizer.decode(valid_response_ids, skip_special_tokens=self.config.skip_special_tokens)
108+
)
109+
ground_truth.append(data.non_tensor_batch["ground_truth"][i])
110+
111+
scores = self.reward_fn(response_str, ground_truth)
112+
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
113+
reward_metrics = defaultdict(list)
114+
for i, score in enumerate(scores):
115+
reward_tensor[i, response_length[i] - 1] = score["overall"]
82116
for key, value in score.items():
83117
reward_metrics[key].append(value)
84118

0 commit comments

Comments
 (0)