Skip to content

Commit de7f3f8

Browse files
authored
[misc] sync feat from upstream (#113)
1 parent 4a640d9 commit de7f3f8

15 files changed

Lines changed: 394 additions & 630 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ indent-width = 4
1414

1515
[tool.ruff.lint]
1616
ignore = ["C901", "E501", "E741", "W605", "C408"]
17-
select = ["C", "E", "F", "I", "W"]
17+
select = ["C", "E", "F", "I", "W", "RUF022"]
1818

1919
[tool.ruff.lint.per-file-ignores]
2020
"__init__.py" = ["E402", "F401", "F403", "F811"]

verl/protocol.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
import torch
2929
from numpy.typing import NDArray
3030
from tensordict import TensorDict
31+
from torch.distributed import ProcessGroup
3132
from torch.utils.data import DataLoader
3233

3334
from .utils.py_functional import union_two_dict
35+
from .utils.torch_functional import allgather_dict_tensors
3436

3537

3638
try:
@@ -620,3 +622,15 @@ def get(self):
620622
if self.dispatch_fn is not None:
621623
output = self.dispatch_fn(output) # split in batch dim, select using dp
622624
return output
625+
626+
627+
def all_gather_data_proto(data: DataProto, size: int, group: ProcessGroup) -> None:
628+
# Note that this is an inplace operator just like torch.distributed.all_gather
629+
prev_device = data.batch.device
630+
data.batch = data.batch.cuda(device=torch.cuda.current_device())
631+
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=size, group=group, dim=0)
632+
data.batch = data.batch.to(prev_device)
633+
# all gather non_tensor_batch
634+
all_non_tensor_batch = [None for _ in range(size)]
635+
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
636+
data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}

verl/single_controller/base/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
1717

1818

19-
__all__ = ["Worker", "WorkerGroup", "ClassWithInitArgs", "ResourcePool"]
19+
__all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]

verl/single_controller/ray/base.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,26 @@ def func(*args, **kwargs):
4949
return func
5050

5151

52+
def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]:
53+
"""
54+
Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.
55+
56+
FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK
57+
to be consistent across nodes when resume from checkpoint.
58+
59+
With this function, if there's only one resource pool and there's no node change, RANK should be consistent
60+
across nodes in multiple ray jobs, even if the whole ray cluster is restarted.
61+
"""
62+
node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()}
63+
pg_ip = {}
64+
for pg in pgs:
65+
specs = ray._private.state.state.placement_group_table(pg.id)
66+
# all bunles should be on the same node
67+
node_id = specs["bundles_to_node_id"][0]
68+
pg_ip[pg.id] = node_ip[node_id]
69+
return sorted(pgs, key=lambda pg: pg_ip[pg.id])
70+
71+
5272
class RayResourcePool(ResourcePool):
5373
def __init__(
5474
self,
@@ -231,8 +251,8 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d
231251
num_gpus = 1 / resource_pool.max_collocate_count
232252

233253
rank = -1
234-
for pg_idx, local_world_size in enumerate(resource_pool.store):
235-
pg = pgs[pg_idx]
254+
local_world_size = resource_pool.store[0]
255+
for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)):
236256
assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the "
237257
for local_rank in range(local_world_size):
238258
rank += 1

verl/trainer/main.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,13 @@
2828
from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
2929

3030

31-
def main():
32-
cli_args = OmegaConf.from_cli()
33-
file_config = OmegaConf.load(cli_args.config)
34-
del cli_args.config
35-
36-
default_config = OmegaConf.structured(PPOConfig())
37-
ppo_config = OmegaConf.merge(default_config, file_config, cli_args)
38-
ppo_config = OmegaConf.to_object(ppo_config)
39-
40-
if not ray.is_initialized():
41-
# this is for local ray cluster
42-
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
43-
44-
ray.get(main_task.remote(ppo_config))
45-
46-
47-
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
31+
@ray.remote(num_cpus=1)
4832
def main_task(config: PPOConfig):
33+
# please make sure main_task is not scheduled on head
34+
# print config
4935
config.deep_post_init()
5036
print(json.dumps(config.to_dict(), indent=2))
37+
5138
# instantiate tokenizer
5239
tokenizer = get_tokenizer(
5340
config.worker.actor.model.model_path,
@@ -67,7 +54,6 @@ def main_task(config: PPOConfig):
6754
Role.Critic: ray.remote(FSDPWorker),
6855
Role.RefPolicy: ray.remote(FSDPWorker),
6956
}
70-
7157
global_pool_id = "global_pool"
7258
resource_pool_spec = {
7359
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
@@ -77,6 +63,7 @@ def main_task(config: PPOConfig):
7763
Role.Critic: global_pool_id,
7864
Role.RefPolicy: global_pool_id,
7965
}
66+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
8067

8168
reward_fn = CustomRewardManager(
8269
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
@@ -85,8 +72,6 @@ def main_task(config: PPOConfig):
8572
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
8673
)
8774

88-
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
89-
9075
trainer = RayPPOTrainer(
9176
config=config,
9277
tokenizer=tokenizer,
@@ -101,5 +86,21 @@ def main_task(config: PPOConfig):
10186
trainer.fit()
10287

10388

89+
def main():
90+
cli_args = OmegaConf.from_cli()
91+
file_config = OmegaConf.load(getattr(cli_args, "config"))
92+
cli_args.pop("config", None)
93+
94+
default_config = OmegaConf.structured(PPOConfig())
95+
ppo_config = OmegaConf.merge(default_config, file_config, cli_args)
96+
ppo_config = OmegaConf.to_object(ppo_config)
97+
98+
if not ray.is_initialized():
99+
# this is for local ray cluster
100+
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
101+
102+
ray.get(main_task.remote(ppo_config))
103+
104+
104105
if __name__ == "__main__":
105106
main()

verl/trainer/metrics.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, List
16+
17+
import numpy as np
18+
import torch
19+
20+
from ..protocol import DataProto
21+
22+
23+
def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
24+
response_length = batch.batch["responses"].shape[-1]
25+
prompt_mask = batch.batch["attention_mask"][:, :-response_length]
26+
response_mask = batch.batch["attention_mask"][:, -response_length:]
27+
prompt_length = prompt_mask.sum(-1).float()
28+
response_length = response_mask.sum(-1).float() # (batch_size,)
29+
return dict(
30+
response_mask=response_mask,
31+
prompt_length=prompt_length,
32+
response_length=response_length,
33+
)
34+
35+
36+
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
37+
return {key: np.mean(value) for key, value in metrics.items()}
38+
39+
40+
def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]:
41+
sequence_score = batch.batch["token_level_scores"].sum(-1)
42+
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
43+
44+
advantages = batch.batch["advantages"]
45+
returns = batch.batch["returns"]
46+
47+
max_response_length = batch.batch["responses"].size(-1)
48+
49+
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
50+
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
51+
52+
max_prompt_length = prompt_mask.size(-1)
53+
54+
response_info = _compute_response_info(batch)
55+
prompt_length = response_info["prompt_length"]
56+
response_length = response_info["response_length"]
57+
58+
valid_adv = torch.masked_select(advantages, response_mask)
59+
valid_returns = torch.masked_select(returns, response_mask)
60+
61+
if use_critic:
62+
values = batch.batch["values"]
63+
valid_values = torch.masked_select(values, response_mask)
64+
return_diff_var = torch.var(valid_returns - valid_values)
65+
return_var = torch.var(valid_returns)
66+
67+
metrics = {
68+
# score
69+
"critic/score/mean": torch.mean(sequence_score).detach().item(),
70+
"critic/score/max": torch.max(sequence_score).detach().item(),
71+
"critic/score/min": torch.min(sequence_score).detach().item(),
72+
# reward
73+
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
74+
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
75+
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
76+
# adv
77+
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
78+
"critic/advantages/max": torch.max(valid_adv).detach().item(),
79+
"critic/advantages/min": torch.min(valid_adv).detach().item(),
80+
# returns
81+
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
82+
"critic/returns/max": torch.max(valid_returns).detach().item(),
83+
"critic/returns/min": torch.min(valid_returns).detach().item(),
84+
**(
85+
{
86+
# values
87+
"critic/values/mean": torch.mean(valid_values).detach().item(),
88+
"critic/values/max": torch.max(valid_values).detach().item(),
89+
"critic/values/min": torch.min(valid_values).detach().item(),
90+
# vf explained var
91+
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
92+
}
93+
if use_critic
94+
else {}
95+
),
96+
# response length
97+
"response_length/mean": torch.mean(response_length).detach().item(),
98+
"response_length/max": torch.max(response_length).detach().item(),
99+
"response_length/min": torch.min(response_length).detach().item(),
100+
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
101+
.detach()
102+
.item(),
103+
# prompt length
104+
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
105+
"prompt_length/max": torch.max(prompt_length).detach().item(),
106+
"prompt_length/min": torch.min(prompt_length).detach().item(),
107+
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
108+
}
109+
return metrics
110+
111+
112+
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
113+
response_info = _compute_response_info(batch)
114+
num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
115+
num_response_tokens = torch.sum(response_info["response_length"]).item()
116+
num_overall_tokens = num_prompt_tokens + num_response_tokens
117+
num_tokens_of_section = {
118+
"gen": num_response_tokens,
119+
**{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]},
120+
}
121+
return {
122+
**{f"timing_s/{name}": value for name, value in timing_raw.items()},
123+
**{
124+
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
125+
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
126+
},
127+
}
128+
129+
130+
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
131+
total_num_tokens = sum(batch.meta_info["global_token_num"])
132+
time = timing_raw["step"]
133+
return {
134+
"perf/total_num_tokens": total_num_tokens,
135+
"perf/time_per_step": time,
136+
"perf/throughput": total_num_tokens / (time * n_gpus),
137+
}

0 commit comments

Comments
 (0)