Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/baselines/qwen2_5_vl_3b_clevr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ python3 -m verl.trainer.main \
data.format_prompt=./examples/format_prompt/r1v_format.jinja \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \
worker.reward.score_function=./examples/score_function/r1v.py:compute_score \
worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_clevr \
trainer.n_gpus_per_node=2
2 changes: 1 addition & 1 deletion examples/baselines/qwen2_5_vl_3b_geoqa8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ python3 -m verl.trainer.main \
data.format_prompt=./examples/format_prompt/r1v_format.jinja \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \
worker.reward.score_function=./examples/score_function/r1v.py:compute_score \
worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
trainer.n_gpus_per_node=8
4 changes: 2 additions & 2 deletions examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ data:
max_prompt_length: 2048
max_response_length: 2048
rollout_batch_size: 512
val_batch_size: -1
val_batch_size: 1024
format_prompt: ./examples/format_prompt/math_format.jinja
shuffle: true
seed: 1
Expand Down Expand Up @@ -71,7 +71,7 @@ worker:

reward:
reward_type: function
score_function: ./examples/score_function/math.py:compute_score
reward_function: ./examples/reward_function/math.py:compute_score

trainer:
total_episodes: 15
Expand Down
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions verl/trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class DataConfig:

def post_init(self):
if self.format_prompt is not None:
if os.path.exists(self.format_prompt):
if os.path.exists(self.format_prompt): # ray job uses absolute path
self.format_prompt = os.path.abspath(self.format_prompt)
else:
self.format_prompt = None
Expand Down Expand Up @@ -94,7 +94,7 @@ def post_init(self):
if self.save_checkpoint_path is None:
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)

self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path)
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path
if self.load_checkpoint_path is not None:
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)

Expand Down
9 changes: 4 additions & 5 deletions verl/trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,11 @@ def run(self, config: PPOConfig):
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
val_reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
RemoteRewardManager = ray.remote(FunctionRewardManager).options(num_cpus=config.worker.reward.num_cpus)
reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)

train_dataloader, val_dataloader = create_dataloader(
config=config.data, tokenizer=tokenizer, processor=processor
)
train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor)

trainer = RayPPOTrainer(
config=config,
Expand Down
73 changes: 29 additions & 44 deletions verl/trainer/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
import os
import uuid
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum, IntEnum, auto
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Type

import numpy as np
import ray
import torch
from codetiming import Timer
from ray.experimental.tqdm_ray import tqdm
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin
Expand All @@ -40,9 +38,10 @@
from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.logger import Tracker
from ..utils.py_functional import convert_dict_to_str
from ..utils.py_functional import convert_dict_to_str, timer
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import FunctionRewardManager
from . import core_algos
from .config import PPOConfig
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
Expand Down Expand Up @@ -162,14 +161,6 @@ def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma:
return data


@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield

timing_raw[name] = timer.last


class RayPPOTrainer:
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
Expand All @@ -185,8 +176,8 @@ def __init__(
role_worker_mapping: dict[Role, Type[Worker]],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup,
reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
val_reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
reward_fn: Optional[FunctionRewardManager] = None,
val_reward_fn: Optional[FunctionRewardManager] = None,
):
self.tokenizer = tokenizer
self.processor = processor
Expand Down Expand Up @@ -307,7 +298,6 @@ def _validate(self) -> Dict[str, Any]:
test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size)
print("validation generation end")

# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
Expand All @@ -317,7 +307,7 @@ def _validate(self) -> Dict[str, Any]:
test_batch = test_batch.union(test_output_gen_batch)

# evaluate using reward_function
reward_tensor, reward_metrics = self.val_reward_fn(test_batch)
reward_tensor, reward_metrics = ray.get(self.val_reward_fn.compute_reward.remote(test_batch))

# Store scores
scores = reward_tensor.sum(-1).cpu().tolist()
Expand Down Expand Up @@ -504,20 +494,20 @@ def fit(self):
non_tensor_batch_keys=["raw_prompt_ids"],
)

with _timer("step", timing_raw):
with timer("step", timing_raw):
# generate a batch
with _timer("gen", timing_raw): # wg: worker group
with timer("gen", timing_raw): # wg: worker group
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

if self.config.algorithm.adv_estimator == "remax":
with _timer("gen_max", timing_raw):
with timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["temperature"] = 0
gen_baseline_batch.meta_info["n"] = 1
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

batch = batch.union(gen_baseline_output)
reward_baseline_tensor, _ = self.reward_fn(batch)
reward_baseline_tensor, _ = ray.get(self.reward_fn.compute_reward.remote(batch))
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
Expand All @@ -532,19 +522,6 @@ def fit(self):
batch = batch.union(gen_batch_output)
batch.non_tensor_batch.pop("multi_modal_data", None)

# compute reward
with _timer("reward", timing_raw):
if self.use_reward_model:
raise NotImplementedError("Reward model is not supported yet.")

# we combine with rule-based rm
reward_tensor, reward_metrics = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {
f"reward/{key}": value for key, value in reduce_metrics(reward_metrics).items()
}
metrics.update(reward_metrics)

# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
Expand All @@ -553,30 +530,38 @@ def fit(self):
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

# compute reward
with timer("reward", timing_raw):
reward_ref = self.reward_fn.compute_reward.remote(batch)

# recompute old_log_probs
with _timer("old", timing_raw):
with timer("old", timing_raw):
old_log_probs = self.actor_rollout_wg.compute_log_probs(batch)
batch = batch.union(old_log_probs)

# compute ref_log_probs
if self.use_reference_policy:
with _timer("ref", timing_raw):
with timer("ref", timing_raw):
ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch)
batch = batch.union(ref_log_probs)

# compute values
if self.use_critic:
with _timer("values", timing_raw):
with timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)

with _timer("adv", timing_raw):
with timer("adv", timing_raw):
# get token level scores
reward_tensor, reward_metrics = ray.get(reward_ref)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {f"reward/{k}": v for k, v in reduce_metrics(reward_metrics).items()}
metrics.update(reward_metrics)

# apply kl penalty if available
if not self.config.algorithm.use_kl_loss and self.use_reference_policy:
# apply kl penalty to reward
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
)
batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
Expand All @@ -591,15 +576,15 @@ def fit(self):

# update critic
if self.use_critic:
with _timer("update_critic", timing_raw):
with timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)

critic_metrics = reduce_metrics(critic_output.non_tensor_batch)
metrics.update(critic_metrics)

# update actor
if self.config.trainer.critic_warmup <= self.global_step:
with _timer("update_actor", timing_raw):
with timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)

actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
Expand All @@ -611,13 +596,13 @@ def fit(self):
and self.config.trainer.val_freq > 0
and self.global_step % self.config.trainer.val_freq == 0
):
with _timer("validation", timing_raw):
with timer("validation", timing_raw):
val_metrics = self._validate()

metrics.update(val_metrics)

if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0:
with _timer("save_checkpoint", timing_raw):
with timer("save_checkpoint", timing_raw):
self._save_checkpoint()

# collect metrics
Expand Down
10 changes: 10 additions & 0 deletions verl/utils/py_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

import importlib.util
import re
from contextlib import contextmanager
from functools import lru_cache
from typing import Any, Dict, List, Union

import numpy as np
import yaml
from codetiming import Timer
from yaml import Dumper


Expand Down Expand Up @@ -101,3 +103,11 @@ def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") ->

def convert_dict_to_str(data: Dict[str, Any]) -> str:
return yaml.dump(data, indent=2)


@contextmanager
def timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield

timing_raw[name] = timer.last
2 changes: 1 addition & 1 deletion verl/workers/actor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def post_init(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path

if self.model_path is not None and os.path.exists(self.model_path):
if self.model_path is not None and os.path.exists(self.model_path): # ray job uses absolute path
self.model_path = os.path.abspath(self.model_path)

if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path):
Expand Down
4 changes: 2 additions & 2 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,15 @@ def _build_model_optimizer(
if self._is_actor or self._is_critic:
if optim_config.strategy == "adamw":
self.optimizer = torch.optim.AdamW(
self.fsdp_module.parameters(),
filter(lambda p: p.requires_grad, self.fsdp_module.parameters()),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
fused=True,
)
elif optim_config.strategy == "adamw_bf16":
self.optimizer = AnyPrecisionAdamW(
self.fsdp_module.parameters(),
filter(lambda p: p.requires_grad, self.fsdp_module.parameters()),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
Expand Down
21 changes: 11 additions & 10 deletions verl/workers/reward/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,21 @@
@dataclass
class RewardConfig:
reward_type: str = "function"
score_function: Optional[str] = None
score_function_kwargs: dict = field(default_factory=dict)
reward_function: Optional[str] = None
reward_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True
num_cpus: int = 1
"""auto keys"""
score_function_name: Optional[str] = field(default=None, init=False)
reward_function_name: Optional[str] = field(default=None, init=False)

def post_init(self):
if self.score_function is not None:
if ":" not in self.score_function:
self.score_function_name = "main"
if self.reward_function is not None: # support custom reward function, e.g., ./math.py:main
if ":" not in self.reward_function:
self.reward_function_name = "main"
else:
self.score_function, self.score_function_name = self.score_function.split(":", maxsplit=1)
self.reward_function, self.reward_function_name = self.reward_function.rsplit(":", maxsplit=1)

if os.path.exists(self.score_function):
self.score_function = os.path.abspath(self.score_function)
if os.path.exists(self.reward_function): # ray job uses absolute path
self.reward_function = os.path.abspath(self.reward_function)
else:
self.score_function = None
self.reward_function = None
Loading