Skip to content

Commit 22c878d

Browse files
[trainer] refactor: make main_ppo TaskRunner more modular (volcengine#2885)
### What does this PR do? - Added `__init__()` method to initialize `self.role_worker_mapping = {}` - Extracted worker setup logic into dedicated methods: - `add_actor_rollout_worker()` - handles strategy-specific worker imports and setup (lines 130-153) - `add_critic_worker()` - sets up critic worker role mapping (lines 170-176) - `init_resource_pool_mgr()` - creates resource pool specifications (lines 178-187) - `add_reward_model_worker()` - conditionally adds reward model workers (lines 195-203) - `add_ref_policy_worker()` - conditionally adds reference policy workers (lines 205-208) ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test relying on existing unit tests ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
1 parent f930cab commit 22c878d

File tree

1 file changed

+108
-73
lines changed

1 file changed

+108
-73
lines changed

verl/trainer/main_ppo.py

Lines changed: 108 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,109 @@ class TaskRunner:
8989
9090
This class encapsulates the main training logic and runs as a Ray remote actor
9191
to enable distributed execution across multiple nodes and GPUs.
92+
93+
Attributes:
94+
role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes
95+
mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation
9296
"""
9397

98+
def __init__(self):
99+
self.role_worker_mapping = {}
100+
self.mapping = {}
101+
102+
def add_actor_rollout_worker(self, config):
103+
"""Add actor rollout worker based on the actor strategy."""
104+
from verl.single_controller.ray import RayWorkerGroup
105+
106+
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
107+
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
108+
109+
actor_rollout_cls = (
110+
AsyncActorRolloutRefWorker
111+
if config.actor_rollout_ref.rollout.mode == "async"
112+
else ActorRolloutRefWorker
113+
)
114+
ray_worker_group_cls = RayWorkerGroup
115+
116+
elif config.actor_rollout_ref.actor.strategy == "megatron":
117+
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
118+
119+
actor_rollout_cls = (
120+
AsyncActorRolloutRefWorker
121+
if config.actor_rollout_ref.rollout.mode == "async"
122+
else ActorRolloutRefWorker
123+
)
124+
ray_worker_group_cls = RayWorkerGroup
125+
126+
else:
127+
raise NotImplementedError
128+
129+
from verl.trainer.ppo.ray_trainer import Role
130+
131+
self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls)
132+
133+
return actor_rollout_cls, ray_worker_group_cls
134+
135+
def add_critic_worker(self, config):
136+
"""Add critic worker to role mapping."""
137+
if config.critic.strategy in {"fsdp", "fsdp2"}:
138+
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
139+
if use_legacy_worker_impl in ["auto", "enable"]:
140+
from verl.workers.fsdp_workers import CriticWorker
141+
elif use_legacy_worker_impl == "disable":
142+
from verl.workers.roles import CriticWorker
143+
144+
print("Using new worker implementation")
145+
else:
146+
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
147+
148+
elif config.critic.strategy == "megatron":
149+
from verl.workers.megatron_workers import CriticWorker
150+
151+
else:
152+
raise NotImplementedError
153+
154+
from verl.trainer.ppo.ray_trainer import Role
155+
156+
self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker)
157+
158+
def init_resource_pool_mgr(self, config):
159+
"""Initialize resource pool manager."""
160+
from verl.trainer.ppo.ray_trainer import Role
161+
162+
global_pool_id = "global_pool"
163+
resource_pool_spec = {
164+
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
165+
}
166+
self.mapping[Role.ActorRollout] = global_pool_id
167+
self.mapping[Role.Critic] = global_pool_id
168+
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
169+
170+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping)
171+
return resource_pool_manager
172+
173+
def add_reward_model_worker(self, config):
174+
"""Add reward model worker if enabled."""
175+
from verl.trainer.ppo.ray_trainer import Role
176+
177+
if config.reward_model.enable:
178+
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
179+
from verl.workers.fsdp_workers import RewardModelWorker
180+
elif config.reward_model.strategy == "megatron":
181+
from verl.workers.megatron_workers import RewardModelWorker
182+
else:
183+
raise NotImplementedError
184+
self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
185+
self.mapping[Role.RewardModel] = "global_pool"
186+
187+
def add_ref_policy_worker(self, config, ref_policy_cls):
188+
"""Add reference policy worker if KL loss or KL reward is used."""
189+
from verl.trainer.ppo.ray_trainer import Role
190+
191+
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
192+
self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls)
193+
self.mapping[Role.RefPolicy] = "global_pool"
194+
94195
def run(self, config):
95196
"""Execute the main PPO training workflow.
96197
@@ -126,86 +227,19 @@ def run(self, config):
126227
# Used for multimodal LLM, could be None
127228
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
128229

129-
# Define worker classes based on the actor strategy.
130-
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
131-
assert config.critic.strategy in {"fsdp", "fsdp2"}
132-
from verl.single_controller.ray import RayWorkerGroup
133-
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
134-
135-
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
136-
if use_legacy_worker_impl in ["auto", "enable"]:
137-
# import warnings
138-
# warnings.warn(f"Legacy worker impl is going to be deprecated, will be removed in the future. \
139-
# Please set trainer.use_legacy_worker_impl = false to switch to the new worker implementation.")
140-
from verl.workers.fsdp_workers import CriticWorker
141-
elif use_legacy_worker_impl == "disable":
142-
from verl.workers.roles import CriticWorker
143-
144-
print("Using new worker implementation")
145-
else:
146-
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
147-
148-
actor_rollout_cls = (
149-
AsyncActorRolloutRefWorker
150-
if config.actor_rollout_ref.rollout.mode == "async"
151-
else ActorRolloutRefWorker
152-
)
153-
ray_worker_group_cls = RayWorkerGroup
154-
155-
elif config.actor_rollout_ref.actor.strategy == "megatron":
156-
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
157-
from verl.single_controller.ray import RayWorkerGroup
158-
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
159-
160-
actor_rollout_cls = (
161-
AsyncActorRolloutRefWorker
162-
if config.actor_rollout_ref.rollout.mode == "async"
163-
else ActorRolloutRefWorker
164-
)
165-
ray_worker_group_cls = RayWorkerGroup
166-
167-
else:
168-
raise NotImplementedError
169-
170-
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
171-
172-
# Map roles to their corresponding remote worker classes.
173-
role_worker_mapping = {
174-
Role.ActorRollout: ray.remote(actor_rollout_cls),
175-
Role.Critic: ray.remote(CriticWorker),
176-
}
177-
178-
# Define the resource pool specification.
179-
# Map roles to the resource pool.
180-
global_pool_id = "global_pool"
181-
resource_pool_spec = {
182-
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
183-
}
184-
mapping = {
185-
Role.ActorRollout: global_pool_id,
186-
Role.Critic: global_pool_id,
187-
}
230+
actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)
231+
self.add_critic_worker(config)
188232

189233
# We should adopt a multi-source reward function here:
190234
# - for rule-based rm, we directly call a reward score
191235
# - for model-based rm, we call a model
192236
# - for code related prompt, we send to a sandbox if there are test cases
193237
# finally, we combine all the rewards together
194238
# The reward type depends on the tag of the data
195-
if config.reward_model.enable:
196-
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
197-
from verl.workers.fsdp_workers import RewardModelWorker
198-
elif config.reward_model.strategy == "megatron":
199-
from verl.workers.megatron_workers import RewardModelWorker
200-
else:
201-
raise NotImplementedError
202-
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
203-
mapping[Role.RewardModel] = global_pool_id
239+
self.add_reward_model_worker(config)
204240

205241
# Add a reference policy worker if KL loss or KL reward is used.
206-
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
207-
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
208-
mapping[Role.RefPolicy] = global_pool_id
242+
self.add_ref_policy_worker(config, actor_rollout_cls)
209243

210244
# Load the reward manager for training and validation.
211245
reward_fn = load_reward_manager(
@@ -214,7 +248,8 @@ def run(self, config):
214248
val_reward_fn = load_reward_manager(
215249
config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
216250
)
217-
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
251+
252+
resource_pool_manager = self.init_resource_pool_mgr(config)
218253

219254
from verl.utils.dataset.rl_dataset import collate_fn
220255

@@ -228,7 +263,7 @@ def run(self, config):
228263
config=config,
229264
tokenizer=tokenizer,
230265
processor=processor,
231-
role_worker_mapping=role_worker_mapping,
266+
role_worker_mapping=self.role_worker_mapping,
232267
resource_pool_manager=resource_pool_manager,
233268
ray_worker_group_cls=ray_worker_group_cls,
234269
reward_fn=reward_fn,

0 commit comments

Comments
 (0)