Skip to content

Commit 61e4397

Browse files
authored
[misc] refactor: deprecate sharding manager (part 1) (volcengine#2912)
### What does this PR do? - Since we introduce register device_mesh inside the worker, there is no need to use sharding manager any longer. We will remove the usage of sharding manager gradually in the main branch. - This PR removes the sharding manager usage inside fsdp_workers ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 7e3bd6a commit 61e4397

File tree

7 files changed

+126
-84
lines changed

7 files changed

+126
-84
lines changed

recipe/one_step_off_policy/fsdp_workers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from transformers import AutoConfig
2525

2626
from verl.single_controller.base import Worker
27-
from verl.single_controller.base.decorator import Dispatch, register
27+
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
2828
from verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass
2929
from verl.utils.debug import DistProfiler, DistProfilerExtension, log_gpu_memory_usage
3030
from verl.utils.device import (
@@ -184,6 +184,12 @@ def init_model(self):
184184
rollout_device_mesh = init_device_mesh(
185185
device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
186186
)
187+
188+
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
189+
self._register_dispatch_collect_info(
190+
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
191+
)
192+
187193
rollout_name = self.config.rollout.name
188194
assert rollout_name == "vllm"
189195

@@ -214,7 +220,7 @@ def init_model(self):
214220
self.rollout = rollout
215221
self.rollout_sharding_manager = rollout_sharding_manager
216222

217-
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
223+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"), blocking=False)
218224
def async_generate_sequences(self, *args, **kwargs):
219225
return super().generate_sequences(*args, **kwargs)
220226

recipe/spin/fsdp_workers.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import verl.utils.torch_functional as verl_F
3030
from verl import DataProto
3131
from verl.single_controller.base import Worker
32-
from verl.single_controller.base.decorator import Dispatch, register
32+
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
3333
from verl.utils import hf_tokenizer
3434
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
3535
from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device
@@ -167,7 +167,7 @@ def init_model(self):
167167
checkpoint_config=self.config.actor.checkpoint,
168168
)
169169

170-
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
170+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
171171
def compute_ref_log_prob(self, data: DataProto):
172172
assert self._is_ref
173173

@@ -180,10 +180,8 @@ def compute_ref_log_prob(self, data: DataProto):
180180
data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu
181181
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
182182
with self.ulysses_sharding_manager:
183-
data = self.ulysses_sharding_manager.preprocess_data(data)
184183
output = self.ref_policy.compute_log_prob(data=data)
185184
output = DataProto.from_dict(tensors={"ref_log_prob": output})
186-
output = self.ulysses_sharding_manager.postprocess_data(output)
187185

188186
output = output.to("cpu")
189187

@@ -194,7 +192,7 @@ def compute_ref_log_prob(self, data: DataProto):
194192

195193
return output
196194

197-
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
195+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
198196
def compute_log_prob(self, data: DataProto):
199197
assert self._is_actor
200198
if self._is_offload_param:
@@ -209,12 +207,10 @@ def compute_log_prob(self, data: DataProto):
209207
data.meta_info["temperature"] = self.config.rollout.temperature
210208
# perform recompute log_prob
211209
with self.ulysses_sharding_manager:
212-
data = self.ulysses_sharding_manager.preprocess_data(data)
213210
output = self.actor.compute_log_prob(data=data)
214211
output = DataProto.from_dict(
215212
tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature}
216213
)
217-
output = self.ulysses_sharding_manager.postprocess_data(output)
218214

219215
output = output.to("cpu")
220216

@@ -229,7 +225,7 @@ def compute_log_prob(self, data: DataProto):
229225
log_gpu_memory_usage("After compute_log_prob", logger=logger)
230226
return output
231227

232-
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
228+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
233229
def update_actor_dpo(self, data: DataProto):
234230
"""
235231
Wrapper for actor update step. Handles FSDP state management.
@@ -253,8 +249,6 @@ def update_actor_dpo(self, data: DataProto):
253249

254250
# --- Ulysses Sharding (if used) ---
255251
with self.ulysses_sharding_manager:
256-
data = self.ulysses_sharding_manager.preprocess_data(data=data)
257-
258252
# --- Call the core update method (now containing DPO logic) ---
259253
with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name
260254
# Calls the modified update_policy method
@@ -282,7 +276,6 @@ def update_actor_dpo(self, data: DataProto):
282276

283277
# --- Prepare Output ---
284278
output = DataProto(meta_info={"metrics": metrics})
285-
output = self.ulysses_sharding_manager.postprocess_data(data=output)
286279
output = output.to("cpu")
287280

288281
# --- FSDP State Management (Offload) ---
@@ -323,6 +316,14 @@ def __init__(self, config):
323316
get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
324317
)
325318

319+
if self.ulysses_device_mesh is not None:
320+
is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0
321+
self._register_dispatch_collect_info(
322+
"reward", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect
323+
)
324+
else:
325+
self._register_dispatch_collect_info("reward", dp_rank=self.rank, is_collect=True)
326+
326327
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
327328

328329
self.use_remove_padding = self.config.model.get("use_remove_padding", False)
@@ -539,7 +540,7 @@ def _switch_chat_template(self, data: DataProto):
539540

540541
return DataProto.from_dict(rm_inputs)
541542

542-
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
543+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward"))
543544
def compute_rm_score(self, data: DataProto):
544545
import itertools
545546

verl/single_controller/base/worker.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ class Worker(WorkerHelper):
7070
"""
7171

7272
fused_worker_attr_name = "fused_worker_dict"
73-
__dispatch_dp_rank = {}
74-
__collect_dp_rank = {}
7573

7674
def __new__(cls, *args, **kwargs):
7775
"""Create a new Worker instance with proper initialization based on environment settings."""
@@ -102,6 +100,8 @@ def _register_dispatch_collect_info(self, mesh_name: str, dp_rank: int, is_colle
102100
is_collect (bool):
103101
Whether the dp_rank is used for collect.
104102
"""
103+
if mesh_name in self.__dispatch_dp_rank or mesh_name in self.__collect_dp_rank:
104+
raise ValueError(f"mesh_name {mesh_name} has been registered")
105105
self.__dispatch_dp_rank[mesh_name] = dp_rank
106106
self.__collect_dp_rank[mesh_name] = is_collect
107107

@@ -117,7 +117,7 @@ def _query_dispatch_info(self, mesh_name: str):
117117
int:
118118
The dp_rank for the given mesh name.
119119
"""
120-
assert mesh_name in self.__dispatch_dp_rank
120+
assert mesh_name in self.__dispatch_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}"
121121
# note that each rank store its own dp_rank
122122
return self.__dispatch_dp_rank[mesh_name]
123123

@@ -133,7 +133,7 @@ def _query_collect_info(self, mesh_name: str):
133133
bool:
134134
Whether the dp_rank is used for collect.
135135
"""
136-
assert mesh_name in self.__collect_dp_rank
136+
assert mesh_name in self.__collect_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}"
137137
return self.__collect_dp_rank[mesh_name]
138138

139139
def _configure_before_init(self, register_center_name: str, rank: int):
@@ -219,6 +219,8 @@ def __init__(self, cuda_visible_devices=None) -> None:
219219
self._configure_with_store(store=store)
220220

221221
self.fused_worker_dict = {}
222+
self.__dispatch_dp_rank = {}
223+
self.__collect_dp_rank = {}
222224

223225
def get_fused_worker_by_name(self, worker_name: str):
224226
"""Get a fused worker by its name.

verl/workers/actor/dp_actor.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,11 @@ def update_policy(self, data: DataProto):
410410
entropy_coeff = self.config.entropy_coeff
411411
loss_agg_mode = self.config.loss_agg_mode
412412

413+
if self.config.use_dynamic_bsz:
414+
loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size
415+
else:
416+
loss_scale_factor = 1 / self.gradient_accumulation
417+
413418
# all return: (bsz, response_length)
414419
calculate_entropy = False
415420
if entropy_coeff != 0:
@@ -449,19 +454,19 @@ def update_policy(self, data: DataProto):
449454
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
450455

451456
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
452-
micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item()
457+
micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor
453458
micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef
454459

455460
if self.config.use_dynamic_bsz:
456461
# relative to the dynamic bsz
457-
loss = policy_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size)
462+
loss = policy_loss * loss_scale_factor
458463
else:
459-
loss = policy_loss / self.gradient_accumulation
464+
loss = policy_loss * loss_scale_factor
460465
loss.backward()
461466

462467
micro_batch_metrics.update(
463468
{
464-
"actor/pg_loss": pg_loss.detach().item(),
469+
"actor/pg_loss": pg_loss.detach().item() * loss_scale_factor,
465470
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
466471
"actor/ppo_kl": ppo_kl.detach().item(),
467472
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),

verl/workers/critic/dp_critic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,15 +238,17 @@ def update_critic(self, data: DataProto):
238238
)
239239
if self.config.use_dynamic_bsz:
240240
# relative to the dynamic bsz
241-
loss = vf_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size)
241+
loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size
242+
loss = vf_loss * loss_scale_factor
242243
else:
243-
loss = vf_loss / self.gradient_accumulation
244+
loss_scale_factor = 1 / self.gradient_accumulation
245+
loss = vf_loss * loss_scale_factor
244246

245247
loss.backward()
246248

247249
micro_batch_metrics.update(
248250
{
249-
"critic/vf_loss": vf_loss.detach().item(),
251+
"critic/vf_loss": vf_loss.detach().item() * loss_scale_factor,
250252
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
251253
"critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
252254
}

0 commit comments

Comments
 (0)