From 5e7f5e774d91c90cb2ee4f2b71ca523946a47402 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 4 May 2025 16:49:36 -0700 Subject: [PATCH 01/31] adding doc string and docs for single controller Signed-off-by: Hongpeng Guo --- docs/{ => api}/data.rst | 0 docs/api/single_controller.rst | 36 ++++++++++++++++++++++++++++++++++ docs/index.rst | 3 ++- 3 files changed, 38 insertions(+), 1 deletion(-) rename docs/{ => api}/data.rst (100%) create mode 100644 docs/api/single_controller.rst diff --git a/docs/data.rst b/docs/api/data.rst similarity index 100% rename from docs/data.rst rename to docs/api/data.rst diff --git a/docs/api/single_controller.rst b/docs/api/single_controller.rst new file mode 100644 index 00000000000..37f93abd922 --- /dev/null +++ b/docs/api/single_controller.rst @@ -0,0 +1,36 @@ +Single Controller interface +============================ + +The Single Controller provides a unified interface for managing distributed workers +using Ray or other backends and executing functions across them. +It simplifies the process of dispatching tasks and collecting results, particularly +when dealing with data parallelism or model parallelism. + + +Core APIs +~~~~~~~~~~~~~~~~~ + +.. autoclass:: verl.single_controller.Worker + :members: __init__, __new__, get_fused_worker_by_name, get_master_addr_port, get_cuda_visible_devices, world_size, rank, execute_with_func_generator, execute_func_rank_zero + +.. autoclass:: verl.single_controller.WorkerGroup + :members: __init__, start_worker_aliveness_check, world_size + +.. autoclass:: verl.single_controller.ClassWithInitArgs + :members: __init__, __call__ + +.. autoclass:: verl.single_controller.ResourcePool + :members: __init__, add_note, world_size, __call__, store, local_world_size_list, local_rank_list + + +Decorator APIs +~~~~~~~~~~~~~~~~~ +.. autofunction:: verl.single_controller.base.decorator.register + +.. autoclass:: verl.single_controller.base.decorator.Dispatch + :members: RANK_ZERO, ONE_TO_ALL, ALL_TO_ALL, MEGATRON_COMPUTE, MEGATRON_PP_AS_DP, MEGATRON_PP_ONLY, MEGATRON_COMPUTE_PROTO, MEGATRON_PP_AS_DP_PROTO, DP_COMPUTE, DP_COMPUTE_PROTO, DP_COMPUTE_PROTO_WITH_FUNC, DP_COMPUTE_METRIC, DIRECT_ROLLOUT_METHOD + :member-order: bysource + +.. autoclass:: verl.single_controller.base.decorator.Execute + :members: ALL, RANK_ZERO + :member-order: bysource \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 4f5da4bd854..679c4bfdf1e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -97,7 +97,8 @@ verl is fast with: :maxdepth: 1 :caption: API References - data.rst + api/data.rst + api/single_controller.rst .. toctree:: From 968c800aa1c9a89f41469a9884d4c267e41a0d8d Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 4 May 2025 16:50:24 -0700 Subject: [PATCH 02/31] add doc string for public methods for single controller Signed-off-by: Hongpeng Guo --- verl/single_controller/base/decorator.py | 66 ++++++++++++----- verl/single_controller/base/worker.py | 59 ++++++++++++++- verl/single_controller/base/worker_group.py | 80 ++++++++++++++++----- 3 files changed, 169 insertions(+), 36 deletions(-) diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index c4c778d6b46..59de926024c 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -25,26 +25,38 @@ class Dispatch(Enum): - RANK_ZERO = 0 - ONE_TO_ALL = 1 - ALL_TO_ALL = 2 - MEGATRON_COMPUTE = 3 - MEGATRON_PP_AS_DP = 4 - MEGATRON_PP_ONLY = 5 - MEGATRON_COMPUTE_PROTO = 6 - MEGATRON_PP_AS_DP_PROTO = 7 - DP_COMPUTE = 8 - DP_COMPUTE_PROTO = 9 - DP_COMPUTE_PROTO_WITH_FUNC = 10 - DP_COMPUTE_METRIC = 11 - - # This is a special dispatch mode for vllm ExternalRayDistributedExecutor - DIRECT_ROLLOUT_METHOD = 12 + """Enum class defining different dispatch modes for distributed computation. + + Each mode represents a specific strategy for distributing data across + different ranks in a distributed system. The modes are used to control + how data is partitioned and processed across different worker groups. + """ + + RANK_ZERO = 0 #: Only on rank 0 + ONE_TO_ALL = 1 #: Broadcast data from one rank to all others + ALL_TO_ALL = 2 #: Distribute data across all ranks + MEGATRON_COMPUTE = 3 #: Megatron-style computation with tensor/pipeline parallelism + MEGATRON_PP_AS_DP = 4 #: Broadcast data from one rank to all others + MEGATRON_PP_ONLY = 5 #: Only use pipeline parallelism + MEGATRON_COMPUTE_PROTO = 6 #: Megatron-style computation with tensor/pipeline parallelism + MEGATRON_PP_AS_DP_PROTO = 7 #: Megatron PP as DP with DataProto support + DP_COMPUTE = 8 #: Data parallelism computation + DP_COMPUTE_PROTO = 9 #: Data parallelism with DataProto support + DP_COMPUTE_PROTO_WITH_FUNC = 10 #: Data parallelism with DataProto, supporting one function argument + DP_COMPUTE_METRIC = 11 #: Data parallelism with metric computation + + DIRECT_ROLLOUT_METHOD = 12 #: Special mode for vllm ExternalRayDistributedExecutor class Execute(Enum): - ALL = 0 - RANK_ZERO = 1 + """Enum class defining different execution modes for distributed computation. + + These modes control how a function should be executed across different ranks + in a distributed system. + """ + + ALL = 0 #: Execute the function on all ranks + RANK_ZERO = 1 #: Execute the function only on rank 0 def _split_args_kwargs_data_proto(chunks, *args, **kwargs): @@ -454,6 +466,26 @@ def _materialize_futures(*args, **kwargs): def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): + """Register a function with distributed execution configuration. + + This decorator registers a function with specific dispatch and execution modes + for distributed computation. It handles both synchronous and asynchronous + functions, and optionally materializes futures before execution. + + Args: + dispatch_mode: + Dispatch mode for computation distribution. Default: Dispatch.ALL_TO_ALL. + execute_mode: + Execute mode for computation distribution. Default: Execute.ALL. + blocking: + Whether the execution should be blocking. Defaults to True. + materialize_futures: + Whether to materialize the data before dispatching. Defaults to True. + + Returns: + A decorator that wraps the original function with distributed execution + configuration. + """ _check_dispatch_mode(dispatch_mode=dispatch_mode) _check_execute_mode(execute_mode=execute_mode) diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index ce32cd134b1..b55c4ecf683 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -90,11 +90,17 @@ def to_dict(self): # we assume that in each WorkerGroup, there is a Master Worker class Worker(WorkerHelper): - """A (distributed) worker.""" + """A distributed worker that handles initialization and configuration for distributed training. + + This class manages worker initialization, configuration, and provides methods for executing + distributed operations. It handles communication settings, device configuration, and worker + metadata management. + """ fused_worker_attr_name = "fused_worker_dict" def __new__(cls, *args, **kwargs): + """Create a new Worker instance with proper initialization based on environment settings.""" instance = super().__new__(cls) # note that here we use int to distinguish @@ -112,6 +118,14 @@ def __new__(cls, *args, **kwargs): return instance def _configure_before_init(self, register_center_name: str, rank: int): + """Configure worker settings before initialization. + + Args: + register_center_name (str): + Name of the register center Ray actor for worker coordination + rank (int): + Rank of the worker in the distributed setup + """ assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}" if rank == 0: @@ -134,6 +148,12 @@ def _configure_before_init(self, register_center_name: str, rank: int): ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id())) def __init__(self, cuda_visible_devices=None) -> None: + """Initialize the worker with environment settings and device configuration. + + Args: + cuda_visible_devices (str, optional): + CUDA visible devices configuration. Defaults to None. + """ # construct a meta from environment variable. Note that the import must be inside the class because it is executed remotely import os @@ -187,11 +207,20 @@ def __init__(self, cuda_visible_devices=None) -> None: self.fused_worker_dict = {} def get_fused_worker_by_name(self, worker_name: str): + """Get a fused worker by its name. + + Args: + worker_name (str): + Name of the worker to retrieve + """ return self.fused_worker_dict.get(worker_name, None) def _configure_with_meta(self, meta: WorkerMeta): - """ - This function should only be called inside by WorkerGroup + """Configure worker settings using WorkerMeta. + + Args: + meta (`WorkerMeta`): + Metadata containing worker configuration """ assert isinstance(meta, WorkerMeta) self.__dict__.update(meta.to_dict()) # this is hacky @@ -204,9 +233,11 @@ def _configure_with_meta(self, meta: WorkerMeta): os.environ["REDIS_STORE_SERVER_HOST"] = str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" def get_master_addr_port(self): + """Get the master address and port for distributed communication.""" return self._master_addr, self._master_port def get_cuda_visible_devices(self): + """Get the CUDA visible devices configuration.""" import os cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set") @@ -214,18 +245,40 @@ def get_cuda_visible_devices(self): @property def world_size(self): + """Get the total number of workers in the distributed setup.""" return self._world_size @property def rank(self): + """Get the rank of this worker in the distributed setup.""" return self._rank @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC) def execute_with_func_generator(self, func, *args, **kwargs): + """Execute a function with function generator dispatch mode. + + Args: + func: + Function to execute + *args: + Positional arguments for the function + **kwargs: + Keyword arguments for the function + """ ret_proto = func(self, *args, **kwargs) return ret_proto @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) def execute_func_rank_zero(self, func, *args, **kwargs): + """Execute a function in rank zero execution mode. + + Args: + func: + Function to execute + *args: + Positional arguments for the function + **kwargs: + Keyword arguments for the function + """ result = func(*args, **kwargs) return result diff --git a/verl/single_controller/base/worker_group.py b/verl/single_controller/base/worker_group.py index d7761c40613..33bb736192b 100644 --- a/verl/single_controller/base/worker_group.py +++ b/verl/single_controller/base/worker_group.py @@ -25,9 +25,20 @@ class ResourcePool: - """The resource pool with meta info such as world_size.""" + """ + Manages a pool of resources across multiple nodes, tracking process counts and GPU allocations. + The class provides methods to calculate world size, local world sizes, and local ranks + across all nodes in the pool. + """ def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8) -> None: + """Initialize the ResourcePool with node processes and GPU configuration. + + Args: + process_on_nodes (List[int], optional): List of process counts per node. Defaults to empty list. + max_colocate_count (int, optional): Maximum number of processes that can be colocated. Defaults to 10. + n_gpus_per_node (int, optional): Number of GPUs available per node. Defaults to 8. + """ if process_on_nodes is None: process_on_nodes = [] self._store = process_on_nodes @@ -35,52 +46,74 @@ def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_p self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node def add_node(self, process_count): + """Add a new node to the resource pool with the specified process count.""" self._store.append(process_count) @property def world_size(self): + """Total number of processes across all nodes in the pool.""" return sum(self._store) def __call__(self) -> Any: + """Get the list of processes count per node of the resource pool.""" return self._store @property def store(self): + """The store of the process count per node.""" return self._store def local_world_size_list(self) -> List[int]: + """Returns a flat list where each process has its local world size.""" nested_local_world_size_list = [[local_world_size for _ in range(local_world_size)] for local_world_size in self._store] return [item for row in nested_local_world_size_list for item in row] def local_rank_list(self) -> List[int]: + """Returns a flat list of local ranks for all processes across all nodes.""" nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] return [item for row in nested_local_rank_list for item in row] class ClassWithInitArgs: """ - This class stores a class constructor and the args/kwargs to construct the class. - It is used to instantiate the remote class. + Wrapper class that stores constructor arguments for deferred instantiation. + This class is particularly useful for remote class instantiation where + the actual construction needs to happen at a different time or location. """ def __init__(self, cls, *args, **kwargs) -> None: + """Initialize the ClassWithInitArgs instance. + + Args: + cls: + The class to be instantiated later + *args: + Positional arguments for the class constructor + **kwargs: + Keyword arguments for the class constructor + """ self.cls = cls self.args = args self.kwargs = kwargs self.fused_worker_used = False - # def add_arg(self, arg): - # self.args += (arg,) - - # def add_kwarg(self, key, value): - # self.kwargs[key] = value - def __call__(self) -> Any: + """Instantiate the stored class with the stored arguments.""" return self.cls(*self.args, **self.kwargs) def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: + """Continuously monitors worker processes and raises SIGABRT if any worker dies. + + Args: + workers (List): + List of worker objects to monitor + is_alive (Callable): + Function to check if a worker is alive + gap_time (float): + Time interval between checks + """ import time while True: @@ -92,7 +125,10 @@ def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) class WorkerGroup: - """A group of workers""" + """ + Base class for managing a group of workers in a distributed system. + The class provides methods for worker management, aliveness checking, and method binding. + """ fused_worker_execute_fn_name = "_fuw_execute" @@ -116,9 +152,11 @@ def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: self._checker_thread: threading.Thread = None def _is_worker_alive(self, worker): + """Check if a worker is alive. Must be implemented by derived classes.""" raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") def _block_until_all_workers_alive(self) -> None: + """Blocks until all workers in the group are alive.""" while True: all_state = [self._is_worker_alive(worker) for worker in self._workers] if False in all_state: @@ -127,6 +165,11 @@ def _block_until_all_workers_alive(self) -> None: break def start_worker_aliveness_check(self, every_n_seconds=1) -> None: + """Starts a background thread to monitor worker aliveness. + + Args: + every_n_seconds (int): Interval between aliveness checks + """ # before starting checking worker aliveness, make sure all workers are already alive self._block_until_all_workers_alive() @@ -135,16 +178,21 @@ def start_worker_aliveness_check(self, every_n_seconds=1) -> None: @property def world_size(self): + """Number of workers in the group.""" return len(self._workers) - # execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup, - # MegatronWorkerGroup, XperfWorkerGroup should skip - def _bind_worker_method(self, user_defined_cls, func_generator): - """ - Bind the worker method to the WorkerGroup - """ + """Binds worker methods to the WorkerGroup based on registered attributes. + Args: + user_defined_cls (type): + The class containing methods to bind + func_generator (Callable): + Function that generates the bound method + + Returns: + List[str]: List of method names that were successfully bound + """ method_names = [] for method_name in dir(user_defined_cls): try: From 2413c8061ca65ae1d8ded64b6ac2270bdff4649e Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 4 May 2025 16:51:21 -0700 Subject: [PATCH 03/31] fix lint Signed-off-by: Hongpeng Guo --- verl/single_controller/base/worker_group.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/verl/single_controller/base/worker_group.py b/verl/single_controller/base/worker_group.py index 33bb736192b..9275abce855 100644 --- a/verl/single_controller/base/worker_group.py +++ b/verl/single_controller/base/worker_group.py @@ -35,9 +35,12 @@ def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_p """Initialize the ResourcePool with node processes and GPU configuration. Args: - process_on_nodes (List[int], optional): List of process counts per node. Defaults to empty list. - max_colocate_count (int, optional): Maximum number of processes that can be colocated. Defaults to 10. - n_gpus_per_node (int, optional): Number of GPUs available per node. Defaults to 8. + process_on_nodes (List[int], optional): + List of process counts per node. Defaults to empty list. + max_colocate_count (int, optional): + Maximum number of processes that can be colocated. Defaults to 10. + n_gpus_per_node (int, optional): + Number of GPUs available per node. Defaults to 8. """ if process_on_nodes is None: process_on_nodes = [] From f823f29f10b986aebefd42aaa8eb885df1870667 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 4 May 2025 18:54:50 -0700 Subject: [PATCH 04/31] add docs for trainer and utils Signed-off-by: Hongpeng Guo --- docs/api/trainer.rst | 28 ++++++++++ docs/api/utils.rst | 82 ++++++++++++++++++++++++++++ docs/index.rst | 2 + tests/sandbox/test_sandbox.py | 8 +-- verl/trainer/ppo/core_algos.py | 27 ++++----- verl/trainer/ppo/ray_trainer.py | 56 +++++++++++++++++++ verl/utils/fs.py | 29 +++++++++- verl/utils/py_functional.py | 27 +++++++++ verl/utils/reward_score/__init__.py | 20 ++++++- verl/workers/reward_manager/dapo.py | 4 +- verl/workers/reward_manager/naive.py | 4 +- verl/workers/reward_manager/prime.py | 4 +- 12 files changed, 262 insertions(+), 29 deletions(-) create mode 100644 docs/api/trainer.rst create mode 100644 docs/api/utils.rst diff --git a/docs/api/trainer.rst b/docs/api/trainer.rst new file mode 100644 index 00000000000..d5174f6a887 --- /dev/null +++ b/docs/api/trainer.rst @@ -0,0 +1,28 @@ +PPO Trainer Interface +=========== + +This section documents the trainer utilities in the VERL library. + +PPO Ray Trainer +------------ + +.. automodule:: verl.trainer.ppo.ray_trainer + :members: + +PPO Metrics +------------ + +.. automodule:: verl.trainer.ppo.metric_utils + :members: + +PPO Core Algorithms +-------------------- + +.. automodule:: verl.trainer.ppo.core_algos + :members: + +PPO Reward +----------- + +.. automodule:: verl.trainer.ppo.reward + :members: \ No newline at end of file diff --git a/docs/api/utils.rst b/docs/api/utils.rst new file mode 100644 index 00000000000..b51afb1cb4c --- /dev/null +++ b/docs/api/utils.rst @@ -0,0 +1,82 @@ +Utilities +============ + +This section documents the utility functions and classes in the VERL library. + +Python Functional Utilities +------------------------- + +.. automodule:: verl.utils.py_functional + :members: + +File System Utilities +------------------- + +.. automodule:: verl.utils.fs + :members: + +Tracking Utilities +---------------- + +.. automodule:: verl.utils.tracking + :members: + +Checkpoint Management +------------------- + +.. automodule:: verl.utils.checkpoint.checkpoint_manager + :members: + +.. automodule:: verl.utils.checkpoint.fsdp_checkpoint_manager + :members: + +Dataset Utilities +--------------- + +.. automodule:: verl.utils.dataset.rl_dataset + :members: + +.. automodule:: verl.utils.dataset.sft_dataset + :members: + +Torch Functional Utilities +------------------------ + +.. automodule:: verl.utils.torch_functional + :members: + +Sequence Length Balancing +----------------------- + +.. automodule:: verl.utils.seqlen_balancing + :members: + +Ulysses Utilities +--------------- + +.. automodule:: verl.utils.ulysses + :members: + +Model Utilities +------------- + +.. automodule:: verl.utils.model + :members: + +FSDP Utilities +------------ + +.. automodule:: verl.utils.fsdp_utils + :members: + +Reward Score Utilities +-------------------- + +.. automodule:: verl.utils.reward_score + :members: + +Debug Utilities +------------- + +.. automodule:: verl.utils.debug + :members: \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 679c4bfdf1e..dbebeeeb9fa 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -99,6 +99,8 @@ verl is fast with: api/data.rst api/single_controller.rst + api/trainer.rst + api/utils.rst .. toctree:: diff --git a/tests/sandbox/test_sandbox.py b/tests/sandbox/test_sandbox.py index eff3fac3e2a..c6c721c9aaa 100644 --- a/tests/sandbox/test_sandbox.py +++ b/tests/sandbox/test_sandbox.py @@ -15,7 +15,7 @@ import asyncio import json -from verl.utils.reward_score import _default_compute_score +from verl.utils.reward_score import default_compute_score from verl.utils.reward_score.prime_code import apps_check_correctness from verl.workers.reward_manager.prime import parallel_compute_score_async @@ -106,7 +106,7 @@ def test_parallelism(): ground_truth.extend(prime_math_gts) data_sources.extend(["numina_aops_forum"] * len(prime_math_answers)) - scores = asyncio.run(parallel_compute_score_async(_default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16)) + scores = asyncio.run(parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16)) print(scores) @@ -116,7 +116,7 @@ def test_prime_code(): """ data_source = "codecontests" for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores): - score = _default_compute_score(data_source, completion, ground_truth) + score = default_compute_score(data_source, completion, ground_truth) assert float(score) == score_ @@ -131,5 +131,5 @@ def test_check_correctness(): def test_prime_math(): data_source = "numina_aops_forum" for completion, ground_truth in zip(prime_math_answers, prime_math_gts): - score = _default_compute_score(data_source, completion, ground_truth) + score = default_compute_score(data_source, completion, ground_truth) assert float(score) == 1.0 diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index a0e7cd8af48..28015e65f01 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -75,12 +75,12 @@ def compute_gae_advantage_return( Args: token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) values: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) response_mask: `(torch.Tensor)` - shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. - gamma: `(float)` + shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. + gamma is `(float)` discounted factor used in RL lam: `(float)` lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) @@ -122,9 +122,9 @@ def compute_grpo_outcome_advantage( (with only one scalar reward for each response). Args: token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) response_mask: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) norm_adv_by_std_in_grpo: (bool) whether to scale the GRPO advantage. If True, the advantage is scaled by the std, as in the original GRPO. @@ -132,9 +132,9 @@ def compute_grpo_outcome_advantage( Returns: advantages: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) Returns: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) """ scores = token_level_rewards.sum(dim=-1) @@ -317,15 +317,12 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str """ Aggregate the loss matrix into a scalar. Args: - loss_mat: `(torch.Tensor)` + loss_mat: `(torch.Tensor)`: shape: (bs, response_length) - loss_mask: `(torch.Tensor)` + loss_mask: `(torch.Tensor)`: shape: (bs, response_length) - loss_agg_mode: (str) choices: "token-mean" / - "seq-mean-token-sum" / - "seq-mean-token-mean" / - "seq-mean-token-sum-norm" / - "token-mean" is the default behavior + loss_agg_mode: (str) choices: + method to aggregate the loss matrix into a scalar. Returns: loss: `a scalar torch.Tensor` aggregated loss diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index d31524444f1..493d0daf0ac 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -143,6 +143,22 @@ def _check_resource_available(self): def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl", multi_turn=False): + """Apply KL penalty to the token-level rewards. + + This function computes the KL divergence between the reference policy and current policy, + then applies a penalty to the token-level rewards based on this divergence. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty. + kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl". + multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False. + + Returns: + tuple: A tuple containing: + - The updated data with token-level rewards adjusted by KL penalty + - A dictionary of metrics related to the KL penalty + """ responses = data.batch["responses"] response_length = responses.size(1) token_level_scores = data.batch["token_level_scores"] @@ -176,6 +192,17 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, def compute_response_mask(data: DataProto): + """Compute the attention mask for the response part of the sequence. + + This function extracts the portion of the attention mask that corresponds to the model's response, + which is used for masking computations that should only apply to response tokens. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + + Returns: + torch.Tensor: The attention mask for the response tokens. + """ responses = data.batch["responses"] response_length = responses.size(1) attention_mask = data.batch["attention_mask"] @@ -183,6 +210,23 @@ def compute_response_mask(data: DataProto): def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True): + """Compute advantage estimates for policy optimization. + + This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. + The advantage estimates are used to guide policy optimization in RL algorithms. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + adv_estimator: The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). + gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. + lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. + num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. + multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False. + norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in GRPO. Defaults to True. + + Returns: + DataProto: The updated data with computed advantages and returns. + """ # Back-compatible with trainers that do not compute response mask in fit if "response_mask" not in data.batch: data.batch["response_mask"] = compute_response_mask(data) @@ -254,6 +298,18 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re @contextmanager def _timer(name: str, timing_raw: Dict[str, float]): + """Context manager for timing code execution. + + This utility function measures the execution time of code within its context + and accumulates the timing information in the provided dictionary. + + Args: + name (str): The name/identifier for this timing measurement. + timing_raw (Dict[str, float]): Dictionary to store timing information. + + Yields: + None: This is a context manager that yields control back to the code block. + """ with Timer(name=name, logger=None) as timer: yield if name not in timing_raw: diff --git a/verl/utils/fs.py b/verl/utils/fs.py index 52e5c69616e..b9eb3004d87 100644 --- a/verl/utils/fs.py +++ b/verl/utils/fs.py @@ -31,10 +31,29 @@ def is_non_local(path): + """Check if a path is a non-local (HDFS) path. + + Args: + path (str): The path to check. + + Returns: + bool: True if the path is an HDFS path, False otherwise. + """ return path.startswith(_HDFS_PREFIX) def md5_encode(path: str) -> str: + """Generate an MD5 hash of a path string. + + This function is used to create unique identifiers for paths, typically + for creating cache directories or lock files. + + Args: + path (str): The path to encode. + + Returns: + str: The hexadecimal MD5 hash of the path. + """ return hashlib.md5(path.encode()).hexdigest() @@ -58,14 +77,18 @@ def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str: def copy_to_local(src: str, cache_dir=None, filelock=".file.lock", verbose=False) -> str: """Copy src from hdfs to local if src is on hdfs or directly return src. + If cache_dir is None, we will use the default cache dir of the system. Note that this may cause conflicts if - the src name is the same between calls + the src name is the same between calls. Args: - src (str): a HDFS path of a local path + src (str): A HDFS path or a local path. + cache_dir (str, optional): Directory to store the local copy. If None, uses system temp directory. + filelock (str, optional): Name of the lock file to use for thread safety. Defaults to ".file.lock". + verbose (bool, optional): Whether to print copy operations. Defaults to False. Returns: - a local path of the copied file + str: A local path of the copied file. """ return copy_local_path_from_hdfs(src, cache_dir, filelock, verbose) diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py index b96eb0cd1ec..f86eb63760f 100644 --- a/verl/utils/py_functional.py +++ b/verl/utils/py_functional.py @@ -38,6 +38,18 @@ def union_two_dict(dict1: Dict, dict2: Dict): def append_to_dict(data: Dict, new_data: Dict): + """Append values from new_data to lists in data. + + For each key in new_data, this function appends the corresponding value to a list + stored under the same key in data. If the key doesn't exist in data, a new list is created. + + Args: + data (Dict): The target dictionary containing lists as values. + new_data (Dict): The source dictionary with values to append. + + Returns: + None: The function modifies data in-place. + """ for key, val in new_data.items(): if key not in data: data[key] = [] @@ -45,6 +57,21 @@ def append_to_dict(data: Dict, new_data: Dict): class NestedNamespace(SimpleNamespace): + """A nested version of SimpleNamespace that recursively converts dictionaries to namespaces. + + This class allows for dot notation access to nested dictionary structures by recursively + converting dictionaries to NestedNamespace objects. + + Example: + config_dict = {"a": 1, "b": {"c": 2, "d": 3}} + config = NestedNamespace(config_dict) + # Access with: config.a, config.b.c, config.b.d + + Args: + dictionary: The dictionary to convert to a nested namespace. + **kwargs: Additional attributes to set on the namespace. + """ + def __init__(self, dictionary, **kwargs): super().__init__(**kwargs) for key, value in dictionary.items(): diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index fe101cc9e7e..4ee76939f80 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -14,7 +14,22 @@ # from . import gsm8k, math, prime_math, prime_code -def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None): +def default_compute_score(data_source, solution_str, ground_truth, extra_info=None): + """Compute the score for a given solution based on the data source. + + Args: + data_source (str): The source dataset identifier which determines the scoring method. + solution_str (str): The solution string to be evaluated. + ground_truth (str): The ground truth answer for comparison. + extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None. + + Returns: + float: The computed score as a floating point number. If the result is a dictionary, + it returns the dictionary instead. + + Raises: + NotImplementedError: If the reward function is not implemented for the given data source. + """ if data_source == "openai/gsm8k": from . import gsm8k @@ -62,3 +77,6 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N return float(res) else: return float(res[0]) + + +__all__ = ["default_compute_score"] diff --git a/verl/workers/reward_manager/dapo.py b/verl/workers/reward_manager/dapo.py index c320a42af77..399cdf05e09 100644 --- a/verl/workers/reward_manager/dapo.py +++ b/verl/workers/reward_manager/dapo.py @@ -17,7 +17,7 @@ import torch from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from verl.utils.reward_score import default_compute_score class DAPORewardManager: @@ -34,7 +34,7 @@ def __init__( ) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console - self.compute_score = compute_score or _default_compute_score + self.compute_score = compute_score or default_compute_score self.reward_fn_key = reward_fn_key self.overlong_buffer_cfg = overlong_buffer_cfg self.max_resp_len = max_resp_len diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index 3a59dc8b23a..59ad618c4c1 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -17,7 +17,7 @@ import torch from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from verl.utils.reward_score import default_compute_score class NaiveRewardManager: @@ -26,7 +26,7 @@ class NaiveRewardManager: def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console - self.compute_score = compute_score or _default_compute_score + self.compute_score = compute_score or default_compute_score self.reward_fn_key = reward_fn_key def __call__(self, data: DataProto, return_dict=False): diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py index 7feb1ebddef..cd779ddcc7a 100644 --- a/verl/workers/reward_manager/prime.py +++ b/verl/workers/reward_manager/prime.py @@ -21,7 +21,7 @@ from transformers import PreTrainedTokenizer from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from verl.utils.reward_score import default_compute_score async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0): @@ -90,7 +90,7 @@ def __init__( ) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console - self.compute_score = compute_score or _default_compute_score + self.compute_score = compute_score or default_compute_score self.reward_fn_key = reward_fn_key def verify(self, data): From 368b1151e51e189463cbceebffbe8f95baae6e8e Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 4 May 2025 21:24:43 -0700 Subject: [PATCH 05/31] adding more tests for utils and also for docs Signed-off-by: Hongpeng Guo --- docs/api/utils.rst | 54 ++++++++++++++----- tests/ray_cpu/test_ray_utils.py | 40 ++++++++++++++ .../utils/megatron/test_pipeline_parallel.py | 33 ++++++++++++ tests/verl/utils/test_seqlen_balancing.py | 47 ++++++++++++++++ verl/utils/megatron/__init__.py | 6 +++ verl/utils/megatron/pipeline_parallel.py | 14 +++++ verl/utils/megatron/sequence_parallel.py | 1 + verl/utils/ray_utils.py | 16 +++++- verl/utils/seqlen_balancing.py | 36 ++++++++----- verl/utils/tracking.py | 10 ++++ 10 files changed, 231 insertions(+), 26 deletions(-) create mode 100644 tests/ray_cpu/test_ray_utils.py create mode 100644 tests/verl/utils/megatron/test_pipeline_parallel.py create mode 100644 tests/verl/utils/test_seqlen_balancing.py diff --git a/docs/api/utils.rst b/docs/api/utils.rst index b51afb1cb4c..a6026ad70cf 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -4,25 +4,25 @@ Utilities This section documents the utility functions and classes in the VERL library. Python Functional Utilities -------------------------- +---------------------------- .. automodule:: verl.utils.py_functional :members: File System Utilities -------------------- +---------------------- .. automodule:: verl.utils.fs :members: Tracking Utilities ----------------- +------------------- .. automodule:: verl.utils.tracking :members: Checkpoint Management -------------------- +---------------------- .. automodule:: verl.utils.checkpoint.checkpoint_manager :members: @@ -31,7 +31,7 @@ Checkpoint Management :members: Dataset Utilities ---------------- +------------------ .. automodule:: verl.utils.dataset.rl_dataset :members: @@ -40,43 +40,73 @@ Dataset Utilities :members: Torch Functional Utilities ------------------------- +--------------------------- .. automodule:: verl.utils.torch_functional :members: Sequence Length Balancing ------------------------ +-------------------------- .. automodule:: verl.utils.seqlen_balancing :members: Ulysses Utilities ---------------- +------------------ .. automodule:: verl.utils.ulysses :members: Model Utilities -------------- +---------------- .. automodule:: verl.utils.model :members: FSDP Utilities ------------- +--------------- .. automodule:: verl.utils.fsdp_utils :members: Reward Score Utilities --------------------- +----------------------- .. automodule:: verl.utils.reward_score :members: Debug Utilities -------------- +---------------- .. automodule:: verl.utils.debug + :members: + +Ray Utilities +-------------- + +.. automodule:: verl.utils.ray_utils + :members: + +Model Utilities +------------- + +.. automodule:: verl.utils.model + :members: + +Sequence Length Balancing +----------------------- + +.. automodule:: verl.utils.seqlen_balancing + :members: + +Tracking Utilities +---------------- + +.. automodule:: verl.utils.tracking + :members: + +Megatron Utilities +---------------------------------- + +.. automodule:: verl.utils.megatron.pipeline_parallel :members: \ No newline at end of file diff --git a/tests/ray_cpu/test_ray_utils.py b/tests/ray_cpu/test_ray_utils.py new file mode 100644 index 00000000000..e9848384f5b --- /dev/null +++ b/tests/ray_cpu/test_ray_utils.py @@ -0,0 +1,40 @@ +import pytest +import ray + +from verl.utils.ray_utils import parallel_put + + +# Initialize Ray for testing if not already done globally +@pytest.fixture() +def init_ray(): + ray.init(num_cpus=4) + yield + ray.shutdown() + + +def test_parallel_put_basic(init_ray): + data = [1, "hello", {"a": 2}, [3, 4]] + refs = parallel_put(data) + assert len(refs) == len(data) + retrieved_data = [ray.get(ref) for ref in refs] + assert retrieved_data == data + + +def test_parallel_put_empty(init_ray): + data = [] + refs = parallel_put(data) + assert len(refs) == 0 + + +def test_parallel_put_workers(init_ray): + data = list(range(20)) + # Test with specific number of workers + refs = parallel_put(data, max_workers=4) + assert len(refs) == len(data) + retrieved_data = [ray.get(ref) for ref in refs] + assert retrieved_data == data + # Test with default workers (should cap) + refs_default = parallel_put(data) + assert len(refs_default) == len(data) + retrieved_data_default = [ray.get(ref) for ref in refs_default] + assert retrieved_data_default == data diff --git a/tests/verl/utils/megatron/test_pipeline_parallel.py b/tests/verl/utils/megatron/test_pipeline_parallel.py new file mode 100644 index 00000000000..e84114cea67 --- /dev/null +++ b/tests/verl/utils/megatron/test_pipeline_parallel.py @@ -0,0 +1,33 @@ +from verl.utils.megatron.pipeline_parallel import make_batch_generator + + +def test_make_batch_generator_no_vpp(): + batches = [1, 2, 3] + vpp_size = 1 + generator = make_batch_generator(batches, vpp_size) + assert list(generator) == batches + + +def test_make_batch_generator_with_vpp(): + batches = [{"data": 1}, {"data": 2}] + vpp_size = 2 + generators = make_batch_generator(batches, vpp_size) + assert isinstance(generators, list) + assert len(generators) == vpp_size + + # Check each generator yields the original batches + for gen in generators: + assert list(gen) == batches + + +def test_make_batch_generator_empty(): + batches = [] + vpp_size = 1 + generator = make_batch_generator(batches, vpp_size) + assert list(generator) == [] + + vpp_size = 3 + generators = make_batch_generator(batches, vpp_size) + assert len(generators) == vpp_size + for gen in generators: + assert list(gen) == [] diff --git a/tests/verl/utils/test_seqlen_balancing.py b/tests/verl/utils/test_seqlen_balancing.py new file mode 100644 index 00000000000..ac083b51c57 --- /dev/null +++ b/tests/verl/utils/test_seqlen_balancing.py @@ -0,0 +1,47 @@ +import pytest + +from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions + + +def test_get_seqlen_balanced_partitions_equal_size(): + seqlen_list = [10, 20, 70, 80, 100, 120] + k_partitions = 3 + partitions = get_seqlen_balanced_partitions(seqlen_list, k_partitions, equal_size=True) + + assert len(partitions) == k_partitions + all_indices = set() + partition_sums = [] + for p in partitions: + assert len(p) == len(seqlen_list) // k_partitions # Check equal size + all_indices.update(p) + partition_sums.append(sum(seqlen_list[i] for i in p)) + + assert all_indices == set(range(len(seqlen_list))) # Check all indices covered + # Check balance (sums should be close) - allow some tolerance + assert max(partition_sums) - min(partition_sums) <= max(seqlen_list) # Heuristic check + + +def test_get_seqlen_balanced_partitions_unequal_size(): + seqlen_list = [5, 10, 15, 20, 25, 100] + k_partitions = 2 + partitions = get_seqlen_balanced_partitions(seqlen_list, k_partitions, equal_size=False) + + assert len(partitions) == k_partitions + all_indices = set() + partition_sums = [] + for p in partitions: + assert len(p) > 0 # Check not empty + all_indices.update(p) + partition_sums.append(sum(seqlen_list[i] for i in p)) + + assert all_indices == set(range(len(seqlen_list))) # Check all indices covered + # Check balance (sums should be close) + assert max(partition_sums) - min(partition_sums) <= max(seqlen_list) # Heuristic check + + +def test_get_seqlen_balanced_partitions_assertions(): + with pytest.raises(AssertionError): + get_seqlen_balanced_partitions([1, 2], 3, False) # n < k + + with pytest.raises(AssertionError): + get_seqlen_balanced_partitions([1, 2, 3], 2, True) # n % k != 0 for equal_size diff --git a/verl/utils/megatron/__init__.py b/verl/utils/megatron/__init__.py index 1ce90c5eb35..36f611bb90f 100644 --- a/verl/utils/megatron/__init__.py +++ b/verl/utils/megatron/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .pipeline_parallel import make_batch_generator +from .sequence_parallel import pad_to_sequence_parallel +from .tensor_parallel import vocab_parallel_log_probs_from_logits_response_rmpad + +__all__ = ["make_batch_generator", "pad_to_sequence_parallel", "vocab_parallel_log_probs_from_logits_response_rmpad"] diff --git a/verl/utils/megatron/pipeline_parallel.py b/verl/utils/megatron/pipeline_parallel.py index b7e272763ff..50ba6973625 100644 --- a/verl/utils/megatron/pipeline_parallel.py +++ b/verl/utils/megatron/pipeline_parallel.py @@ -47,6 +47,20 @@ def compute_transformers_input_shapes(batches, meta_info): def make_batch_generator(batches, vpp_size): + """ + Creates a batch generator suitable for Megatron pipeline parallelism, + handling virtual pipeline parallelism (VPP). + + If VPP is used (vpp_size > 1), it duplicates the batch iterator for each + virtual pipeline stage. Otherwise, it returns a single iterator. + + Args: + batches: An iterable (e.g., list) of micro-batches. + vpp_size (int): The virtual pipeline model parallel size. + + Returns: + An iterator or a list of iterators over the micro-batches. + """ if vpp_size > 1: # has vpp batch_generator = [batches] * vpp_size # number of vpp chunks diff --git a/verl/utils/megatron/sequence_parallel.py b/verl/utils/megatron/sequence_parallel.py index 9f4cbc08e87..52fda9b30cc 100644 --- a/verl/utils/megatron/sequence_parallel.py +++ b/verl/utils/megatron/sequence_parallel.py @@ -33,6 +33,7 @@ def pad_to_sequence_parallel(unpad_tokens: torch.Tensor): unpad_tokens: (total_nnz, ...). Tokens after removing padding Returns: + the padded tokens: (total_nnz + pad_size,...) """ total_nnz = unpad_tokens.shape[0] diff --git a/verl/utils/ray_utils.py b/verl/utils/ray_utils.py index 49b60ef45c6..db3c990d549 100644 --- a/verl/utils/ray_utils.py +++ b/verl/utils/ray_utils.py @@ -16,11 +16,25 @@ """ import concurrent.futures +from typing import Any, List, Optional import ray -def parallel_put(data_list, max_workers=None): +def parallel_put(data_list: List[Any], max_workers: Optional[int] = None): + """ + Puts a list of data into the Ray object store in parallel using a thread pool. + + Args: + data_list (List[Any]): A list of Python objects to be put into the Ray object store. + max_workers (int, optional): The maximum number of worker threads to use. + Defaults to min(len(data_list), 16). + + Returns: + List[ray.ObjectRef]: A list of Ray object references corresponding to the input data_list, + maintaining the original order. + """ + def put_data(index, data): return index, ray.put(data) diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index 16dca31c961..ff2cac2bc00 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -142,20 +142,30 @@ def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool): - """get order of seq lengths to make partitions balanced, this is - used in balacing sum of seqlength across dp ranks and microbatches - Parameters: - seqlen_list (List[int]): - seq lengths of each items - k_partitions (int): - resulting number of partitions - equal_size (bool): - if True, number of items in each partitions must be equal. - if False, only consider balancing the sum, each partition can have - variable number of items + """ + Calculates partitions of indices from seqlen_list such that the sum of sequence lengths + in each partition is balanced. Uses the Karmarkar-Karp differencing method. + + This is useful for balancing workload across devices or batches, especially when + dealing with variable sequence lengths. + + Args: + seqlen_list (List[int]): A list of sequence lengths for each item. + k_partitions (int): The desired number of partitions. + equal_size (bool): If True, ensures that each partition has the same number of items. + Requires len(seqlen_list) to be divisible by k_partitions. + If False, partitions can have varying numbers of items, focusing + only on balancing the sum of sequence lengths. + Returns: - partitions (List[List[int]]): - return k_partitions list containing the index of items. + List[List[int]]: A list containing k_partitions lists. Each inner list contains the + original indices of the items assigned to that partition. The indices + within each partition list are sorted. + + Raises: + AssertionError: If len(seqlen_list) < k_partitions. + AssertionError: If equal_size is True and len(seqlen_list) is not divisible by k_partitions. + AssertionError: If any resulting partition is empty. """ assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index 00026136aa7..0a6d20dff4e 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -23,6 +23,16 @@ class Tracking: + """A unified tracking interface for logging experiment data to multiple backends. + + This class provides a centralized way to log experiment metrics, parameters, and artifacts + to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console. + + Attributes: + supported_backend: List of supported tracking backends. + logger: Dictionary of initialized logger instances for each backend. + """ + supported_backend = ["wandb", "mlflow", "swanlab", "vemlp_wandb", "tensorboard", "console"] def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = "console", config=None): From 767ccc353db704cb194f260a29d3f35572640725 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 4 May 2025 21:43:41 -0700 Subject: [PATCH 06/31] addding more doc string and unit tests Signed-off-by: Hongpeng Guo --- docs/api/utils.rst | 14 ++++++++++---- verl/utils/megatron/__init__.py | 6 ------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/api/utils.rst b/docs/api/utils.rst index a6026ad70cf..50eafe50ca8 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -88,25 +88,31 @@ Ray Utilities :members: Model Utilities -------------- +---------------- .. automodule:: verl.utils.model :members: Sequence Length Balancing ------------------------ +-------------------------- .. automodule:: verl.utils.seqlen_balancing :members: Tracking Utilities ----------------- +-------------------- .. automodule:: verl.utils.tracking :members: Megatron Utilities ----------------------------------- +------------------- .. automodule:: verl.utils.megatron.pipeline_parallel + :members: + +.. automodule:: verl.utils.megatron.sequence_parallel + :members: + +.. automodule:: verl.utils.megatron.tensor_parallel :members: \ No newline at end of file diff --git a/verl/utils/megatron/__init__.py b/verl/utils/megatron/__init__.py index 36f611bb90f..1ce90c5eb35 100644 --- a/verl/utils/megatron/__init__.py +++ b/verl/utils/megatron/__init__.py @@ -11,9 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .pipeline_parallel import make_batch_generator -from .sequence_parallel import pad_to_sequence_parallel -from .tensor_parallel import vocab_parallel_log_probs_from_logits_response_rmpad - -__all__ = ["make_batch_generator", "pad_to_sequence_parallel", "vocab_parallel_log_probs_from_logits_response_rmpad"] From ba4c98abe7d68b140d4beec44772c30d6283ded3 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 4 May 2025 21:56:34 -0700 Subject: [PATCH 07/31] fix pipelines Signed-off-by: Hongpeng Guo --- .github/workflows/verl_unit_test.yml | 1 + tests/ray_cpu/test_ray_utils.py | 14 ++++++++++++++ .../verl/utils/megatron/test_pipeline_parallel.py | 14 ++++++++++++++ tests/verl/utils/test_seqlen_balancing.py | 15 +++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/.github/workflows/verl_unit_test.yml b/.github/workflows/verl_unit_test.yml index c55caca0ded..14cd2422214 100644 --- a/.github/workflows/verl_unit_test.yml +++ b/.github/workflows/verl_unit_test.yml @@ -41,6 +41,7 @@ jobs: - name: Install the current repository run: | pip install -e .[test] + pip install megatron-core - name: Running test protocol.py run: | cd tests/verl diff --git a/tests/ray_cpu/test_ray_utils.py b/tests/ray_cpu/test_ray_utils.py index e9848384f5b..a73b9fb3a36 100644 --- a/tests/ray_cpu/test_ray_utils.py +++ b/tests/ray_cpu/test_ray_utils.py @@ -1,3 +1,17 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest import ray diff --git a/tests/verl/utils/megatron/test_pipeline_parallel.py b/tests/verl/utils/megatron/test_pipeline_parallel.py index e84114cea67..cf442a03b58 100644 --- a/tests/verl/utils/megatron/test_pipeline_parallel.py +++ b/tests/verl/utils/megatron/test_pipeline_parallel.py @@ -1,3 +1,17 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from verl.utils.megatron.pipeline_parallel import make_batch_generator diff --git a/tests/verl/utils/test_seqlen_balancing.py b/tests/verl/utils/test_seqlen_balancing.py index ac083b51c57..25ad3fa692a 100644 --- a/tests/verl/utils/test_seqlen_balancing.py +++ b/tests/verl/utils/test_seqlen_balancing.py @@ -1,3 +1,18 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import pytest from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions From 752a5be209786ac53f5ccc91af406a6756e4c833 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 4 May 2025 22:45:50 -0700 Subject: [PATCH 08/31] handle comments Signed-off-by: Hongpeng Guo --- docs/api/single_controller.rst | 9 +- verl/single_controller/ray/base.py | 171 ++++++++++++++++++++++++- verl/single_controller/ray/megatron.py | 11 ++ 3 files changed, 188 insertions(+), 3 deletions(-) diff --git a/docs/api/single_controller.rst b/docs/api/single_controller.rst index 37f93abd922..1356e98a401 100644 --- a/docs/api/single_controller.rst +++ b/docs/api/single_controller.rst @@ -11,7 +11,7 @@ Core APIs ~~~~~~~~~~~~~~~~~ .. autoclass:: verl.single_controller.Worker - :members: __init__, __new__, get_fused_worker_by_name, get_master_addr_port, get_cuda_visible_devices, world_size, rank, execute_with_func_generator, execute_func_rank_zero + :members: __init__, __new__, get_master_addr_port, get_cuda_visible_devices, world_size, rank, execute_with_func_generator, execute_func_rank_zero .. autoclass:: verl.single_controller.WorkerGroup :members: __init__, start_worker_aliveness_check, world_size @@ -22,6 +22,13 @@ Core APIs .. autoclass:: verl.single_controller.ResourcePool :members: __init__, add_note, world_size, __call__, store, local_world_size_list, local_rank_list +.. automodule:: verl.single_controller.ray + :members: RayWorkerGroup, create_colocated_worker_cls, create_colocated_worker_cls_fused + +.. autoclass:: verl.single_controller.ray.megatron.NVMegatronRayWorkerGroup + :members: __init__ + + Decorator APIs ~~~~~~~~~~~~~~~~~ diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 72be01588eb..2cd06b276f5 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -145,6 +145,13 @@ def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResour class RayClassWithInitArgs(ClassWithInitArgs): + """A wrapper class for Ray actors with initialization arguments. + + This class extends ClassWithInitArgs to provide additional functionality for + configuring and creating Ray actors with specific resource requirements and + scheduling strategies. + """ + def __init__(self, cls, *args, **kwargs) -> None: # self._options = kwargs.pop('options', dict()) super().__init__(cls, *args, **kwargs) @@ -152,12 +159,34 @@ def __init__(self, cls, *args, **kwargs) -> None: self._additional_resource = {} def set_additional_resource(self, additional_resource): + """Set additional resource requirements for the actor. + + Args: + additional_resource: Dictionary specifying additional resource requirements + """ self._additional_resource = additional_resource def update_options(self, options: Dict): + """Update the Ray actor creation options. + + Args: + options: Dictionary of options to update + """ self._options.update(options) def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None) -> Any: + """Create and return a Ray actor with the configured options. + + Args: + placement_group: Ray placement group for scheduling + placement_group_bundle_idx: Index of the bundle in the placement group + use_gpu: Whether to use GPU resources + num_gpus: Number of GPUs to allocate + sharing_with: Actor to share resources with + + Returns: + A Ray actor handle with the configured options + """ if sharing_with is not None: target_node_id = ray.get(sharing_with.get_node_id.remote()) cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) @@ -181,6 +210,13 @@ def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = class RayWorkerGroup(WorkerGroup): + """A group of Ray workers that can be managed collectively. + + This class extends WorkerGroup to provide Ray-specific functionality for + creating and managing groups of Ray actors with specific resource requirements + and scheduling strategies. + """ + def __init__( self, resource_pool: RayResourcePool = None, @@ -192,6 +228,18 @@ def __init__( ray_wait_register_center_timeout: int = 300, **kwargs, ) -> None: + """Initialize a RayWorkerGroup. + + Args: + resource_pool: Resource pool for worker allocation + ray_cls_with_init: Class with initialization arguments for workers + bin_pack: Whether to use strict bin packing for resource allocation + name_prefix: Prefix for worker names + detached: Whether workers should be detached + worker_names: Names of existing workers to attach to + ray_wait_register_center_timeout: Timeout for waiting on register center + **kwargs: Additional keyword arguments + """ super().__init__(resource_pool=resource_pool, **kwargs) self.ray_cls_with_init = ray_cls_with_init self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix @@ -217,15 +265,36 @@ def __init__( self.method_names = [] def _is_worker_alive(self, worker: ray.actor.ActorHandle): + """Check if a worker actor is still alive. + + Args: + worker: Ray actor handle to check + + Returns: + bool: True if the worker is alive, False otherwise + """ worker_state_dict = get_actor(worker._actor_id.hex()) return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False def _init_with_detached_workers(self, worker_names): + """Initialize the worker group with existing detached workers. + + Args: + worker_names: Names of existing workers to attach to + """ workers = [ray.get_actor(name=name) for name in worker_names] self._workers = workers self._world_size = len(worker_names) def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached): + """Initialize the worker group by creating new workers from a resource pool. + + Args: + resource_pool: Resource pool for worker allocation + ray_cls_with_init: Class with initialization arguments for workers + bin_pack: Whether to use strict bin packing for resource allocation + detached: Whether workers should be detached + """ use_gpu = resource_pool.use_gpu strategy = "PACK" @@ -321,13 +390,27 @@ def from_detached( worker_names=None, ray_cls_with_init=None, ): + """Create a worker group from existing detached workers. + + Args: + name_prefix: Prefix for worker names + worker_names: Names of existing workers to attach to + ray_cls_with_init: Class with initialization arguments for workers + + Returns: + A new RayWorkerGroup instance + """ worker_group = cls(resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names) return worker_group def spawn(self, prefix_set): - """ - spawn to a dictionary of worker groups, each with a subset of method with prefix. + """Spawn to a dictionary of worker groups, each with a subset of method with prefix. + + Args: + prefix_set: Set of prefixes to create worker groups for + Returns: + Dictionary of worker groups keyed by prefix """ if self.fused_worker_used: return self.spawn_fused(prefix_set) @@ -357,6 +440,14 @@ def _rebind_actor_methods(worker_group, actor_name): return new_worker_group_dict def spawn_fused(self, prefix_set): + """Create a dictionary of worker groups for fused workers. + + Args: + prefix_set: Set of prefixes to create worker groups for + + Returns: + Dictionary of worker groups keyed by prefix + """ wg_dict = dict() for key in prefix_set: new_wg = deepcopy(self) @@ -366,6 +457,11 @@ def spawn_fused(self, prefix_set): return wg_dict def fuse(self, prefix_set): + """Fuse multiple worker groups into the current worker group. + + Args: + prefix_set: Set of prefixes to fuse into the worker group + """ if self.wg_dict is None: self.wg_dict = self.spawn(prefix_set) for role_name, role_wg in self.wg_dict.items(): @@ -373,6 +469,17 @@ def fuse(self, prefix_set): self.method_names = self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) def _execute_remote_single_worker(self, worker, method_name: str, *args, **kwargs): + """Execute a method on a single worker remotely. + + Args: + worker: The worker actor handle + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + Remote object reference to the method execution + """ if self.fused_worker_used and method_name not in self.method_names: remote_call = getattr(worker, self.fused_worker_execute_fn_name) return remote_call.remote(f"{self.sub_cls_name}_fwmn_{method_name}", *args, **kwargs) @@ -381,21 +488,81 @@ def _execute_remote_single_worker(self, worker, method_name: str, *args, **kwarg return remote_call.remote(*args, **kwargs) def execute_rank_zero_sync(self, method_name: str, *args, **kwargs): + """Execute a method on rank zero worker synchronously. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + Result of the method execution + """ return ray.get(self.execute_rank_zero_async(method_name, *args, **kwargs)) def execute_rank_zero_async(self, method_name: str, *args, **kwargs): + """Execute a method on rank zero worker asynchronously. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + Remote object reference to the method execution + """ return self._execute_remote_single_worker(self._workers[0], method_name, *args, **kwargs) def execute_rank_zero(self, method_name: str, *args, **kwargs): + """Alias for execute_rank_zero_async. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + Remote object reference to the method execution + """ return self.execute_rank_zero_async(method_name, *args, **kwargs) def execute_all(self, method_name: str, *args, **kwargs): + """Alias for execute_all_async. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + List of remote object references to the method executions + """ return self.execute_all_async(method_name, *args, **kwargs) def execute_all_sync(self, method_name: str, *args, **kwargs): + """Execute a method on all workers synchronously. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + List of results from all workers + """ return ray.get(self.execute_all_async(method_name, *args, **kwargs)) def execute_all_async(self, method_name: str, *args, **kwargs): + """Execute a method on all workers asynchronously. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + List of remote object references to the method executions + """ # Here, we assume that if all arguments in args and kwargs are lists, # and their lengths match len(self._workers), we'll distribute each # element in these lists to the corresponding worker diff --git a/verl/single_controller/ray/megatron.py b/verl/single_controller/ray/megatron.py index 8baf03e6f17..9c4aa5bbb45 100644 --- a/verl/single_controller/ray/megatron.py +++ b/verl/single_controller/ray/megatron.py @@ -30,6 +30,17 @@ class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): """ def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs): + """ + Initialize the NVMegatronRayWorkerGroup. + + Args: + resource_pool (RayResourcePool): + The resource pool containing worker resources + ray_cls_with_init (RayClassWithInitArgs): + The Ray class with initialization arguments + **kwargs: + Additional keyword arguments to pass to the parent class + """ super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") self._megatron_global_info: DistGlobalInfo = ray.get(self.execute_rank_zero_async(method_name="get_megatron_global_info")) From 5d14578a49a02589e26f09ef9acf8d445f918282 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 4 May 2025 22:52:22 -0700 Subject: [PATCH 09/31] fix nits Signed-off-by: Hongpeng Guo --- docs/api/trainer.rst | 10 +++++----- docs/api/utils.rst | 34 +++++++++++++++++----------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/docs/api/trainer.rst b/docs/api/trainer.rst index d5174f6a887..7f408f208ee 100644 --- a/docs/api/trainer.rst +++ b/docs/api/trainer.rst @@ -1,28 +1,28 @@ PPO Trainer Interface -=========== +================================ This section documents the trainer utilities in the VERL library. PPO Ray Trainer ------------- +---------------------------- .. automodule:: verl.trainer.ppo.ray_trainer :members: PPO Metrics ------------- +-------------------------- .. automodule:: verl.trainer.ppo.metric_utils :members: PPO Core Algorithms --------------------- +----------------------------- .. automodule:: verl.trainer.ppo.core_algos :members: PPO Reward ------------ +--------------------------- .. automodule:: verl.trainer.ppo.reward :members: \ No newline at end of file diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 50eafe50ca8..f750bd3ff3f 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -4,25 +4,25 @@ Utilities This section documents the utility functions and classes in the VERL library. Python Functional Utilities ----------------------------- +------------------------------ .. automodule:: verl.utils.py_functional :members: File System Utilities ----------------------- +------------------------ .. automodule:: verl.utils.fs :members: Tracking Utilities -------------------- +--------------------- .. automodule:: verl.utils.tracking :members: Checkpoint Management ----------------------- +------------------------ .. automodule:: verl.utils.checkpoint.checkpoint_manager :members: @@ -31,7 +31,7 @@ Checkpoint Management :members: Dataset Utilities ------------------- +--------------------- .. automodule:: verl.utils.dataset.rl_dataset :members: @@ -40,73 +40,73 @@ Dataset Utilities :members: Torch Functional Utilities ---------------------------- +----------------------------- .. automodule:: verl.utils.torch_functional :members: Sequence Length Balancing --------------------------- +---------------------------- .. automodule:: verl.utils.seqlen_balancing :members: Ulysses Utilities ------------------- +-------------------- .. automodule:: verl.utils.ulysses :members: Model Utilities ----------------- +------------------ .. automodule:: verl.utils.model :members: FSDP Utilities ---------------- +------------------ .. automodule:: verl.utils.fsdp_utils :members: Reward Score Utilities ------------------------ +------------------------- .. automodule:: verl.utils.reward_score :members: Debug Utilities ----------------- +------------------- .. automodule:: verl.utils.debug :members: Ray Utilities --------------- +----------------- .. automodule:: verl.utils.ray_utils :members: Model Utilities ----------------- +------------------- .. automodule:: verl.utils.model :members: Sequence Length Balancing --------------------------- +----------------------------- .. automodule:: verl.utils.seqlen_balancing :members: Tracking Utilities --------------------- +----------------------- .. automodule:: verl.utils.tracking :members: Megatron Utilities -------------------- +---------------------- .. automodule:: verl.utils.megatron.pipeline_parallel :members: From a05690c76374cb3849740de7841e9d1dd13b9205 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sat, 17 May 2025 21:48:50 -0700 Subject: [PATCH 10/31] update doc Signed-off-by: Hongpeng Guo --- docs/api/single_controller.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/api/single_controller.rst b/docs/api/single_controller.rst index 1356e98a401..e680144763d 100644 --- a/docs/api/single_controller.rst +++ b/docs/api/single_controller.rst @@ -11,19 +11,19 @@ Core APIs ~~~~~~~~~~~~~~~~~ .. autoclass:: verl.single_controller.Worker - :members: __init__, __new__, get_master_addr_port, get_cuda_visible_devices, world_size, rank, execute_with_func_generator, execute_func_rank_zero + :members: __init__, __new__, get_master_addr_port, get_cuda_visible_devices, world_size, rank .. autoclass:: verl.single_controller.WorkerGroup - :members: __init__, start_worker_aliveness_check, world_size + :members: __init__, world_size .. autoclass:: verl.single_controller.ClassWithInitArgs :members: __init__, __call__ .. autoclass:: verl.single_controller.ResourcePool - :members: __init__, add_note, world_size, __call__, store, local_world_size_list, local_rank_list + :members: __init__, world_size, local_world_size_list, local_rank_list .. automodule:: verl.single_controller.ray - :members: RayWorkerGroup, create_colocated_worker_cls, create_colocated_worker_cls_fused + :members: RayWorkerGroup, create_colocated_worker_cls .. autoclass:: verl.single_controller.ray.megatron.NVMegatronRayWorkerGroup :members: __init__ From c5f616f73f762d7f0817bb11f6944a4e3d2304e2 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sat, 17 May 2025 22:20:01 -0700 Subject: [PATCH 11/31] fix some docs errors Signed-off-by: Hongpeng Guo --- docs/start/install.rst | 9 ++++----- docs/workers/sglang_worker.rst | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/start/install.rst b/docs/start/install.rst index d96c8382ad1..a4e62d72c8d 100644 --- a/docs/start/install.rst +++ b/docs/start/install.rst @@ -214,11 +214,10 @@ Install with AMD GPUs - ROCM kernel support ------------------------------------------------------------------ When you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and run it. - If you encounter any issues in using AMD GPUs running verl, feel free to contact me - `Yusheng Su `_. Find the docker for AMD ROCm: `docker/Dockerfile.rocm `_ -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: .. code-block:: bash @@ -267,15 +266,15 @@ Find the docker for AMD ROCm: `docker/Dockerfile.rocm Date: Sat, 17 May 2025 22:46:44 -0700 Subject: [PATCH 12/31] add doc test pipeline Signed-off-by: Hongpeng Guo --- .github/workflows/doc.yml | 62 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 .github/workflows/doc.yml diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml new file mode 100644 index 00000000000..520267813f6 --- /dev/null +++ b/.github/workflows/doc.yml @@ -0,0 +1,62 @@ +name: doc_test + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + - v0.* + pull_request: + branches: + - main + - v0.* + paths: + - "**/*.py" + - "docs/**" + - .github/workflows/doc.yml + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + doc_test: + runs-on: ubuntu-latest + timeout-minutes: 5 # Increase this timeout value as needed + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install the current repository + run: | + pip install -e .[test] + pip install -r requirements-docs.txt + + - name: Run doc make html + run: | + cd docs + make clean + make html + + - name: Upload to GitHub Pages + uses: actions/upload-pages-artifact@v1 + with: + path: doc/_build/html + + - name: Deploy Pages + uses: actions/deploy-pages@v1 + + - name: Show published URL + run: | + echo "📖 Docs are live at: https://${{ github.repository_owner }}.github.io/${{ github.event.repository.name }}/" \ No newline at end of file From 1b9cf918f968712eb1c9211c2b4f9b3b2eff47e6 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sat, 17 May 2025 22:58:04 -0700 Subject: [PATCH 13/31] update pipeline Signed-off-by: Hongpeng Guo --- .github/workflows/doc.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 520267813f6..8bd67b94f13 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -50,7 +50,7 @@ jobs: make html - name: Upload to GitHub Pages - uses: actions/upload-pages-artifact@v1 + uses: actions/upload-pages-artifact@v3 with: path: doc/_build/html From 0de7ee25762ee370918a98227ec95787a8a531be Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sat, 17 May 2025 23:00:58 -0700 Subject: [PATCH 14/31] update pipeline Signed-off-by: Hongpeng Guo --- .github/workflows/doc.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 8bd67b94f13..b333c108482 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -41,7 +41,7 @@ jobs: - name: Install the current repository run: | pip install -e .[test] - pip install -r requirements-docs.txt + pip install -r docs/requirements-docs.txt - name: Run doc make html run: | From 042ef6fc7b8e363293650d1fef091877d39c867a Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sat, 17 May 2025 23:04:26 -0700 Subject: [PATCH 15/31] update pipeline Signed-off-by: Hongpeng Guo --- .github/workflows/doc.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index b333c108482..b7af7efd317 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -52,7 +52,7 @@ jobs: - name: Upload to GitHub Pages uses: actions/upload-pages-artifact@v3 with: - path: doc/_build/html + path: docs/_build/html - name: Deploy Pages uses: actions/deploy-pages@v1 From f10305c6f117887efb5404ab4922b1527f212dbb Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sat, 17 May 2025 23:14:23 -0700 Subject: [PATCH 16/31] update pipeline Signed-off-by: Hongpeng Guo --- .github/workflows/doc.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index b7af7efd317..d7a4932f4e2 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -23,7 +23,9 @@ concurrency: # Declare permissions just read content. permissions: - contents: read + contents: read # for checkout + pages: write # for deploy-pages + id-token: write # for deploy-pages jobs: doc_test: From d9cc14eb619641d915e17694587382ab0d15c03e Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sat, 17 May 2025 23:36:16 -0700 Subject: [PATCH 17/31] update pipeline Signed-off-by: Hongpeng Guo --- .github/workflows/doc.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index d7a4932f4e2..796e7f833f7 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -57,8 +57,13 @@ jobs: path: docs/_build/html - name: Deploy Pages - uses: actions/deploy-pages@v1 + if: github.event_name == 'pull_request' + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: docs/_build/html + destination_dir: pr-${{ github.event.number }} - name: Show published URL run: | - echo "📖 Docs are live at: https://${{ github.repository_owner }}.github.io/${{ github.event.repository.name }}/" \ No newline at end of file + echo "📖 Docs are live at: https://${{ github.repository_owner }}.github.io/${{ github.event.repository.name }}/pr-${{ github.event.number }}/" \ No newline at end of file From 17dabae59bcddca78ae79ad0d20c5f89f6cd89f4 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sat, 17 May 2025 23:41:35 -0700 Subject: [PATCH 18/31] update pipeline Signed-off-by: Hongpeng Guo --- .github/workflows/doc.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 796e7f833f7..2f4967e160b 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -58,7 +58,7 @@ jobs: - name: Deploy Pages if: github.event_name == 'pull_request' - uses: peaceiris/actions-gh-pages@v3 + uses: actions/deploy-pages@v1 with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: docs/_build/html From 875252647e96e4148977188212833f4c9e06421a Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sat, 17 May 2025 23:50:45 -0700 Subject: [PATCH 19/31] remove the deply module Signed-off-by: Hongpeng Guo --- .github/workflows/doc.yml | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 2f4967e160b..ff07c07fc39 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -50,20 +50,3 @@ jobs: cd docs make clean make html - - - name: Upload to GitHub Pages - uses: actions/upload-pages-artifact@v3 - with: - path: docs/_build/html - - - name: Deploy Pages - if: github.event_name == 'pull_request' - uses: actions/deploy-pages@v1 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: docs/_build/html - destination_dir: pr-${{ github.event.number }} - - - name: Show published URL - run: | - echo "📖 Docs are live at: https://${{ github.repository_owner }}.github.io/${{ github.event.repository.name }}/pr-${{ github.event.number }}/" \ No newline at end of file From e04f4ee897c6ada3d0bcd64e15e6eadc3c7e4e87 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 18 May 2025 18:05:54 -0700 Subject: [PATCH 20/31] remove redundant apis of single controller Signed-off-by: Hongpeng Guo --- docs/api/single_controller.rst | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/docs/api/single_controller.rst b/docs/api/single_controller.rst index e680144763d..f10b6521c87 100644 --- a/docs/api/single_controller.rst +++ b/docs/api/single_controller.rst @@ -23,21 +23,4 @@ Core APIs :members: __init__, world_size, local_world_size_list, local_rank_list .. automodule:: verl.single_controller.ray - :members: RayWorkerGroup, create_colocated_worker_cls - -.. autoclass:: verl.single_controller.ray.megatron.NVMegatronRayWorkerGroup - :members: __init__ - - - -Decorator APIs -~~~~~~~~~~~~~~~~~ -.. autofunction:: verl.single_controller.base.decorator.register - -.. autoclass:: verl.single_controller.base.decorator.Dispatch - :members: RANK_ZERO, ONE_TO_ALL, ALL_TO_ALL, MEGATRON_COMPUTE, MEGATRON_PP_AS_DP, MEGATRON_PP_ONLY, MEGATRON_COMPUTE_PROTO, MEGATRON_PP_AS_DP_PROTO, DP_COMPUTE, DP_COMPUTE_PROTO, DP_COMPUTE_PROTO_WITH_FUNC, DP_COMPUTE_METRIC, DIRECT_ROLLOUT_METHOD - :member-order: bysource - -.. autoclass:: verl.single_controller.base.decorator.Execute - :members: ALL, RANK_ZERO - :member-order: bysource \ No newline at end of file + :members: RayWorkerGroup, create_colocated_worker_cls \ No newline at end of file From 110c8f2d86a7303bd34d6b6c37ebe6a94a62d703 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 18 May 2025 23:03:18 -0700 Subject: [PATCH 21/31] fix merge errror Signed-off-by: Hongpeng Guo --- verl/single_controller/ray/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index dc5fd9bb497..c0822e3cf27 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -401,7 +401,7 @@ def from_detached( Returns: A new RayWorkerGroup instance """ - worker_group = cls(resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names) + worker_group = cls(resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names, worker_handles=worker_handles) return worker_group def spawn(self, prefix_set): @@ -417,9 +417,6 @@ def spawn(self, prefix_set): return self.spawn_fused(prefix_set) def _rebind_actor_methods(worker_group, actor_name): - """ - bind the method with actor_prefix to its original name - """ prefix: str = actor_name + "_" for method_name in dir(worker_group): if method_name.startswith(prefix): From 632189bd36a43c9f94bda5cd3189980cb19aaec3 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 18 May 2025 23:06:22 -0700 Subject: [PATCH 22/31] fix some doc strings Signed-off-by: Hongpeng Guo --- verl/single_controller/base/worker_group.py | 27 ++++++--------------- verl/single_controller/ray/megatron.py | 9 +++---- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/verl/single_controller/base/worker_group.py b/verl/single_controller/base/worker_group.py index 9275abce855..04c4f15bede 100644 --- a/verl/single_controller/base/worker_group.py +++ b/verl/single_controller/base/worker_group.py @@ -35,12 +35,9 @@ def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_p """Initialize the ResourcePool with node processes and GPU configuration. Args: - process_on_nodes (List[int], optional): - List of process counts per node. Defaults to empty list. - max_colocate_count (int, optional): - Maximum number of processes that can be colocated. Defaults to 10. - n_gpus_per_node (int, optional): - Number of GPUs available per node. Defaults to 8. + process_on_nodes (List[int], optional): List of process counts per node. Defaults to empty list. + max_colocate_count (int, optional): Maximum number of processes that can be colocated. Defaults to 10. + n_gpus_per_node (int, optional): Number of GPUs available per node. Defaults to 8. """ if process_on_nodes is None: process_on_nodes = [] @@ -49,7 +46,6 @@ def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_p self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node def add_node(self, process_count): - """Add a new node to the resource pool with the specified process count.""" self._store.append(process_count) @property @@ -58,12 +54,10 @@ def world_size(self): return sum(self._store) def __call__(self) -> Any: - """Get the list of processes count per node of the resource pool.""" return self._store @property def store(self): - """The store of the process count per node.""" return self._store def local_world_size_list(self) -> List[int]: @@ -88,12 +82,9 @@ def __init__(self, cls, *args, **kwargs) -> None: """Initialize the ClassWithInitArgs instance. Args: - cls: - The class to be instantiated later - *args: - Positional arguments for the class constructor - **kwargs: - Keyword arguments for the class constructor + cls: The class to be instantiated later + *args: Positional arguments for the class constructor + **kwargs: Keyword arguments for the class constructor """ self.cls = cls self.args = args @@ -188,10 +179,8 @@ def _bind_worker_method(self, user_defined_cls, func_generator): """Binds worker methods to the WorkerGroup based on registered attributes. Args: - user_defined_cls (type): - The class containing methods to bind - func_generator (Callable): - Function that generates the bound method + user_defined_cls (type): The class containing methods to bind + func_generator (Callable): Function that generates the bound method Returns: List[str]: List of method names that were successfully bound diff --git a/verl/single_controller/ray/megatron.py b/verl/single_controller/ray/megatron.py index 9c4aa5bbb45..4f56ac1bfab 100644 --- a/verl/single_controller/ray/megatron.py +++ b/verl/single_controller/ray/megatron.py @@ -34,12 +34,9 @@ def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWi Initialize the NVMegatronRayWorkerGroup. Args: - resource_pool (RayResourcePool): - The resource pool containing worker resources - ray_cls_with_init (RayClassWithInitArgs): - The Ray class with initialization arguments - **kwargs: - Additional keyword arguments to pass to the parent class + resource_pool (RayResourcePool): The resource pool containing worker resources + ray_cls_with_init (RayClassWithInitArgs): The Ray class with initialization arguments + **kwargs: Additional keyword arguments to pass to the parent class """ super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") From 6514f4278df8ecfb69731bbb520fd35bcecb1c83 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 19 May 2025 17:48:05 -0700 Subject: [PATCH 23/31] remove some methods from verl.single_controller.ray.RayWorkerGroup Signed-off-by: Hongpeng Guo --- docs/api/single_controller.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/api/single_controller.rst b/docs/api/single_controller.rst index f10b6521c87..369e59776c7 100644 --- a/docs/api/single_controller.rst +++ b/docs/api/single_controller.rst @@ -22,5 +22,7 @@ Core APIs .. autoclass:: verl.single_controller.ResourcePool :members: __init__, world_size, local_world_size_list, local_rank_list -.. automodule:: verl.single_controller.ray - :members: RayWorkerGroup, create_colocated_worker_cls \ No newline at end of file +.. autoclass:: verl.single_controller.ray.RayWorkerGroup + :members: __init__ + +.. autofunction:: verl.single_controller.ray.create_colocated_worker_cls \ No newline at end of file From 4af35dc66a73df9198a24659e5af6a031310cee4 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 19 May 2025 19:11:28 -0700 Subject: [PATCH 24/31] rename default_compute_score Signed-off-by: Hongpeng Guo --- docs/api/utils.rst | 3 +++ tests/sandbox/test_sandbox.py | 12 ++++++------ verl/trainer/ppo/reward.py | 6 +++--- verl/utils/reward_score/__init__.py | 2 +- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 9bba4bea3df..98823866eb6 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -21,6 +21,9 @@ Tracking Utilities .. automodule:: verl.utils.tracking :members: +Metrics Utilities +--------------------- + .. automodule:: verl.utils.metric :members: reduce_metrics diff --git a/tests/sandbox/test_sandbox.py b/tests/sandbox/test_sandbox.py index 12a1048d184..e3e0b10dba6 100644 --- a/tests/sandbox/test_sandbox.py +++ b/tests/sandbox/test_sandbox.py @@ -18,7 +18,7 @@ import pytest -from verl.utils.reward_score import _default_compute_score, prime_code, sandbox_fusion +from verl.utils.reward_score import default_compute_score, prime_code, sandbox_fusion from verl.utils.reward_score.prime_code import apps_check_correctness from verl.workers.reward_manager.prime import parallel_compute_score_async @@ -109,7 +109,7 @@ def test_parallelism(): ground_truth.extend(prime_math_gts) data_sources.extend(["numina_aops_forum"] * len(prime_math_answers)) - scores = asyncio.run(parallel_compute_score_async(_default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16)) + scores = asyncio.run(parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16)) print(scores) @@ -119,7 +119,7 @@ def test_prime_code(): """ data_source = "codecontests" for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores): - score = _default_compute_score(data_source, completion, ground_truth) + score = default_compute_score(data_source, completion, ground_truth) assert float(score) == score_ @@ -135,7 +135,7 @@ def test_prime_code_sandbox_fusion(): # Removed the previous 'if not sandbox_url' check block for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores): - score = _default_compute_score(data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url}) # <-- Use the URL obtained from the environment variable + score = default_compute_score(data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url}) # <-- Use the URL obtained from the environment variable assert float(score) == score_ @@ -153,7 +153,7 @@ def test_continuous_score_consistency(): prime_score, _ = sandbox_fusion.compute_score(os.environ.get("SANDBOX_FUSION_URL"), None, completion, ground_truth, continuous=True) # 2. Calculate score using sandbox_fusion with continuous=True - # Ensure the extra_info key triggers the sandbox_fusion path in _default_compute_score + # Ensure the extra_info key triggers the sandbox_fusion path in default_compute_score fusion_score, _ = prime_code.compute_score(completion, ground_truth, continuous=True) # 3. Assert scores are equal (using pytest.approx for float comparison) @@ -175,5 +175,5 @@ def test_check_correctness(): def test_prime_math(): data_source = "numina_aops_forum" for completion, ground_truth in zip(prime_math_answers, prime_math_gts): - score = _default_compute_score(data_source, completion, ground_truth) + score = default_compute_score(data_source, completion, ground_truth) assert float(score) == 1.0 diff --git a/verl/trainer/ppo/reward.py b/verl/trainer/ppo/reward.py index 23d4b1e70fa..7f6910ef35f 100644 --- a/verl/trainer/ppo/reward.py +++ b/verl/trainer/ppo/reward.py @@ -19,7 +19,7 @@ import ray from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from verl.utils.reward_score import default_compute_score def get_custom_reward_fn(config): @@ -87,9 +87,9 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): if sandbox_url: sandbox_manager = multiprocessing.Manager() _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) - final_compute_score = partial(_default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore) + final_compute_score = partial(default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore) else: - final_compute_score = _default_compute_score + final_compute_score = default_compute_score return reward_manager_cls( tokenizer=tokenizer, diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 36b293668a5..c55bc0140e4 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -14,7 +14,7 @@ # from . import gsm8k, math, prime_math, prime_code -def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None): +def default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None): """Compute the score for a given solution based on the data source. Args: From a8b520daa67b2afb4f5e86031c9d1d898f5cd705 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 19 May 2025 21:45:26 -0700 Subject: [PATCH 25/31] fix Signed-off-by: Hongpeng Guo --- .github/workflows/utils_cpu_test.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/utils_cpu_test.yml b/.github/workflows/utils_cpu_test.yml index 058545dc679..02319847c06 100644 --- a/.github/workflows/utils_cpu_test.yml +++ b/.github/workflows/utils_cpu_test.yml @@ -41,12 +41,7 @@ jobs: - name: Install the current repository run: | pip install -e .[test] -<<<<<<< HEAD:.github/workflows/verl_unit_test.yml - pip install megatron-core - name: Running test protocol.py -======= - - name: Running test protocol.py ->>>>>>> hpguo/doc:.github/workflows/utils_cpu_test.yml run: | cd tests pytest -s -x test_protocol.py From 2f2eb7f282ef6cc070d0f25bdf0d4f6e1f3a90a6 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 19 May 2025 21:46:17 -0700 Subject: [PATCH 26/31] fix Signed-off-by: Hongpeng Guo --- .github/workflows/utils_cpu_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/utils_cpu_test.yml b/.github/workflows/utils_cpu_test.yml index 02319847c06..e3ec220d078 100644 --- a/.github/workflows/utils_cpu_test.yml +++ b/.github/workflows/utils_cpu_test.yml @@ -41,7 +41,7 @@ jobs: - name: Install the current repository run: | pip install -e .[test] - - name: Running test protocol.py + - name: Running test protocol.py run: | cd tests pytest -s -x test_protocol.py From 2cfd60ae2a9aa7fe596da11caeda9e9e24dd8cbe Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 19 May 2025 22:16:18 -0700 Subject: [PATCH 27/31] add some doc strings Signed-off-by: Hongpeng Guo --- docs/api/trainer.rst | 32 ++++++++++++++++---------------- verl/trainer/ppo/ray_trainer.py | 9 +++++++-- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/docs/api/trainer.rst b/docs/api/trainer.rst index 39ad435d3a0..cd308c44d09 100644 --- a/docs/api/trainer.rst +++ b/docs/api/trainer.rst @@ -1,28 +1,28 @@ -PPO Trainer Interface +Trainer Interface ================================ -This section documents the trainer utilities in the VERL library. +Trainers drive the training loop. Introducing new trainer classes in case of new training paradiam is encouraged. -PPO Ray Trainer ----------------------------- +.. autosummary:: + :nosignatures: -.. automodule:: verl.trainer.ppo.ray_trainer - :members: + verl.trainer.ppo.ray_trainer.RayPPOTrainer -PPO Metrics --------------------------- -.. automodule:: verl.trainer.ppo.metric_utils - :members: +Core APIs +~~~~~~~~~~~~~~~~~ + +.. autoclass:: verl.trainer.ppo.ray_trainer.RayPPOTrainer + :members: __init__, init_workers, fit + + +.. automodule:: verl.utils.tokenizer + :members: hf_tokenizer -PPO Core Algorithms ------------------------------ .. automodule:: verl.trainer.ppo.core_algos - :members: + :members: agg_loss, kl_penalty, compute_policy_loss, kl_penalty -PPO Reward ---------------------------- .. automodule:: verl.trainer.ppo.reward - :members: + :members: load_reward_manager, compute_reward, compute_reward_async diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 1d021dee25a..0ec42672fe7 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -341,7 +341,7 @@ def __init__( collate_fn=None, train_sampler: Optional[Sampler] = None, ): - # assert torch.cuda.is_available(), 'cuda must be available on driver' + """Initialize distributed PPO trainer with Ray backend.""" self.tokenizer = tokenizer self.processor = processor @@ -724,7 +724,12 @@ def _validate(self): return metric_dict def init_workers(self): - """Init resource pool and worker group""" + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ self.resource_pool_manager.create_resource_pool() self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} From a1126f52009eb4ad6888ad41feef37ec88630f8d Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 19 May 2025 22:30:21 -0700 Subject: [PATCH 28/31] select only used utils apis in the doc Signed-off-by: Hongpeng Guo --- docs/api/utils.rst | 62 +++++++++------------------------------------- 1 file changed, 12 insertions(+), 50 deletions(-) diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 98823866eb6..58f62d8a30a 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -7,19 +7,19 @@ Python Functional Utilities ------------------------------ .. automodule:: verl.utils.py_functional - :members: + :members: append_to_dict File System Utilities ------------------------ .. automodule:: verl.utils.fs - :members: + :members: copy_to_local Tracking Utilities --------------------- .. automodule:: verl.utils.tracking - :members: + :members: Tracking, ValidationGenerationsLogger Metrics Utilities --------------------- @@ -31,88 +31,50 @@ Checkpoint Management ------------------------ .. automodule:: verl.utils.checkpoint.checkpoint_manager - :members: + :members: fins_latest_ckpt_path .. automodule:: verl.utils.checkpoint.fsdp_checkpoint_manager - :members: + :members: FSDPCheckpointManager Dataset Utilities --------------------- .. automodule:: verl.utils.dataset.rl_dataset - :members: - -.. automodule:: verl.utils.dataset.sft_dataset - :members: + :members: RLHFDataset, collate_fn Torch Functional Utilities ----------------------------- .. automodule:: verl.utils.torch_functional - :members: + :members: get_constant_schedule_with_warmup, masked_whiten, masked_mean, logprobs_from_logits Sequence Length Balancing ---------------------------- .. automodule:: verl.utils.seqlen_balancing - :members: + :members: get_reverse_idx, rearrage_micro_batches Ulysses Utilities -------------------- .. automodule:: verl.utils.ulysses - :members: + :members: gather_outpus_and_unpad, ulysses_pad_and_slice_inputs Model Utilities ------------------ .. automodule:: verl.utils.model - :members: + :members: print_model_size FSDP Utilities ------------------ .. automodule:: verl.utils.fsdp_utils - :members: - -Reward Score Utilities -------------------------- - -.. automodule:: verl.utils.reward_score - :members: + :members: get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer, Debug Utilities ------------------- .. automodule:: verl.utils.debug - :members: - -Ray Utilities ------------------ - -.. automodule:: verl.utils.ray_utils - :members: - -Model Utilities -------------------- - -.. automodule:: verl.utils.model - :members: - -Sequence Length Balancing ------------------------------ - -.. automodule:: verl.utils.seqlen_balancing - :members: - -Megatron Utilities ----------------------- - -.. automodule:: verl.utils.megatron.pipeline_parallel - :members: - -.. automodule:: verl.utils.megatron.sequence_parallel - :members: + :members: log_gpu_memory_usage, GPUMemoryLogger -.. automodule:: verl.utils.megatron.tensor_parallel - :members: From 76224fa0da0f8661887b28df24da955c863420b9 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 22 May 2025 13:28:58 -0700 Subject: [PATCH 29/31] small fix Signed-off-by: Hongpeng Guo --- docs/api/utils.rst | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 58f62d8a30a..caeae4f5c21 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -60,12 +60,6 @@ Ulysses Utilities .. automodule:: verl.utils.ulysses :members: gather_outpus_and_unpad, ulysses_pad_and_slice_inputs -Model Utilities ------------------- - -.. automodule:: verl.utils.model - :members: print_model_size - FSDP Utilities ------------------ From 24dcc64a374d67e09077e3f744aae89eaf5db2a8 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 22 May 2025 13:29:20 -0700 Subject: [PATCH 30/31] small fix Signed-off-by: Hongpeng Guo --- verl/utils/debug/performance.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index f461b3f47a8..783e2b9754e 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -53,17 +53,12 @@ def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging class GPUMemoryLogger(DecoratorLoggerBase): """A decorator class to log GPU memory usage. - Usage: - For example, in actor function, we initialize a GPUMemoryLogger - - ``` - from verl.utils.debug.performance import GPUMemoryLogger - @GPUMemoryLogger(role="actor") - def update_actor(self, batch): - # do something - return - ``` - + Example: + >>> from verl.utils.debug.performance import GPUMemoryLogger + >>> @GPUMemoryLogger(role="actor") + >>> def update_actor(self, batch): + ... # real actor update logics + ... return """ def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True): From 5a98b610899b99b4b27fdedb088214c02a7e520c Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 26 May 2025 19:45:20 -0700 Subject: [PATCH 31/31] make doc strings better, use google style doc string Signed-off-by: Hongpeng Guo --- docs/api/utils.rst | 6 +- docs/conf.py | 4 ++ verl/utils/checkpoint/checkpoint_manager.py | 15 ++++- .../checkpoint/fsdp_checkpoint_manager.py | 56 +++++++++++++++---- verl/utils/dataset/rl_dataset.py | 30 ++++++++-- verl/utils/seqlen_balancing.py | 9 +++ verl/utils/torch_functional.py | 51 ++++++++++++++++- verl/utils/ulysses.py | 15 +++++ 8 files changed, 164 insertions(+), 22 deletions(-) diff --git a/docs/api/utils.rst b/docs/api/utils.rst index caeae4f5c21..3ac4380b039 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -19,7 +19,7 @@ Tracking Utilities --------------------- .. automodule:: verl.utils.tracking - :members: Tracking, ValidationGenerationsLogger + :members: Tracking Metrics Utilities --------------------- @@ -31,7 +31,7 @@ Checkpoint Management ------------------------ .. automodule:: verl.utils.checkpoint.checkpoint_manager - :members: fins_latest_ckpt_path + :members: find_latest_ckpt_path .. automodule:: verl.utils.checkpoint.fsdp_checkpoint_manager :members: FSDPCheckpointManager @@ -52,7 +52,7 @@ Sequence Length Balancing ---------------------------- .. automodule:: verl.utils.seqlen_balancing - :members: get_reverse_idx, rearrage_micro_batches + :members: get_reverse_idx, rearrange_micro_batches Ulysses Utilities -------------------- diff --git a/docs/conf.py b/docs/conf.py index fe8cf2a5dbf..829a5ed8e71 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,7 +48,11 @@ "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.autosectionlabel", + "sphinx.ext.napoleon", ] +# Use Google style docstrings instead of NumPy docstrings. +napoleon_google_docstring = True +napoleon_numpy_docstring = False # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: diff --git a/verl/utils/checkpoint/checkpoint_manager.py b/verl/utils/checkpoint/checkpoint_manager.py index c9ac414f370..076a319bbca 100644 --- a/verl/utils/checkpoint/checkpoint_manager.py +++ b/verl/utils/checkpoint/checkpoint_manager.py @@ -22,6 +22,7 @@ import torch.distributed from filelock import FileLock from transformers import PreTrainedTokenizer, ProcessorMixin + from verl.utils.device import is_cuda_available, is_npu_available @@ -124,7 +125,7 @@ def load_rng_state(rng_state): torch.set_rng_state(rng_state["cpu"]) np.random.set_state(rng_state["numpy"]) random.setstate(rng_state["random"]) - + if is_cuda_available: torch.cuda.set_rng_state(rng_state["cuda"]) elif is_npu_available: @@ -132,6 +133,18 @@ def load_rng_state(rng_state): def find_latest_ckpt_path(path, directory_format="global_step_{}"): + """ + Return the most recent checkpoint directory based on a tracker file. + + Args: + path (str): Base directory containing the checkpoint tracker. + directory_format (str): Template for checkpoint subfolders with one + placeholder for the iteration number (default "global_step_{}"). + + Returns: + str or None: Full path to the latest checkpoint directory, or + None if the tracker or checkpoint folder is missing. + """ if path is None: return None diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index b556f298412..f5980129e91 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -23,8 +23,8 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin -from verl.utils.fs import copy_to_local, is_non_local from verl.utils.device import is_cuda_available +from verl.utils.fs import copy_to_local, is_non_local from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx from .checkpoint_manager import BaseCheckpointManager @@ -32,17 +32,20 @@ class FSDPCheckpointManager(BaseCheckpointManager): """ - A checkpoint manager that saves and loads - - model - - optimizer - - lr_scheduler - - extra_states - in a SPMD way. - - We save - - sharded model states and optimizer states - - full lr_scheduler states - - huggingface tokenizer/processor and config for ckpt merge + Manage FSDP checkpointing in SPMD training. + + - Saves/loads per-rank sharded model & optimizer states + - Persists full lr_scheduler and RNG state + - Stores HF tokenizer/processor and model/config for unified restore + + Args: + model (FSDP): Wrapped model instance. + optimizer (Optimizer): Training optimizer. + lr_scheduler (LRScheduler): Learning-rate scheduler. + processing_class (PreTrainedTokenizer or ProcessorMixin, optional): + Pre-/post-processing artifact handler. + checkpoint_contents (list[str], optional): + Components to include; must contain 'model', 'optimizer', 'extra'. """ def __init__( @@ -71,6 +74,18 @@ def __init__( ) def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): + """ + Load an FSDP checkpoint for this rank. + + Downloads and loads: + - model and optimizer shards + - extra state dict (scheduler + RNG) + + Args: + local_path: Directory with per-rank checkpoint files. + hdfs_path: Unused (for API compatibility). + del_local_after_load: Remove local files after loading. + """ if local_path is None: return @@ -112,6 +127,23 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): + """ + Save an FSDP checkpoint for this rank. + + Writes: + - model & optimizer shard files + - extra state dict (scheduler + RNG) + - HF tokenizer/processor and model/config on rank 0 + - optional full HF model under 'huggingface/' if requested + + Rotates old checkpoints, keeping at most `max_ckpt_to_keep`. + + Args: + local_path: Target directory for checkpoint files. + hdfs_path: Unused (for API compatibility). + global_step: Current training step (used for bookkeeping). + max_ckpt_to_keep: Number of recent checkpoints to retain. + """ if local_path is None: return diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index a7cd183945b..e952af5e057 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -35,7 +35,17 @@ def collate_fn(data_list: list[dict]) -> dict: - """Collate a batch of data.""" + """ + Collate a batch of sample dicts into batched tensors and arrays. + + Args: + data_list: List of dicts mapping feature names to torch.Tensor or other values. + + Returns: + Dict where tensor entries are stacked into a torch.Tensor of shape + (batch_size, *dims) and non-tensor entries are converted to + np.ndarray of dtype object with shape (batch_size,). + """ tensors = defaultdict(list) non_tensors = defaultdict(list) @@ -57,7 +67,19 @@ def collate_fn(data_list: list[dict]) -> dict: class RLHFDataset(Dataset): """ - We assume the dataset contains a column that contains prompts and other information + Load and preprocess RLHF data from Parquet files. + + - Caches files locally. + - Reads into a HuggingFace Dataset and tokenizes prompts. + - Optionally handles images/videos via a ProcessorMixin. + - Filters prompts over a max length. + - Supports resuming from checkpoints. + + Args: + data_files (str or list): Path(s) to Parquet file(s). + tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs. + config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc. + processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos. """ def __init__( @@ -247,10 +269,10 @@ def __getitem__(self, item): # encode prompts without chat template if self.return_raw_chat: row_dict["raw_prompt"] = messages - + # get prompts with chat template if self.return_full_prompt: - row_dict["full_prompts"] = raw_prompt # array of strings + row_dict["full_prompts"] = raw_prompt # array of strings # add index for each prompt index = row_dict.get("extra_info", {}).get("index", 0) diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index 5063cc385af..4da331858cc 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -271,6 +271,15 @@ def rearrange_micro_batches(batch, max_token_len, dp_group=None, same_micro_num_ def get_reverse_idx(idx_map): + """ + Build the inverse of an index mapping. + + Args: + idx_map (Sequence[int]): Sequence where idx_map[i] = j. + + Returns: + List[int]: Inverse mapping list such that output[j] = i for each i. + """ reverse_idx_map = copy.deepcopy(idx_map) for i, idx in enumerate(idx_map): diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 9754d989344..e728758d49b 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -53,7 +53,20 @@ def gather_from_labels(data, label): def logprobs_from_logits(logits, labels, inplace_backward=True): """ + Compute per-token log-probabilities for the given labels. + + Uses a Flash-Attention–based cross-entropy (if available) for efficient backward, + otherwise falls back to a standard log-softmax+gather approach. + See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 + + Args: + logits (Tensor): Model outputs of shape (..., vocab_size). + labels (LongTensor): True class indices of shape matching logits[..., :-1]. + inplace_backward (bool): If True and Flash-Attn is available, perform backward in-place. + + Returns: + Tensor: Log-probabilities of the target labels, shape logits.shape[:-1]. """ if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: batch_dim = logits.shape[:-1] @@ -121,7 +134,18 @@ def masked_sum(values, mask, axis=None): def masked_mean(values, mask, axis=None): - """Compute mean of tensor with a masked values.""" + """ + Compute the mean of `values` over elements selected by `mask`. + + Args: + values (Tensor): Input tensor. + mask (Tensor): Boolean or numeric mask of the same shape as `values`. + axis (int or tuple of int, optional): Dimension(s) along which to compute the mean. + Defaults to None (over all elements). + + Returns: + Tensor: Masked mean, with shape equal to `values` reduced over `axis`. + """ return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8) @@ -144,7 +168,18 @@ def masked_var(values, mask, unbiased=True): def masked_whiten(values, mask, shift_mean=True): - """Whiten values with masked values.""" + """ + Whiten `values` by normalizing with mean and variance computed over `mask`. + + Args: + values (torch.Tensor): Input tensor. + mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats. + shift_mean (bool): If True (default), output is zero-mean; + if False, the original mean is re-added after scaling. + + Returns: + torch.Tensor: Whitened tensor of same shape as `values`. + """ mean, var = masked_mean(values, mask), masked_var(values, mask) whitened = (values - mean) * torch.rsqrt(var + 1e-8) if not shift_mean: @@ -472,6 +507,18 @@ def get_constant_schedule_with_warmup( num_warmup_steps: int, last_epoch: int = -1, ): + """ + Create a constant LR schedule with a linear warmup phase. + + Args: + optimizer (Optimizer): Wrapped optimizer. + num_warmup_steps (int): Number of steps to ramp up the LR from 0 to initial value. + last_epoch (int, optional): The index of the last epoch when resuming training. Defaults to -1. + + Returns: + LambdaLR: Scheduler that increases LR linearly during warmup, then holds it constant. + """ + def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) diff --git a/verl/utils/ulysses.py b/verl/utils/ulysses.py index bf587081d32..a33293364f1 100644 --- a/verl/utils/ulysses.py +++ b/verl/utils/ulysses.py @@ -242,6 +242,21 @@ def gather_outpus_and_unpad( grad_scaler: bool = True, group: Optional[dist.ProcessGroup] = None, ): + """ + Gather a tensor across a process group and optionally unpad its padded elements. + + Args: + x (Tensor): Input tensor to gather. + gather_dim (int): Dimension along which to gather across ranks. + unpad_dim (int, optional): Dimension from which to remove padding. If None, no unpadding. + padding_size (int): Number of padding elements to remove on `unpad_dim`. Defaults to 0. + grad_scaler (bool): Whether to apply gradient scaling during gather. Defaults to True. + group (ProcessGroup, optional): Process group for gathering. If None, uses + `get_ulysses_sequence_parallel_group()`. If still None, returns `x` unchanged. + + Returns: + Tensor: The gathered tensor, with padding removed if requested. + """ group = get_ulysses_sequence_parallel_group() if group is None else group if group is None: return x