diff --git a/examples/config.yaml b/examples/config.yaml index 1bf9f527..e9827419 100644 --- a/examples/config.yaml +++ b/examples/config.yaml @@ -7,6 +7,7 @@ data: max_prompt_length: 2048 max_response_length: 2048 rollout_batch_size: 512 + val_batch_size: -1 shuffle: true seed: 1 max_pixels: 4194304 @@ -14,6 +15,7 @@ data: algorithm: adv_estimator: grpo + disable_kl: false use_kl_loss: true kl_penalty: low_var_kl kl_coef: 1.0e-2 @@ -74,7 +76,6 @@ trainer: experiment_name: qwen2_5_7b_math_grpo n_gpus_per_node: 8 nnodes: 1 - report_kl: true val_freq: 5 # -1 to disable val_before_train: true val_only: false diff --git a/verl/trainer/config.py b/verl/trainer/config.py index c629bc06..d1530e14 100644 --- a/verl/trainer/config.py +++ b/verl/trainer/config.py @@ -41,6 +41,7 @@ class DataConfig: max_prompt_length: int = 512 max_response_length: int = 512 rollout_batch_size: int = 512 + val_batch_size: int = -1 system_prompt: Optional[str] = None shuffle: bool = True seed: int = 1 @@ -53,6 +54,7 @@ class AlgorithmConfig: gamma: float = 1.0 lam: float = 1.0 adv_estimator: str = "grpo" + disable_kl: bool = False use_kl_loss: bool = False kl_penalty: str = "kl" kl_coef: float = 1e-3 @@ -70,7 +72,6 @@ class TrainerConfig: logger: Tuple[str] = ("console", "wandb") nnodes: int = 1 n_gpus_per_node: int = 8 - report_kl: bool = True critic_warmup: int = 0 val_freq: int = -1 val_before_train: bool = True @@ -96,10 +97,10 @@ class PPOConfig: def post_init(self): self.worker.rollout.prompt_length = self.data.max_prompt_length self.worker.rollout.response_length = self.data.max_response_length + self.worker.actor.disable_kl = self.algorithm.disable_kl self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss self.worker.actor.kl_penalty = self.algorithm.kl_penalty self.worker.actor.kl_coef = self.algorithm.kl_coef - self.worker.actor.report_kl = self.trainer.report_kl def deep_post_init(self): recursive_post_init(self) diff --git a/verl/trainer/core_algos.py b/verl/trainer/core_algos.py index fccef62d..9dfa5eb9 100644 --- a/verl/trainer/core_algos.py +++ b/verl/trainer/core_algos.py @@ -127,7 +127,7 @@ def compute_gae_advantage_return( # NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. @torch.no_grad() def compute_grpo_outcome_advantage( - token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6 + token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for GRPO, operating only on Outcome reward @@ -164,7 +164,7 @@ def compute_grpo_outcome_advantage( raise ValueError(f"no score in prompt index: {idx}") for i in range(bsz): - scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps) scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask return scores, scores @@ -172,7 +172,7 @@ def compute_grpo_outcome_advantage( @torch.no_grad() def compute_rloo_outcome_advantage( - token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6 + token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 diff --git a/verl/trainer/ray_trainer.py b/verl/trainer/ray_trainer.py index 94274bba..59942784 100644 --- a/verl/trainer/ray_trainer.py +++ b/verl/trainer/ray_trainer.py @@ -16,7 +16,6 @@ This trainer supports model-agonistic model initialization with huggingface """ -import json import os import uuid from collections import defaultdict @@ -42,8 +41,9 @@ from ..utils import torch_functional as VF from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt from ..utils.dataset import RLHFDataset, collate_fn +from ..utils.logger import Tracker +from ..utils.py_functional import convert_dict_to_str from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance -from ..utils.tracking import Tracking, ValGenerationsLogger from ..workers.fsdp_workers import FSDPWorker from . import core_algos from .config import PPOConfig @@ -236,15 +236,15 @@ def __init__( self.resource_pool_manager = resource_pool_manager self.use_reward_model = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls - self.val_generations_logger = ValGenerationsLogger() # define KL control - if Role.RefPolicy in role_worker_mapping and config.trainer.report_kl: + if Role.RefPolicy in role_worker_mapping and not config.algorithm.disable_kl: self.use_reference_policy = True self.kl_ctrl = core_algos.get_kl_controller(config.algorithm) else: self.use_reference_policy = False self.kl_ctrl = core_algos.FixedKLController(init_kl_coef=0.0) + print("KL is disabled, no KL metrics will be logged. Please set `kl_coef=0` to log KL metrics.") if config.algorithm.adv_estimator == AdvantageEstimator.GAE: self.use_critic = True @@ -260,9 +260,6 @@ def __init__( if self.use_critic and config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0: raise ValueError("Rollout batch size must be divisible by global batch size.") - if config.algorithm.kl_coef > 1e-8 and not config.trainer.report_kl: - raise ValueError("KL coefficient must be 0 if report_kl is False.") - self._create_dataloader() def _create_dataloader(self) -> None: @@ -312,7 +309,9 @@ def _create_dataloader(self) -> None: ) self.val_dataloader = StatefulDataLoader( dataset=self.val_dataset, - batch_size=len(self.val_dataset), + batch_size=len(self.val_dataset) + if self.config.data.val_batch_size == -1 + else self.config.data.val_batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn, @@ -321,8 +320,9 @@ def _create_dataloader(self) -> None: ) assert len(self.train_dataloader) >= 1 - assert len(self.val_dataloader) == 1 + assert len(self.val_dataloader) >= 1 print(f"Size of train dataloader: {len(self.train_dataloader)}") + print(f"Size of val dataloader: {len(self.val_dataloader)}") if self.config.trainer.max_steps is not None: training_steps = self.config.trainer.max_steps @@ -348,7 +348,7 @@ def _maybe_log_val_generations(self, inputs: List[str], outputs: List[str], scor rng.shuffle(samples) samples = samples[: self.config.trainer.val_generations_to_log] - self.val_generations_logger.log(self.config.trainer.logger, samples, self.global_step) + self.logger.log_generation(samples, self.global_step) def _validate(self) -> Dict[str, Any]: reward_tensor_lst = [] @@ -538,14 +538,9 @@ def fit(self): The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=self.config.to_dict(), - ) - val_metrics: Optional[Dict[str, Any]] = None + self.logger = Tracker(loggers=self.config.trainer.logger, config=self.config.to_dict()) self.global_step = 0 + val_metrics: Optional[Dict[str, Any]] = None # load checkpoint before doing anything self._load_checkpoint() @@ -554,8 +549,7 @@ def fit(self): # currently, we only support validation using the reward_function. if self.val_reward_fn is not None and self.config.trainer.val_before_train: val_metrics = self._validate() - print(f"Initial validation metrics: {json.dumps(val_metrics, indent=2)}") - logger.log(data=val_metrics, step=self.global_step) + self.logger.log(data=val_metrics, step=self.global_step) if self.config.trainer.val_only: return @@ -699,8 +693,7 @@ def fit(self): metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_step) + self.logger.log(data=metrics, step=self.global_step) # perform validation after training if self.val_reward_fn is not None: @@ -710,9 +703,9 @@ def fit(self): or self.global_step % self.config.trainer.val_freq != 0 ): val_metrics = self._validate() - logger.log(data=val_metrics, step=self.global_step) + self.logger.log(data=val_metrics, step=self.global_step) - print(f"Final validation metrics: {json.dumps(val_metrics, indent=2)}") + print(f"Final validation metrics: {convert_dict_to_str(val_metrics)}") if self.config.trainer.save_freq <= 0 or self.global_step % self.config.trainer.save_freq != 0: self._save_checkpoint() diff --git a/verl/utils/logger/__init__.py b/verl/utils/logger/__init__.py index 1ce90c5e..557c4775 100644 --- a/verl/utils/logger/__init__.py +++ b/verl/utils/logger/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from .logger import Tracker + + +__all__ = ["Tracker"] diff --git a/verl/utils/logger/aggregate_logger.py b/verl/utils/logger/aggregate_logger.py deleted file mode 100644 index 3bbf733e..00000000 --- a/verl/utils/logger/aggregate_logger.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -A Ray logger will receive logging info from different processes. -""" - -import numbers -from typing import Any, Dict - - -def concat_dict_to_str(dict: Dict[str, Any], step: int) -> str: - output = [f"step {step}:"] - for k, v in dict.items(): - if isinstance(v, numbers.Number): - output.append(f"{k}:{v:.3f}") - - output_str = " - ".join(output) - return output_str - - -class LocalLogger: - def __init__(self): - pass - - def flush(self): - pass - - def log(self, data: Dict[str, Any], step: int) -> None: - print(concat_dict_to_str(data, step=step), flush=True) diff --git a/verl/utils/logger/gen_logger.py b/verl/utils/logger/gen_logger.py new file mode 100644 index 00000000..880b73f6 --- /dev/null +++ b/verl/utils/logger/gen_logger.py @@ -0,0 +1,86 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Tuple + +from ..py_functional import is_package_available + + +if is_package_available("wandb"): + import wandb # type: ignore + + +if is_package_available("swanlab"): + import swanlab # type: ignore + + +@dataclass +class GenerationLogger(ABC): + @abstractmethod + def log(self, samples: List[Tuple[str, str, float]], step: int) -> None: ... + + +@dataclass +class WandbGenerationLogger(GenerationLogger): + def log(self, samples: List[Tuple[str, str, float]], step: int) -> None: + # Create column names for all samples + columns = ["step"] + sum( + [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [] + ) + + if not hasattr(self, "validation_table"): + # Initialize the table on first call + self.validation_table = wandb.Table(columns=columns) + + # Create a new table with same columns and existing data + # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 + new_table = wandb.Table(columns=columns, data=self.validation_table.data) + + # Add new row with all data + row_data = [step] + for sample in samples: + row_data.extend(sample) + + new_table.add_data(*row_data) + wandb.log({"val/generations": new_table}, step=step) + self.validation_table = new_table + + +@dataclass +class SwanlabGenerationLogger(GenerationLogger): + def log(self, samples: List[Tuple[str, str, float]], step: int) -> None: + swanlab_text_list = [] + for i, sample in enumerate(samples): + row_text = f"input: {sample[0]}\n\n---\n\noutput: {sample[1]}\n\n---\n\nscore: {sample[2]}" + swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}")) + + swanlab.log({"val/generations": swanlab_text_list}, step=step) + + +@dataclass +class AggregateGenerationsLogger: + def __init__(self, loggers: List[str]): + self.loggers: List[GenerationLogger] = [] + if "wandb" in loggers: + self.loggers.append(WandbGenerationLogger()) + + if "swanlab" in loggers: + self.loggers.append(SwanlabGenerationLogger()) + + def log(self, samples: List[Tuple[str, str, float]], step: int) -> None: + for logger in self.loggers: + logger.log(samples, step) diff --git a/verl/utils/logger/logger.py b/verl/utils/logger/logger.py new file mode 100644 index 00000000..a957e0fe --- /dev/null +++ b/verl/utils/logger/logger.py @@ -0,0 +1,154 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A unified tracking interface that supports logging data to different backend +""" + +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple, Union + +from torch.utils.tensorboard import SummaryWriter + +from ..py_functional import convert_dict_to_str, flatten_dict, is_package_available, unflatten_dict +from .gen_logger import AggregateGenerationsLogger + + +if is_package_available("mlflow"): + import mlflow # type: ignore + + +if is_package_available("wandb"): + import wandb # type: ignore + + +if is_package_available("swanlab"): + import swanlab # type: ignore + + +class Logger(ABC): + @abstractmethod + def __init__(self, config: Dict[str, Any]) -> None: ... + + @abstractmethod + def log(self, data: Dict[str, Any], step: int) -> None: ... + + def finish(self) -> None: + pass + + +class ConsoleLogger(Logger): + def __init__(self, config: Dict[str, Any]) -> None: + print("Config\n" + convert_dict_to_str(config)) + + def log(self, data: Dict[str, Any], step: int) -> None: + print(f"Step {step}\n" + convert_dict_to_str(unflatten_dict(data))) + + +class MlflowLogger(Logger): + def __init__(self, config: Dict[str, Any]) -> None: + mlflow.start_run(run_name=config["trainer"]["experiment_name"]) + mlflow.log_params(flatten_dict(config)) + + def log(self, data: Dict[str, Any], step: int) -> None: + mlflow.log_metrics(metrics=data, step=step) + + +class TensorBoardLogger(Logger): + def __init__(self, config: Dict[str, Any]) -> None: + tensorboard_dir = os.getenv("TENSORBOARD_DIR", "tensorboard_log") + os.makedirs(tensorboard_dir, exist_ok=True) + print(f"Saving tensorboard log to {tensorboard_dir}.") + self.writer = SummaryWriter(tensorboard_dir) + self.writer.add_hparams(flatten_dict(config)) + + def log(self, data: Dict[str, Any], step: int) -> None: + for key, value in data.items(): + self.writer.add_scalar(key, value, step) + + def finish(self): + self.writer.close() + + +class WandbLogger(Logger): + def __init__(self, config: Dict[str, Any]) -> None: + wandb.init( + project=config["trainer"]["project_name"], + name=config["trainer"]["experiment_name"], + config=config, + ) + + def log(self, data: Dict[str, Any], step: int) -> None: + wandb.log(data=data, step=step) + + def finish(self) -> None: + wandb.finish() + + +class SwanlabLogger(Logger): + def __init__(self, config: Dict[str, Any]) -> None: + swanlab_key = os.getenv("SWANLAB_API_KEY") + swanlab_dir = os.getenv("SWANLAB_DIR", "swanlab_log") + swanlab_mode = os.getenv("SWANLAB_MODE", "cloud") + if swanlab_key: + swanlab.login(swanlab_key) + + swanlab.init( + project=config["trainer"]["project_name"], + experiment_name=config["trainer"]["experiment_name"], + config={"UPPERFRAMEWORK": "EasyR1", "FRAMEWORK": "veRL", **config}, + logdir=swanlab_dir, + mode=swanlab_mode, + ) + + def log(self, data: Dict[str, Any], step: int) -> None: + swanlab.log(data=data, step=step) + + def finish(self) -> None: + swanlab.finish() + + +LOGGERS = { + "wandb": WandbLogger, + "mlflow": MlflowLogger, + "tensorboard": TensorBoardLogger, + "console": ConsoleLogger, + "swanlab": SwanlabLogger, +} + + +class Tracker: + def __init__(self, loggers: Union[str, List[str]] = "console", config: Optional[Dict[str, Any]] = None): + if isinstance(loggers, str): + loggers = [loggers] + + self.loggers: List[Logger] = [] + for logger in loggers: + if logger not in LOGGERS: + raise ValueError(f"{logger} is not supported.") + + self.loggers.append(LOGGERS[logger](config)) + + self.gen_logger = AggregateGenerationsLogger(loggers) + + def log(self, data: Dict[str, Any], step: int) -> None: + for logger in self.loggers: + logger.log(data=data, step=step) + + def log_generation(self, samples: List[Tuple[str, str, float]], step: int) -> None: + self.gen_logger.log(samples, step) + + def __del__(self): + for logger in self.loggers: + logger.finish() diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py index 56b53da4..ce140f9a 100644 --- a/verl/utils/py_functional.py +++ b/verl/utils/py_functional.py @@ -15,14 +15,34 @@ Contain small python utility functions """ -from typing import Any, Dict, List +import importlib.util +from functools import lru_cache +from typing import Any, Dict, List, Union + +import numpy as np +import yaml +from yaml import Dumper + + +def numpy_representer(dumper: Dumper, value: Union[np.float32, np.float64]): + value = str(round(value, 3)) + return dumper.represent_scalar("tag:yaml.org,2002:float", value) + + +yaml.add_representer(np.float32, numpy_representer) +yaml.add_representer(np.float64, numpy_representer) + + +@lru_cache +def is_package_available(name: str) -> bool: + return importlib.util.find_spec(name) is not None def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]: """Union two dict. Will throw an error if there is an item not the same object with the same key.""" for key in dict2.keys(): if key in dict1: - assert dict1[key] == dict2[key], f"{key} in meta_dict1 and meta_dict2 are not the same object" + assert dict1[key] == dict2[key], f"{key} in dict1 and dict2 are not the same object" dict1[key] = dict2[key] @@ -30,8 +50,41 @@ def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, An def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None: + """Append dict to a dict of list.""" for key, val in new_data.items(): if key not in data: data[key] = [] data[key].append(val) + + +def unflatten_dict(data: Dict[str, Any], sep: str = "/") -> Dict[str, Any]: + unflattened = {} + for key, value in data.items(): + pieces = key.split(sep) + pointer = unflattened + for piece in pieces[:-1]: + if piece not in pointer: + pointer[piece] = {} + + pointer = pointer[piece] + + pointer[pieces[-1]] = value + + return unflattened + + +def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]: + flattened = {} + for key, value in data.items(): + new_key = parent_key + sep + key if parent_key else key + if isinstance(value, dict): + flattened.update(flatten_dict(value, new_key, sep=sep)) + else: + flattened[new_key] = value + + return flattened + + +def convert_dict_to_str(data: Dict[str, Any]) -> str: + return yaml.dump(data, indent=2) diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 48c1a7f1..ad684cb7 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -66,9 +66,9 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return output.view(*batch_dim) -def masked_mean(values: torch.Tensor, mask: torch.Tensor, dim: int = None) -> torch.Tensor: +def masked_mean(values: torch.Tensor, mask: torch.Tensor, dim: int = None, eps: float = 1e-8) -> torch.Tensor: """Compute mean of tensor with a masked values.""" - return (values * mask).sum(dim=dim) / mask.sum(dim=dim) + return (values * mask).sum(dim=dim) / (mask.sum(dim=dim) + eps) def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: @@ -79,7 +79,8 @@ def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) if unbiased: mask_sum = mask.sum() if mask_sum <= 1: - raise ValueError("The sum of the mask is less than one, which can cause a division by zero.") + print("The sum of the mask is less than one, which can cause a division by zero.") + return variance bessel_correction = mask_sum / (mask_sum - 1) variance = variance * bessel_correction @@ -87,10 +88,10 @@ def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) return variance -def masked_whiten(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: """Whiten values with masked values.""" mean, var = masked_mean(values, mask), masked_var(values, mask) - return (values - mean) * torch.rsqrt(var + 1e-8) + return (values - mean) * torch.rsqrt(var + eps) def get_eos_mask(response_ids: torch.Tensor, eos_token_id: Union[int, List[int]] = 2, dtype: torch.dtype = torch.long): diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py deleted file mode 100644 index d69bb964..00000000 --- a/verl/utils/tracking.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -A unified tracking interface that supports logging data to different backend -""" - -import os -from dataclasses import dataclass -from typing import List, Tuple, Union - -from .logger.aggregate_logger import LocalLogger - - -class Tracking: - supported_backend = ["wandb", "mlflow", "swanlab", "console"] - - def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = "console", config=None): - if isinstance(default_backend, str): - default_backend = [default_backend] - - for backend in default_backend: - assert backend in self.supported_backend, f"{backend} is not supported" - - self.logger = {} - - if "wandb" in default_backend: - import wandb # type: ignore - - wandb.init(project=project_name, name=experiment_name, config=config) - self.logger["wandb"] = wandb - - if "mlflow" in default_backend: - import mlflow # type: ignore - - mlflow.start_run(run_name=experiment_name) - mlflow.log_params(config) - self.logger["mlflow"] = _MlflowLoggingAdapter() - - if "swanlab" in default_backend: - import swanlab # type: ignore - - SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None) - SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog") - SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud") - if SWANLAB_API_KEY: - swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten - - swanlab.init( - project=project_name, - experiment_name=experiment_name, - config={"UPPERFRAMEWORK": "EasyR1", "FRAMEWORK": "veRL", **config}, - logdir=SWANLAB_LOG_DIR, - mode=SWANLAB_MODE, - ) - self.logger["swanlab"] = swanlab - - if "console" in default_backend: - self.console_logger = LocalLogger() - self.logger["console"] = self.console_logger - - def log(self, data, step, backend=None): - for default_backend, logger_instance in self.logger.items(): - if backend is None or default_backend in backend: - logger_instance.log(data=data, step=step) - - def __del__(self): - if "wandb" in self.logger: - self.logger["wandb"].finish(exit_code=0) - - if "swanlab" in self.logger: - self.logger["swanlab"].finish() - - -class _MlflowLoggingAdapter: - def log(self, data, step): - import mlflow # type: ignore - - mlflow.log_metrics(metrics=data, step=step) - - -@dataclass -class ValGenerationsLogger: - def log(self, loggers: List[str], samples: List[Tuple[str, str, float]], step: int): - if "wandb" in loggers: - self.log_generations_to_wandb(samples, step) - if "swanlab" in loggers: - self.log_generations_to_swanlab(samples, step) - - def log_generations_to_wandb(self, samples: List[Tuple[str, str, float]], step: int) -> None: - """Log samples to wandb as a table""" - import wandb # type: ignore - - # Create column names for all samples - columns = ["step"] + sum( - [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [] - ) - - if not hasattr(self, "validation_table"): - # Initialize the table on first call - self.validation_table = wandb.Table(columns=columns) - - # Create a new table with same columns and existing data - # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 - new_table = wandb.Table(columns=columns, data=self.validation_table.data) - - # Add new row with all data - row_data = [] - row_data.append(step) - for sample in samples: - row_data.extend(sample) - - new_table.add_data(*row_data) - - # Update reference and log - wandb.log({"val/generations": new_table}, step=step) - self.validation_table = new_table - - def log_generations_to_swanlab(self, samples: List[Tuple[str, str, float]], step: int) -> None: - """Log samples to swanlab as text""" - import swanlab # type: ignore - - swanlab_text_list = [] - for i, sample in enumerate(samples): - row_text = f"input: {sample[0]}\n\n---\n\noutput: {sample[1]}\n\n---\n\nscore: {sample[2]}" - swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}")) - - # Log to swanlab - swanlab.log({"val/generations": swanlab_text_list}, step=step) diff --git a/verl/workers/actor/config.py b/verl/workers/actor/config.py index 7785a100..0e2a3456 100644 --- a/verl/workers/actor/config.py +++ b/verl/workers/actor/config.py @@ -82,7 +82,7 @@ class ActorConfig: offload: OffloadConfig = field(default_factory=OffloadConfig) """auto keys""" global_batch_size_per_device: int = field(default=-1, init=False) - report_kl: bool = field(default=False, init=False) + disable_kl: bool = field(default=False, init=False) use_kl_loss: bool = field(default=False, init=False) kl_penalty: str = field(default="kl", init=False) kl_coef: float = field(default=0.0, init=False) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 55a1e2e0..1c754eb0 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -210,7 +210,7 @@ def update_policy(self, data: DataProto) -> Dict[str, Any]: temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] - if self.config.use_kl_loss and self.config.report_kl: + if self.config.use_kl_loss and not self.config.disable_kl: select_keys.append("ref_log_probs") if "multi_modal_inputs" in data.non_tensor_batch.keys():