diff --git a/recipe/one_step_off_policy/fsdp_workers.py b/recipe/one_step_off_policy/fsdp_workers.py index 72036d6057c..a16b7b74345 100644 --- a/recipe/one_step_off_policy/fsdp_workers.py +++ b/recipe/one_step_off_policy/fsdp_workers.py @@ -24,7 +24,7 @@ from transformers import AutoConfig from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register from verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass from verl.utils.debug import DistProfiler, DistProfilerExtension, log_gpu_memory_usage from verl.utils.device import ( @@ -184,6 +184,12 @@ def init_model(self): rollout_device_mesh = init_device_mesh( device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] ) + + is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + rollout_name = self.config.rollout.name assert rollout_name == "vllm" @@ -214,7 +220,7 @@ def init_model(self): self.rollout = rollout self.rollout_sharding_manager = rollout_sharding_manager - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"), blocking=False) def async_generate_sequences(self, *args, **kwargs): return super().generate_sequences(*args, **kwargs) diff --git a/recipe/spin/fsdp_workers.py b/recipe/spin/fsdp_workers.py index 1ecb1e000af..0960a8abcc0 100644 --- a/recipe/spin/fsdp_workers.py +++ b/recipe/spin/fsdp_workers.py @@ -29,7 +29,7 @@ import verl.utils.torch_functional as verl_F from verl import DataProto from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register from verl.utils import hf_tokenizer from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device @@ -167,7 +167,7 @@ def init_model(self): checkpoint_config=self.config.actor.checkpoint, ) - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) def compute_ref_log_prob(self, data: DataProto): assert self._is_ref @@ -180,10 +180,8 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) output = self.ref_policy.compute_log_prob(data=data) output = DataProto.from_dict(tensors={"ref_log_prob": output}) - output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") @@ -194,7 +192,7 @@ def compute_ref_log_prob(self, data: DataProto): return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: @@ -209,12 +207,10 @@ def compute_log_prob(self, data: DataProto): data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) output = self.actor.compute_log_prob(data=data) output = DataProto.from_dict( tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature} ) - output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") @@ -229,7 +225,7 @@ def compute_log_prob(self, data: DataProto): log_gpu_memory_usage("After compute_log_prob", logger=logger) return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) def update_actor_dpo(self, data: DataProto): """ Wrapper for actor update step. Handles FSDP state management. @@ -253,8 +249,6 @@ def update_actor_dpo(self, data: DataProto): # --- Ulysses Sharding (if used) --- with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) - # --- Call the core update method (now containing DPO logic) --- with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name # Calls the modified update_policy method @@ -282,7 +276,6 @@ def update_actor_dpo(self, data: DataProto): # --- Prepare Output --- output = DataProto(meta_info={"metrics": metrics}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") # --- FSDP State Management (Offload) --- @@ -323,6 +316,14 @@ def __init__(self, config): get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "reward", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("reward", dp_rank=self.rank, is_collect=True) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self.use_remove_padding = self.config.model.get("use_remove_padding", False) @@ -539,7 +540,7 @@ def _switch_chat_template(self, data: DataProto): return DataProto.from_dict(rm_inputs) - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) def compute_rm_score(self, data: DataProto): import itertools diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index f02931fd717..3802f10655a 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -70,8 +70,6 @@ class Worker(WorkerHelper): """ fused_worker_attr_name = "fused_worker_dict" - __dispatch_dp_rank = {} - __collect_dp_rank = {} def __new__(cls, *args, **kwargs): """Create a new Worker instance with proper initialization based on environment settings.""" @@ -102,6 +100,8 @@ def _register_dispatch_collect_info(self, mesh_name: str, dp_rank: int, is_colle is_collect (bool): Whether the dp_rank is used for collect. """ + if mesh_name in self.__dispatch_dp_rank or mesh_name in self.__collect_dp_rank: + raise ValueError(f"mesh_name {mesh_name} has been registered") self.__dispatch_dp_rank[mesh_name] = dp_rank self.__collect_dp_rank[mesh_name] = is_collect @@ -117,7 +117,7 @@ def _query_dispatch_info(self, mesh_name: str): int: The dp_rank for the given mesh name. """ - assert mesh_name in self.__dispatch_dp_rank + assert mesh_name in self.__dispatch_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}" # note that each rank store its own dp_rank return self.__dispatch_dp_rank[mesh_name] @@ -133,7 +133,7 @@ def _query_collect_info(self, mesh_name: str): bool: Whether the dp_rank is used for collect. """ - assert mesh_name in self.__collect_dp_rank + assert mesh_name in self.__collect_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}" return self.__collect_dp_rank[mesh_name] def _configure_before_init(self, register_center_name: str, rank: int): @@ -219,6 +219,8 @@ def __init__(self, cuda_visible_devices=None) -> None: self._configure_with_store(store=store) self.fused_worker_dict = {} + self.__dispatch_dp_rank = {} + self.__collect_dp_rank = {} def get_fused_worker_by_name(self, worker_name: str): """Get a fused worker by its name. diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index e5c33082070..ce84ddf5053 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -410,6 +410,11 @@ def update_policy(self, data: DataProto): entropy_coeff = self.config.entropy_coeff loss_agg_mode = self.config.loss_agg_mode + if self.config.use_dynamic_bsz: + loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size + else: + loss_scale_factor = 1 / self.gradient_accumulation + # all return: (bsz, response_length) calculate_entropy = False if entropy_coeff != 0: @@ -449,19 +454,19 @@ def update_policy(self, data: DataProto): kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() + micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef if self.config.use_dynamic_bsz: # relative to the dynamic bsz - loss = policy_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size) + loss = policy_loss * loss_scale_factor else: - loss = policy_loss / self.gradient_accumulation + loss = policy_loss * loss_scale_factor loss.backward() micro_batch_metrics.update( { - "actor/pg_loss": pg_loss.detach().item(), + "actor/pg_loss": pg_loss.detach().item() * loss_scale_factor, "actor/pg_clipfrac": pg_clipfrac.detach().item(), "actor/ppo_kl": ppo_kl.detach().item(), "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index 43656dd2b41..c369fd39663 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -238,15 +238,17 @@ def update_critic(self, data: DataProto): ) if self.config.use_dynamic_bsz: # relative to the dynamic bsz - loss = vf_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size) + loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size + loss = vf_loss * loss_scale_factor else: - loss = vf_loss / self.gradient_accumulation + loss_scale_factor = 1 / self.gradient_accumulation + loss = vf_loss * loss_scale_factor loss.backward() micro_batch_metrics.update( { - "critic/vf_loss": vf_loss.detach().item(), + "critic/vf_loss": vf_loss.detach().item() * loss_scale_factor, "critic/vf_clipfrac": vf_clipfrac.detach().item(), "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), } diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 25824fc3c00..e7ae738c7f9 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -39,7 +39,7 @@ from verl import DataProto from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register from verl.utils import hf_processor, hf_tokenizer from verl.utils.activation_offload import enable_activation_offloading from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager @@ -144,6 +144,15 @@ def __init__(self, config: DictConfig, role: str, **kwargs): device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "actor", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("actor", dp_rank=self.rank, is_collect=True) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self._lora_rank = self.config.model.get("lora_rank", 0) self._is_lora = self._lora_rank > 0 @@ -475,6 +484,15 @@ def _build_rollout(self, trust_remote_code=False): device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] ) rollout_name = self.config.rollout.name + + if rollout_name == "hf": + self._register_dispatch_collect_info("rollout", dp_rank=self.rank, is_collect=True) + else: + is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + if rollout_name == "hf": from verl.workers.rollout import HFRollout from verl.workers.sharding_manager.base import BaseShardingManager @@ -678,7 +696,7 @@ def init_model(self): checkpoint_config=checkpoint_contents, ) - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): # Support all hardwares @@ -691,7 +709,6 @@ def update_actor(self, data: DataProto): load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) # perform training with Timer(name="update_policy", logger=None) as timer: metrics = self.actor.update_policy(data=data) @@ -712,7 +729,6 @@ def update_actor(self, data: DataProto): # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") if self._is_offload_param: @@ -724,7 +740,7 @@ def update_actor(self, data: DataProto): return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) @DistProfiler.annotate(color="red", role="rollout_generate") def generate_sequences(self, prompts: DataProto): # Support all hardwares @@ -745,14 +761,11 @@ def generate_sequences(self, prompts: DataProto): with self.rollout_sharding_manager: log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) - prompts = self.rollout_sharding_manager.preprocess_data(prompts) with simple_timer("generate_sequences", timing_generate): output = self.rollout.generate_sequences(prompts=prompts) log_gpu_memory_usage("After rollout generation", logger=logger) - output = self.rollout_sharding_manager.postprocess_data(output) - timing_generate.update(self.rollout_sharding_manager.timing) # We calculate the average timing across all ranks # to make sure meta_info["timing"] is the same @@ -764,7 +777,7 @@ def generate_sequences(self, prompts: DataProto): get_torch_device().empty_cache() return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") def compute_log_prob(self, data: DataProto): # when is_lora is True, we use the actor without lora applied to calculate the log_prob @@ -786,14 +799,12 @@ def compute_log_prob(self, data: DataProto): data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) with adapter_ctx: output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) output = DataProto.from_dict( tensors={"old_log_probs": output, "entropys": entropys}, meta_info={"temperature": self.config.rollout.temperature}, ) - output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") @@ -808,7 +819,7 @@ def compute_log_prob(self, data: DataProto): return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: DataProto): if self._is_lora: @@ -830,10 +841,8 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": output}) - output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") @@ -957,6 +966,15 @@ def __init__(self, config: FSDPCriticConfig): device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "critic", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("critic", dp_rank=self.rank, is_collect=True) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) # set FSDP offload params @@ -1221,7 +1239,7 @@ def init_model(self): checkpoint_config=self.config.checkpoint, ) - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) @DistProfiler.annotate(color="cyan") def compute_values(self, data: DataProto): # Support all hardwares @@ -1235,17 +1253,15 @@ def compute_values(self, data: DataProto): data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz # perform forward computation with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) values = self.critic.compute_values(data=data) output = DataProto.from_dict(tensors={"values": values}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) @DistProfiler.annotate(color="pink") def update_critic(self, data: DataProto): # Support all hardwares @@ -1257,8 +1273,6 @@ def update_critic(self, data: DataProto): # perform forward computation with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) - with Timer(name="update_critic", logger=None) as timer: metrics = self.critic.update_critic(data=data) delta_time = timer.last @@ -1272,7 +1286,6 @@ def update_critic(self, data: DataProto): self.critic_lr_scheduler.step() output = DataProto(batch=None, meta_info={"metrics": metrics}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) @@ -1355,6 +1368,15 @@ def __init__(self, config): self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "reward", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("reward", dp_rank=self.rank, is_collect=True) + self.use_remove_padding = self.config.model.get("use_remove_padding", False) # normalize config @@ -1603,7 +1625,7 @@ def _switch_chat_template(self, data: DataProto): return DataProto.from_dict(rm_inputs) - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) @DistProfiler.annotate(color="brown") def compute_rm_score(self, data: DataProto): import itertools @@ -1630,9 +1652,6 @@ def compute_rm_score(self, data: DataProto): # perform forward computation with self.ulysses_sharding_manager: - rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data) - data = self.ulysses_sharding_manager.preprocess_data(data=data) - use_dynamic_bsz = self.config.use_dynamic_bsz if use_dynamic_bsz: max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size @@ -1654,7 +1673,6 @@ def compute_rm_score(self, data: DataProto): token_level_scores = self._expand_to_token_level(data, scores) # Note that this is only the scores, may not be the final rewards used to train RL output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 50feaa8c11f..802434b67e6 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -195,14 +195,14 @@ def __init__(self, config: DictConfig, role: str, **kwargs): nccl_communicator_config_path=None, ) - is_collect = ( - mpu.get_tensor_model_parallel_rank() == 0 - and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 - and mpu.get_context_parallel_rank() == 0 - ) - self._register_dispatch_collect_info( - mesh_name="train", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect - ) + is_collect = ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 + and mpu.get_context_parallel_rank() == 0 + ) + self._register_dispatch_collect_info( + mesh_name="actor", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect + ) set_random_seed(seed=self.config.actor.megatron.seed) @@ -417,6 +417,11 @@ def _build_rollout(self, trust_remote_code=False): ) log_gpu_memory_usage("After building sharding manager", logger=logger) + is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + elif self.config.rollout.name == "sglang": from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout @@ -438,6 +443,11 @@ def _build_rollout(self, trust_remote_code=False): "cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp") ) + is_collect = rollout_device_mesh["tp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + local_path = copy_to_local(self.config.model.path) log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None) rollout = SGLangRollout( @@ -581,7 +591,7 @@ def init_model(self): get_torch_device().empty_cache() log_gpu_memory_usage("After init_model finish", logger=logger) - @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train")) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="update_actor", logger=logger) @DistProfiler.annotate(color="red") def update_actor(self, data: DataProto): @@ -625,7 +635,7 @@ def update_actor(self, data: DataProto): aggressive_empty_cache(force_sync=True) return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) @GPUMemoryLogger(role="generate_sequences", logger=logger) @DistProfiler.annotate(color="red") def generate_sequences(self, prompts: DataProto): @@ -646,10 +656,8 @@ def generate_sequences(self, prompts: DataProto): timing_generate = {} with self.sharding_manager: log_gpu_memory_usage("After entering sharding manager", logger=logger) - prompts = self.sharding_manager.preprocess_data(prompts) with simple_timer("generate_sequences", timing_generate): output = self.rollout.generate_sequences(prompts=prompts) - output = self.sharding_manager.postprocess_data(output) log_gpu_memory_usage("After rollout generation", logger=logger) timing_generate.update(self.sharding_manager.timing) @@ -662,7 +670,7 @@ def generate_sequences(self, prompts: DataProto): aggressive_empty_cache(force_sync=True) return output - @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train")) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) @DistProfiler.annotate(color="olive") def compute_ref_log_prob(self, data: DataProto): @@ -685,7 +693,7 @@ def compute_ref_log_prob(self, data: DataProto): aggressive_empty_cache(force_sync=True) return output - @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train")) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="compute_log_prob", logger=logger) @DistProfiler.annotate(color="blue") def compute_log_prob(self, data: DataProto): @@ -832,14 +840,14 @@ def __init__(self, config: McoreCriticConfig): nccl_communicator_config_path=None, ) - is_collect = ( - mpu.get_tensor_model_parallel_rank() == 0 - and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 - and mpu.get_context_parallel_rank() == 0 - ) - self._register_dispatch_collect_info( - mesh_name="train", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect - ) + is_collect = ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 + and mpu.get_context_parallel_rank() == 0 + ) + self._register_dispatch_collect_info( + mesh_name="critic", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect + ) set_random_seed(seed=self.config.megatron.seed) @@ -1011,7 +1019,7 @@ def init_model(self): use_dist_checkpointing=self.config.megatron.use_dist_checkpointing, ) - @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train")) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) @DistProfiler.annotate(color="cyan") def compute_values(self, data: DataProto): micro_batch_size = self.config.ppo_micro_batch_size_per_gpu @@ -1028,7 +1036,7 @@ def compute_values(self, data: DataProto): offload_megatron_model_to_cpu(self.critic_module) return output - @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train")) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) @DistProfiler.annotate(color="pink") def update_critic(self, data: DataProto): data = data.to(get_device_id()) @@ -1121,14 +1129,14 @@ def __init__(self, config): nccl_communicator_config_path=None, ) - is_collect = ( - mpu.get_tensor_model_parallel_rank() == 0 - and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 - and mpu.get_context_parallel_rank() == 0 - ) - self._register_dispatch_collect_info( - mesh_name="train", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect - ) + is_collect = ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 + and mpu.get_context_parallel_rank() == 0 + ) + self._register_dispatch_collect_info( + mesh_name="reward", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect + ) set_random_seed(seed=self.config.megatron.seed) @@ -1253,7 +1261,7 @@ def init_model(self): # TODO: reward model use itself tokenizer instead of sft tokenizer # the input_ids, responses, attention_mask and position_ids may be different! - @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train")) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) @DistProfiler.annotate(color="brown") def compute_rm_score(self, data: DataProto): data.meta_info["micro_batch_size"] = self.config.micro_batch_size_per_gpu