Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ data:
max_prompt_length: 2048
max_response_length: 2048
rollout_batch_size: 512
val_batch_size: -1
shuffle: true
seed: 1
max_pixels: 4194304
min_pixels: 262144

algorithm:
adv_estimator: grpo
disable_kl: false
use_kl_loss: true
kl_penalty: low_var_kl
kl_coef: 1.0e-2
Expand Down Expand Up @@ -74,7 +76,6 @@ trainer:
experiment_name: qwen2_5_7b_math_grpo
n_gpus_per_node: 8
nnodes: 1
report_kl: true
val_freq: 5 # -1 to disable
val_before_train: true
val_only: false
Expand Down
5 changes: 3 additions & 2 deletions verl/trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class DataConfig:
max_prompt_length: int = 512
max_response_length: int = 512
rollout_batch_size: int = 512
val_batch_size: int = -1
system_prompt: Optional[str] = None
shuffle: bool = True
seed: int = 1
Expand All @@ -53,6 +54,7 @@ class AlgorithmConfig:
gamma: float = 1.0
lam: float = 1.0
adv_estimator: str = "grpo"
disable_kl: bool = False
use_kl_loss: bool = False
kl_penalty: str = "kl"
kl_coef: float = 1e-3
Expand All @@ -70,7 +72,6 @@ class TrainerConfig:
logger: Tuple[str] = ("console", "wandb")
nnodes: int = 1
n_gpus_per_node: int = 8
report_kl: bool = True
critic_warmup: int = 0
val_freq: int = -1
val_before_train: bool = True
Expand All @@ -96,10 +97,10 @@ class PPOConfig:
def post_init(self):
self.worker.rollout.prompt_length = self.data.max_prompt_length
self.worker.rollout.response_length = self.data.max_response_length
self.worker.actor.disable_kl = self.algorithm.disable_kl
self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
self.worker.actor.kl_penalty = self.algorithm.kl_penalty
self.worker.actor.kl_coef = self.algorithm.kl_coef
self.worker.actor.report_kl = self.trainer.report_kl

def deep_post_init(self):
recursive_post_init(self)
Expand Down
6 changes: 3 additions & 3 deletions verl/trainer/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def compute_gae_advantage_return(
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
@torch.no_grad()
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
Expand Down Expand Up @@ -164,15 +164,15 @@ def compute_grpo_outcome_advantage(
raise ValueError(f"no score in prompt index: {idx}")

for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)

scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores


@torch.no_grad()
def compute_rloo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Expand Down
39 changes: 16 additions & 23 deletions verl/trainer/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
This trainer supports model-agonistic model initialization with huggingface
"""

import json
import os
import uuid
from collections import defaultdict
Expand All @@ -42,8 +41,9 @@
from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.dataset import RLHFDataset, collate_fn
from ..utils.logger import Tracker
from ..utils.py_functional import convert_dict_to_str
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from ..utils.tracking import Tracking, ValGenerationsLogger
from ..workers.fsdp_workers import FSDPWorker
from . import core_algos
from .config import PPOConfig
Expand Down Expand Up @@ -236,15 +236,15 @@ def __init__(
self.resource_pool_manager = resource_pool_manager
self.use_reward_model = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
self.val_generations_logger = ValGenerationsLogger()

# define KL control
if Role.RefPolicy in role_worker_mapping and config.trainer.report_kl:
if Role.RefPolicy in role_worker_mapping and not config.algorithm.disable_kl:
self.use_reference_policy = True
self.kl_ctrl = core_algos.get_kl_controller(config.algorithm)
else:
self.use_reference_policy = False
self.kl_ctrl = core_algos.FixedKLController(init_kl_coef=0.0)
print("KL is disabled, no KL metrics will be logged. Please set `kl_coef=0` to log KL metrics.")

if config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True
Expand All @@ -260,9 +260,6 @@ def __init__(
if self.use_critic and config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0:
raise ValueError("Rollout batch size must be divisible by global batch size.")

if config.algorithm.kl_coef > 1e-8 and not config.trainer.report_kl:
raise ValueError("KL coefficient must be 0 if report_kl is False.")

self._create_dataloader()

def _create_dataloader(self) -> None:
Expand Down Expand Up @@ -312,7 +309,9 @@ def _create_dataloader(self) -> None:
)
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=len(self.val_dataset),
batch_size=len(self.val_dataset)
if self.config.data.val_batch_size == -1
else self.config.data.val_batch_size,
shuffle=False,
num_workers=8,
collate_fn=collate_fn,
Expand All @@ -321,8 +320,9 @@ def _create_dataloader(self) -> None:
)

assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) == 1
assert len(self.val_dataloader) >= 1
print(f"Size of train dataloader: {len(self.train_dataloader)}")
print(f"Size of val dataloader: {len(self.val_dataloader)}")

if self.config.trainer.max_steps is not None:
training_steps = self.config.trainer.max_steps
Expand All @@ -348,7 +348,7 @@ def _maybe_log_val_generations(self, inputs: List[str], outputs: List[str], scor
rng.shuffle(samples)

samples = samples[: self.config.trainer.val_generations_to_log]
self.val_generations_logger.log(self.config.trainer.logger, samples, self.global_step)
self.logger.log_generation(samples, self.global_step)

def _validate(self) -> Dict[str, Any]:
reward_tensor_lst = []
Expand Down Expand Up @@ -538,14 +538,9 @@ def fit(self):
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=self.config.to_dict(),
)
val_metrics: Optional[Dict[str, Any]] = None
self.logger = Tracker(loggers=self.config.trainer.logger, config=self.config.to_dict())
self.global_step = 0
val_metrics: Optional[Dict[str, Any]] = None

# load checkpoint before doing anything
self._load_checkpoint()
Expand All @@ -554,8 +549,7 @@ def fit(self):
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.val_before_train:
val_metrics = self._validate()
print(f"Initial validation metrics: {json.dumps(val_metrics, indent=2)}")
logger.log(data=val_metrics, step=self.global_step)
self.logger.log(data=val_metrics, step=self.global_step)
if self.config.trainer.val_only:
return

Expand Down Expand Up @@ -699,8 +693,7 @@ def fit(self):
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))

# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_step)
self.logger.log(data=metrics, step=self.global_step)

# perform validation after training
if self.val_reward_fn is not None:
Expand All @@ -710,9 +703,9 @@ def fit(self):
or self.global_step % self.config.trainer.val_freq != 0
):
val_metrics = self._validate()
logger.log(data=val_metrics, step=self.global_step)
self.logger.log(data=val_metrics, step=self.global_step)

print(f"Final validation metrics: {json.dumps(val_metrics, indent=2)}")
print(f"Final validation metrics: {convert_dict_to_str(val_metrics)}")

if self.config.trainer.save_freq <= 0 or self.global_step % self.config.trainer.save_freq != 0:
self._save_checkpoint()
6 changes: 6 additions & 0 deletions verl/utils/logger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .logger import Tracker


__all__ = ["Tracker"]
40 changes: 0 additions & 40 deletions verl/utils/logger/aggregate_logger.py

This file was deleted.

86 changes: 86 additions & 0 deletions verl/utils/logger/gen_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Tuple

from ..py_functional import is_package_available


if is_package_available("wandb"):
import wandb # type: ignore


if is_package_available("swanlab"):
import swanlab # type: ignore


@dataclass
class GenerationLogger(ABC):
@abstractmethod
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None: ...


@dataclass
class WandbGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
# Create column names for all samples
columns = ["step"] + sum(
[[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
)

if not hasattr(self, "validation_table"):
# Initialize the table on first call
self.validation_table = wandb.Table(columns=columns)

# Create a new table with same columns and existing data
# Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
new_table = wandb.Table(columns=columns, data=self.validation_table.data)

# Add new row with all data
row_data = [step]
for sample in samples:
row_data.extend(sample)

new_table.add_data(*row_data)
wandb.log({"val/generations": new_table}, step=step)
self.validation_table = new_table


@dataclass
class SwanlabGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
swanlab_text_list = []
for i, sample in enumerate(samples):
row_text = f"input: {sample[0]}\n\n---\n\noutput: {sample[1]}\n\n---\n\nscore: {sample[2]}"
swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}"))

swanlab.log({"val/generations": swanlab_text_list}, step=step)


@dataclass
class AggregateGenerationsLogger:
def __init__(self, loggers: List[str]):
self.loggers: List[GenerationLogger] = []
if "wandb" in loggers:
self.loggers.append(WandbGenerationLogger())

if "swanlab" in loggers:
self.loggers.append(SwanlabGenerationLogger())

def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
for logger in self.loggers:
logger.log(samples, step)
Loading