Skip to content

Commit 34e409b

Browse files
[docs] refactor: Adding doc strings and doc pages for public methods in trainer and utils (#1397)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? * This PR adds doc string for the public methods inside `trainer` and `utils` module, so that these methods can be reused and referenced better. * Two new doc page `PPO Trainer Interface` and `Utilities` were also provided under the API Reference section. * Renamed one function `verl.utils._default_compute_score` to `verl.utils.default_compute_score`, as it was an external function used by other modules, i.e., trainer and recipe; <img width="1093" alt="Screenshot 2025-05-26 at 9 20 31 PM" src="https://github.com/user-attachments/assets/e361e6bd-a33b-426b-85b4-9fe93ab1e398" /> ### TODO This is the second of a series of PRs to improve and stabilize the docs and API. Stacked on top of #1396 TODO includes adding more useful utility functions to the doc with improved doc strings. ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if neccessary. --------- Signed-off-by: Hongpeng Guo <[email protected]> Co-authored-by: H <[email protected]>
1 parent 4d3ca21 commit 34e409b

28 files changed

+580
-98
lines changed

docs/api/single_controller.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,7 @@ Core APIs
2222
.. autoclass:: verl.single_controller.ResourcePool
2323
:members: __init__, world_size, local_world_size_list, local_rank_list
2424

25-
.. automodule:: verl.single_controller.ray
26-
:members: RayWorkerGroup, create_colocated_worker_cls
25+
.. autoclass:: verl.single_controller.ray.RayWorkerGroup
26+
:members: __init__
27+
28+
.. autofunction:: verl.single_controller.ray.create_colocated_worker_cls

docs/api/trainer.rst

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Trainers
2-
=========================
1+
Trainer Interface
2+
================================
33

44
Trainers drive the training loop. Introducing new trainer classes in case of new training paradiam is encouraged.
55

@@ -13,9 +13,16 @@ Core APIs
1313
~~~~~~~~~~~~~~~~~
1414

1515
.. autoclass:: verl.trainer.ppo.ray_trainer.RayPPOTrainer
16+
:members: __init__, init_workers, fit
17+
1618

1719
.. automodule:: verl.utils.tokenizer
1820
:members: hf_tokenizer
1921

20-
.. automodule:: verl.single_controller
21-
:members: Worker, WorkerGroup, ClassWithInitArgs, ResourcePool
22+
23+
.. automodule:: verl.trainer.ppo.core_algos
24+
:members: agg_loss, kl_penalty, compute_policy_loss, kl_penalty
25+
26+
27+
.. automodule:: verl.trainer.ppo.reward
28+
:members: load_reward_manager, compute_reward, compute_reward_async

docs/api/utils.rst

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,74 @@
1-
Training utils
2-
=========================
1+
Utilities
2+
============
33

4-
Core APIs
5-
~~~~~~~~~~~~~~~~~
4+
This section documents the utility functions and classes in the VERL library.
5+
6+
Python Functional Utilities
7+
------------------------------
8+
9+
.. automodule:: verl.utils.py_functional
10+
:members: append_to_dict
11+
12+
File System Utilities
13+
------------------------
14+
15+
.. automodule:: verl.utils.fs
16+
:members: copy_to_local
17+
18+
Tracking Utilities
19+
---------------------
20+
21+
.. automodule:: verl.utils.tracking
22+
:members: Tracking
23+
24+
Metrics Utilities
25+
---------------------
626

727
.. automodule:: verl.utils.metric
828
:members: reduce_metrics
29+
30+
Checkpoint Management
31+
------------------------
32+
33+
.. automodule:: verl.utils.checkpoint.checkpoint_manager
34+
:members: find_latest_ckpt_path
35+
36+
.. automodule:: verl.utils.checkpoint.fsdp_checkpoint_manager
37+
:members: FSDPCheckpointManager
38+
39+
Dataset Utilities
40+
---------------------
41+
42+
.. automodule:: verl.utils.dataset.rl_dataset
43+
:members: RLHFDataset, collate_fn
44+
45+
Torch Functional Utilities
46+
-----------------------------
47+
48+
.. automodule:: verl.utils.torch_functional
49+
:members: get_constant_schedule_with_warmup, masked_whiten, masked_mean, logprobs_from_logits
50+
51+
Sequence Length Balancing
52+
----------------------------
53+
54+
.. automodule:: verl.utils.seqlen_balancing
55+
:members: get_reverse_idx, rearrange_micro_batches
56+
57+
Ulysses Utilities
58+
--------------------
59+
60+
.. automodule:: verl.utils.ulysses
61+
:members: gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
62+
63+
FSDP Utilities
64+
------------------
65+
66+
.. automodule:: verl.utils.fsdp_utils
67+
:members: get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer,
68+
69+
Debug Utilities
70+
-------------------
71+
72+
.. automodule:: verl.utils.debug
73+
:members: log_gpu_memory_usage, GPUMemoryLogger
74+

docs/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@
4848
"sphinx.ext.autodoc",
4949
"sphinx.ext.autosummary",
5050
"sphinx.ext.autosectionlabel",
51+
"sphinx.ext.napoleon",
5152
]
53+
# Use Google style docstrings instead of NumPy docstrings.
54+
napoleon_google_docstring = True
55+
napoleon_numpy_docstring = False
5256

5357
# The suffix(es) of source filenames.
5458
# You can specify multiple suffix as a list of string:

docs/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ verl is fast with:
108108
:caption: API References
109109

110110
api/data
111-
api/utils
112111
api/single_controller.rst
112+
api/trainer.rst
113+
api/utils.rst
113114

114115

115116
.. toctree::

tests/ray_cpu/test_ray_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import ray
17+
18+
from verl.utils.ray_utils import parallel_put
19+
20+
21+
# Initialize Ray for testing if not already done globally
22+
@pytest.fixture()
23+
def init_ray():
24+
ray.init(num_cpus=4)
25+
yield
26+
ray.shutdown()
27+
28+
29+
def test_parallel_put_basic(init_ray):
30+
data = [1, "hello", {"a": 2}, [3, 4]]
31+
refs = parallel_put(data)
32+
assert len(refs) == len(data)
33+
retrieved_data = [ray.get(ref) for ref in refs]
34+
assert retrieved_data == data
35+
36+
37+
def test_parallel_put_empty(init_ray):
38+
data = []
39+
refs = parallel_put(data)
40+
assert len(refs) == 0
41+
42+
43+
def test_parallel_put_workers(init_ray):
44+
data = list(range(20))
45+
# Test with specific number of workers
46+
refs = parallel_put(data, max_workers=4)
47+
assert len(refs) == len(data)
48+
retrieved_data = [ray.get(ref) for ref in refs]
49+
assert retrieved_data == data
50+
# Test with default workers (should cap)
51+
refs_default = parallel_put(data)
52+
assert len(refs_default) == len(data)
53+
retrieved_data_default = [ray.get(ref) for ref in refs_default]
54+
assert retrieved_data_default == data

tests/sandbox/test_sandbox.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import pytest
2020

21-
from verl.utils.reward_score import _default_compute_score, prime_code, sandbox_fusion
21+
from verl.utils.reward_score import default_compute_score, prime_code, sandbox_fusion
2222
from verl.utils.reward_score.prime_code import apps_check_correctness
2323
from verl.workers.reward_manager.prime import parallel_compute_score_async
2424

@@ -109,7 +109,7 @@ def test_parallelism():
109109
ground_truth.extend(prime_math_gts)
110110
data_sources.extend(["numina_aops_forum"] * len(prime_math_answers))
111111

112-
scores = asyncio.run(parallel_compute_score_async(_default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16))
112+
scores = asyncio.run(parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16))
113113
print(scores)
114114

115115

@@ -119,7 +119,7 @@ def test_prime_code():
119119
"""
120120
data_source = "codecontests"
121121
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores):
122-
score = _default_compute_score(data_source, completion, ground_truth)
122+
score = default_compute_score(data_source, completion, ground_truth)
123123
assert float(score) == score_
124124

125125

@@ -135,7 +135,7 @@ def test_prime_code_sandbox_fusion():
135135
# Removed the previous 'if not sandbox_url' check block
136136

137137
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores):
138-
score = _default_compute_score(data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url}) # <-- Use the URL obtained from the environment variable
138+
score = default_compute_score(data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url}) # <-- Use the URL obtained from the environment variable
139139
assert float(score) == score_
140140

141141

@@ -153,7 +153,7 @@ def test_continuous_score_consistency():
153153
prime_score, _ = sandbox_fusion.compute_score(os.environ.get("SANDBOX_FUSION_URL"), None, completion, ground_truth, continuous=True)
154154

155155
# 2. Calculate score using sandbox_fusion with continuous=True
156-
# Ensure the extra_info key triggers the sandbox_fusion path in _default_compute_score
156+
# Ensure the extra_info key triggers the sandbox_fusion path in default_compute_score
157157
fusion_score, _ = prime_code.compute_score(completion, ground_truth, continuous=True)
158158

159159
# 3. Assert scores are equal (using pytest.approx for float comparison)
@@ -175,5 +175,5 @@ def test_check_correctness():
175175
def test_prime_math():
176176
data_source = "numina_aops_forum"
177177
for completion, ground_truth in zip(prime_math_answers, prime_math_gts):
178-
score = _default_compute_score(data_source, completion, ground_truth)
178+
score = default_compute_score(data_source, completion, ground_truth)
179179
assert float(score) == 1.0
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from verl.utils.megatron.pipeline_parallel import make_batch_generator
16+
17+
18+
def test_make_batch_generator_no_vpp():
19+
batches = [1, 2, 3]
20+
vpp_size = 1
21+
generator = make_batch_generator(batches, vpp_size)
22+
assert list(generator) == batches
23+
24+
25+
def test_make_batch_generator_with_vpp():
26+
batches = [{"data": 1}, {"data": 2}]
27+
vpp_size = 2
28+
generators = make_batch_generator(batches, vpp_size)
29+
assert isinstance(generators, list)
30+
assert len(generators) == vpp_size
31+
32+
# Check each generator yields the original batches
33+
for gen in generators:
34+
assert list(gen) == batches
35+
36+
37+
def test_make_batch_generator_empty():
38+
batches = []
39+
vpp_size = 1
40+
generator = make_batch_generator(batches, vpp_size)
41+
assert list(generator) == []
42+
43+
vpp_size = 3
44+
generators = make_batch_generator(batches, vpp_size)
45+
assert len(generators) == vpp_size
46+
for gen in generators:
47+
assert list(gen) == []

verl/trainer/ppo/core_algos.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ def compute_gae_advantage_return(
7575
7676
Args:
7777
token_level_rewards: `(torch.Tensor)`
78-
shape: (bs, response_length)
78+
shape is (bs, response_length)
7979
values: `(torch.Tensor)`
80-
shape: (bs, response_length)
80+
shape is (bs, response_length)
8181
response_mask: `(torch.Tensor)`
82-
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
83-
gamma: `(float)`
82+
shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
83+
gamma is `(float)`
8484
discounted factor used in RL
8585
lam: `(float)`
8686
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
@@ -122,19 +122,19 @@ def compute_grpo_outcome_advantage(
122122
(with only one scalar reward for each response).
123123
Args:
124124
token_level_rewards: `(torch.Tensor)`
125-
shape: (bs, response_length)
125+
shape is (bs, response_length)
126126
response_mask: `(torch.Tensor)`
127-
shape: (bs, response_length)
127+
shape is (bs, response_length)
128128
norm_adv_by_std_in_grpo: (bool)
129129
whether to scale the GRPO advantage.
130130
If True, the advantage is scaled by the std, as in the original GRPO.
131131
If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
132132
133133
Returns:
134134
advantages: `(torch.Tensor)`
135-
shape: (bs, response_length)
135+
shape is (bs, response_length)
136136
Returns: `(torch.Tensor)`
137-
shape: (bs, response_length)
137+
shape is (bs, response_length)
138138
"""
139139
scores = token_level_rewards.sum(dim=-1)
140140

@@ -371,15 +371,12 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str
371371
"""
372372
Aggregate the loss matrix into a scalar.
373373
Args:
374-
loss_mat: `(torch.Tensor)`
374+
loss_mat: `(torch.Tensor)`:
375375
shape: (bs, response_length)
376-
loss_mask: `(torch.Tensor)`
376+
loss_mask: `(torch.Tensor)`:
377377
shape: (bs, response_length)
378-
loss_agg_mode: (str) choices: "token-mean" /
379-
"seq-mean-token-sum" /
380-
"seq-mean-token-mean" /
381-
"seq-mean-token-sum-norm" /
382-
"token-mean" is the default behavior
378+
loss_agg_mode: (str) choices:
379+
method to aggregate the loss matrix into a scalar.
383380
Returns:
384381
loss: `a scalar torch.Tensor`
385382
aggregated loss

0 commit comments

Comments
 (0)