Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.PHONY: build commit quality style

check_dirs := scripts verl setup.py
check_dirs := examples scripts verl setup.py

build:
python3 setup.py sdist bdist_wheel
Expand Down
1 change: 1 addition & 0 deletions examples/baselines/qwen2_5_vl_3b_clevr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ python3 -m verl.trainer.main \
data.format_prompt=./examples/format_prompt/r1v_format.jinja \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \
worker.reward.reward_type=sequential \
worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_clevr \
trainer.n_gpus_per_node=2
1 change: 1 addition & 0 deletions examples/baselines/qwen2_5_vl_3b_geoqa8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ python3 -m verl.trainer.main \
data.format_prompt=./examples/format_prompt/r1v_format.jinja \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \
worker.reward.reward_type=sequential \
worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
trainer.n_gpus_per_node=8
2 changes: 1 addition & 1 deletion examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ worker:
offload_params: false

reward:
reward_type: function
reward_type: batch
reward_function: ./examples/reward_function/math.py:compute_score

trainer:
Expand Down
34 changes: 20 additions & 14 deletions examples/reward_function/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,34 @@
# limitations under the License.

import re
from typing import Dict
from typing import Dict, List

from mathruler.grader import extract_boxed_content, grade_answer


def format_reward(predict_str: str) -> float:
def format_reward(predict: str) -> float:
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str)
format_match = re.fullmatch(pattern, predict)
return 1.0 if format_match else 0.0


def accuracy_reward(predict_str: str, ground_truth: str) -> float:
answer = extract_boxed_content(predict_str)
def accuracy_reward(predict: str, ground_truth: str) -> float:
answer = extract_boxed_content(predict)
return 1.0 if grade_answer(answer, ground_truth) else 0.0


def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.1) -> Dict[str, float]:
predict_str = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str) # handle qwen2.5vl-32b format
format_score = format_reward(predict_str)
accuracy_score = accuracy_reward(predict_str, ground_truth)
return {
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
"format": format_score,
"accuracy": accuracy_score,
}
def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
scores = []
for predict, ground_truth in zip(predicts, ground_truths):
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format
format_score = format_reward(predict)
accuracy_score = accuracy_reward(predict, ground_truth)
scores.append(
{
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
"format": format_score,
"accuracy": accuracy_score,
}
)

return scores
16 changes: 8 additions & 8 deletions examples/reward_function/r1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
from mathruler.grader import grade_answer


def format_reward(predict_str: str) -> float:
def format_reward(predict: str) -> float:
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str)
format_match = re.fullmatch(pattern, predict)
return 1.0 if format_match else 0.0


def accuracy_reward(predict_str: str, ground_truth: str) -> float:
def accuracy_reward(predict: str, ground_truth: str) -> float:
try:
content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
content_match = re.search(r"<answer>(.*?)</answer>", predict)
given_answer = content_match.group(1).strip() if content_match else predict.strip()
if grade_answer(given_answer, ground_truth.strip()):
return 1.0

Expand All @@ -37,9 +37,9 @@ def accuracy_reward(predict_str: str, ground_truth: str) -> float:
return 0.0


def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
format_score = format_reward(predict_str)
accuracy_score = accuracy_reward(predict_str, ground_truth)
def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
format_score = format_reward(predict)
accuracy_score = accuracy_reward(predict, ground_truth)
return {
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
"format": format_score,
Expand Down
11 changes: 9 additions & 2 deletions verl/trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..single_controller.ray import RayWorkerGroup
from ..utils.tokenizer import get_processor, get_tokenizer
from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import FunctionRewardManager
from ..workers.reward import BatchFunctionRewardManager, SequentialFunctionRewardManager
from .config import PPOConfig
from .data_loader import create_dataloader
from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
Expand Down Expand Up @@ -67,7 +67,14 @@ def run(self, config: PPOConfig):
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

RemoteRewardManager = ray.remote(FunctionRewardManager).options(num_cpus=config.worker.reward.num_cpus)
if config.worker.reward.reward_type == "sequential":
RewardManager = SequentialFunctionRewardManager
elif config.worker.reward.reward_type == "batch":
RewardManager = BatchFunctionRewardManager
else:
raise NotImplementedError(f"Unknown reward type {config.worker.reward.reward_type}.")

RemoteRewardManager = ray.remote(RewardManager).options(num_cpus=config.worker.reward.num_cpus)
reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)

Expand Down
4 changes: 2 additions & 2 deletions verl/workers/reward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from .config import RewardConfig
from .function import FunctionRewardManager
from .function import BatchFunctionRewardManager, FunctionRewardManager, SequentialFunctionRewardManager


__all__ = ["FunctionRewardManager", "RewardConfig"]
__all__ = ["BatchFunctionRewardManager", "FunctionRewardManager", "RewardConfig", "SequentialFunctionRewardManager"]
2 changes: 1 addition & 1 deletion verl/workers/reward/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

@dataclass
class RewardConfig:
reward_type: str = "function"
reward_type: str = "batch"
reward_function: Optional[str] = None
reward_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True
Expand Down
56 changes: 45 additions & 11 deletions verl/workers/reward/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import importlib.util
import os
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
Expand All @@ -32,10 +33,12 @@ class RewardScore(TypedDict):
accuracy: Optional[float]


RewardFunction = Callable[[str, str], RewardScore]
SequentialRewardFunction = Callable[[str, str], RewardScore]

BatchRewardFunction = Callable[[List[str], List[str]], List[RewardScore]]

class FunctionRewardManager:

class FunctionRewardManager(ABC):
"""Reward manager for rule-based reward."""

def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
Expand All @@ -56,29 +59,60 @@ def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
if not hasattr(module, config.reward_function_name):
raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.")

reward_fn: RewardFunction = getattr(module, config.reward_function_name)
reward_fn = getattr(module, config.reward_function_name)
print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.")
self.reward_fn = partial(reward_fn, **config.reward_function_kwargs)
self.config = config
self.tokenizer = tokenizer

@abstractmethod
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
"""Compute reward for a batch of data."""
...


class SequentialFunctionRewardManager(FunctionRewardManager):
reward_fn: SequentialRewardFunction

def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list)
response_ids = data.batch["responses"]
response_length = data.batch["response_mask"].sum(dim=-1)
for i in range(len(data)):
data_item = data[i] # DataProtoItem
response_ids = data_item.batch["responses"]
response_mask = data_item.batch["response_mask"]
valid_response_length = response_mask.sum()
valid_response_ids = response_ids[:valid_response_length]

valid_response_ids = response_ids[i][: response_length[i]]
response_str = self.tokenizer.decode(
valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
)
ground_truth = data_item.non_tensor_batch["ground_truth"]
ground_truth = data.non_tensor_batch["ground_truth"][i]

score = self.reward_fn(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score["overall"]
reward_tensor[i, response_length[i] - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)

return reward_tensor, reward_metrics


class BatchFunctionRewardManager(FunctionRewardManager):
reward_fn: BatchRewardFunction

def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
response_str, ground_truth = [], []
response_ids = data.batch["responses"]
response_length = data.batch["response_mask"].sum(dim=-1)
for i in range(len(data)):
valid_response_ids = response_ids[i][: response_length[i]]
response_str.append(
self.tokenizer.decode(valid_response_ids, skip_special_tokens=self.config.skip_special_tokens)
)
ground_truth.append(data.non_tensor_batch["ground_truth"][i])

scores = self.reward_fn(response_str, ground_truth)
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list)
for i, score in enumerate(scores):
reward_tensor[i, response_length[i] - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)

Expand Down