Skip to content
10 changes: 8 additions & 2 deletions recipe/one_step_off_policy/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers import AutoConfig

from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
from verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass
from verl.utils.debug import DistProfiler, DistProfilerExtension, log_gpu_memory_usage
from verl.utils.device import (
Expand Down Expand Up @@ -184,6 +184,12 @@ def init_model(self):
rollout_device_mesh = init_device_mesh(
device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)

is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
self._register_dispatch_collect_info(
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)

rollout_name = self.config.rollout.name
assert rollout_name == "vllm"

Expand Down Expand Up @@ -214,7 +220,7 @@ def init_model(self):
self.rollout = rollout
self.rollout_sharding_manager = rollout_sharding_manager

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"), blocking=False)
def async_generate_sequences(self, *args, **kwargs):
return super().generate_sequences(*args, **kwargs)

Expand Down
25 changes: 13 additions & 12 deletions recipe/spin/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
from verl.utils import hf_tokenizer
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device
Expand Down Expand Up @@ -167,7 +167,7 @@ def init_model(self):
checkpoint_config=self.config.actor.checkpoint,
)

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref

Expand All @@ -180,10 +180,8 @@ def compute_ref_log_prob(self, data: DataProto):
data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)

output = output.to("cpu")

Expand All @@ -194,7 +192,7 @@ def compute_ref_log_prob(self, data: DataProto):

return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
def compute_log_prob(self, data: DataProto):
assert self._is_actor
if self._is_offload_param:
Expand All @@ -209,12 +207,10 @@ def compute_log_prob(self, data: DataProto):
data.meta_info["temperature"] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.actor.compute_log_prob(data=data)
output = DataProto.from_dict(
tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature}
)
output = self.ulysses_sharding_manager.postprocess_data(output)

output = output.to("cpu")

Expand All @@ -229,7 +225,7 @@ def compute_log_prob(self, data: DataProto):
log_gpu_memory_usage("After compute_log_prob", logger=logger)
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
def update_actor_dpo(self, data: DataProto):
"""
Wrapper for actor update step. Handles FSDP state management.
Expand All @@ -253,8 +249,6 @@ def update_actor_dpo(self, data: DataProto):

# --- Ulysses Sharding (if used) ---
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)

# --- Call the core update method (now containing DPO logic) ---
with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name
# Calls the modified update_policy method
Expand Down Expand Up @@ -282,7 +276,6 @@ def update_actor_dpo(self, data: DataProto):

# --- Prepare Output ---
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")

# --- FSDP State Management (Offload) ---
Expand Down Expand Up @@ -323,6 +316,14 @@ def __init__(self, config):
get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
)

if self.ulysses_device_mesh is not None:
is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0
self._register_dispatch_collect_info(
"reward", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
else:
self._register_dispatch_collect_info("reward", dp_rank=self.rank, is_collect=True)

self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)

self.use_remove_padding = self.config.model.get("use_remove_padding", False)
Expand Down Expand Up @@ -539,7 +540,7 @@ def _switch_chat_template(self, data: DataProto):

return DataProto.from_dict(rm_inputs)

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward"))
def compute_rm_score(self, data: DataProto):
import itertools

Expand Down
10 changes: 6 additions & 4 deletions verl/single_controller/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ class Worker(WorkerHelper):
"""

fused_worker_attr_name = "fused_worker_dict"
__dispatch_dp_rank = {}
__collect_dp_rank = {}

def __new__(cls, *args, **kwargs):
"""Create a new Worker instance with proper initialization based on environment settings."""
Expand Down Expand Up @@ -102,6 +100,8 @@ def _register_dispatch_collect_info(self, mesh_name: str, dp_rank: int, is_colle
is_collect (bool):
Whether the dp_rank is used for collect.
"""
if mesh_name in self.__dispatch_dp_rank or mesh_name in self.__collect_dp_rank:
raise ValueError(f"mesh_name {mesh_name} has been registered")
Comment on lines +103 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This check is a good addition for safety. However, it highlights a more fundamental issue with the current implementation. __dispatch_dp_rank and __collect_dp_rank are defined as class attributes on the Worker class (lines 73-74), which means they are shared across all Worker instances within the same process.

This will lead to issues when multiple workers are instantiated in the same process (e.g., an actor worker and a critic worker), as they will attempt to write to the same shared dictionaries. This will either raise a ValueError due to this new check or, worse, lead to silent bugs from overwritten dispatch information.

These attributes should be instance-specific. The correct fix is to initialize them as instance attributes in Worker.__init__:

# In Worker.__init__
self.__dispatch_dp_rank = {}
self.__collect_dp_rank = {}

And remove the class-level definitions. Since __init__ and the class attribute definitions are not part of this diff, I cannot suggest the change directly, but this is a critical issue that needs to be addressed to ensure correctness.

self.__dispatch_dp_rank[mesh_name] = dp_rank
self.__collect_dp_rank[mesh_name] = is_collect

Expand All @@ -117,7 +117,7 @@ def _query_dispatch_info(self, mesh_name: str):
int:
The dp_rank for the given mesh name.
"""
assert mesh_name in self.__dispatch_dp_rank
assert mesh_name in self.__dispatch_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}"
# note that each rank store its own dp_rank
return self.__dispatch_dp_rank[mesh_name]

Expand All @@ -133,7 +133,7 @@ def _query_collect_info(self, mesh_name: str):
bool:
Whether the dp_rank is used for collect.
"""
assert mesh_name in self.__collect_dp_rank
assert mesh_name in self.__collect_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}"
return self.__collect_dp_rank[mesh_name]

def _configure_before_init(self, register_center_name: str, rank: int):
Expand Down Expand Up @@ -219,6 +219,8 @@ def __init__(self, cuda_visible_devices=None) -> None:
self._configure_with_store(store=store)

self.fused_worker_dict = {}
self.__dispatch_dp_rank = {}
self.__collect_dp_rank = {}

def get_fused_worker_by_name(self, worker_name: str):
"""Get a fused worker by its name.
Expand Down
13 changes: 9 additions & 4 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,11 @@ def update_policy(self, data: DataProto):
entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode

if self.config.use_dynamic_bsz:
loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size
else:
loss_scale_factor = 1 / self.gradient_accumulation

# all return: (bsz, response_length)
calculate_entropy = False
if entropy_coeff != 0:
Expand Down Expand Up @@ -449,19 +454,19 @@ def update_policy(self, data: DataProto):
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item()
micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor
micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef

if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size)
loss = policy_loss * loss_scale_factor
else:
loss = policy_loss / self.gradient_accumulation
loss = policy_loss * loss_scale_factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question, policy_loss is mean loss of all tokens of micro batch samples, why we need loss_scale_factor here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to divide gradient accumulation

loss.backward()

micro_batch_metrics.update(
{
"actor/pg_loss": pg_loss.detach().item(),
"actor/pg_loss": pg_loss.detach().item() * loss_scale_factor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

You've correctly scaled the pg_loss metric by loss_scale_factor to be consistent with the loss value used for backpropagation. This is a good improvement for metric correctness.

However, there's an inconsistency. If self.config.use_kl_loss is true, kl_loss is also a component of the total policy_loss, but it is logged without being scaled by loss_scale_factor (on line 452). For consistent and interpretable metrics, all reported loss components should be scaled in the same way as the total loss. Please consider scaling kl_loss as well when it's added to micro_batch_metrics.

"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
Expand Down
8 changes: 5 additions & 3 deletions verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,17 @@ def update_critic(self, data: DataProto):
)
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = vf_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size)
loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size
loss = vf_loss * loss_scale_factor
else:
loss = vf_loss / self.gradient_accumulation
loss_scale_factor = 1 / self.gradient_accumulation
loss = vf_loss * loss_scale_factor

loss.backward()

micro_batch_metrics.update(
{
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_loss": vf_loss.detach().item() * loss_scale_factor,
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
}
Expand Down
Loading