diff --git a/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh b/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh index 626f6349749..b4a18ac0206 100644 --- a/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh +++ b/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh @@ -18,7 +18,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=4 \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ @@ -34,7 +34,7 @@ python3 -m verl.trainer.main_ppo \ critic.ppo_micro_batch_size_per_gpu=4 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ - critic.model.fsdp_config.fsdp_size=-1 \ + critic.model.fsdp_config.fsdp_size=4 \ reward_model.enable=True \ reward_model.ulysses_sequence_parallel_size=2 \ reward_model.model.path=Qwen/Qwen2.5-0.5B\ diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 975061475a2..7fe299dd4c4 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -204,7 +204,8 @@ def _build_model_optimizer(self): apply_monkey_patch(config, verbose=True) # This may be very large - init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings) + init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings, + mesh=self.device_mesh) with init_context(): self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path, diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 26b7dbd521a..b3f5b73534e 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -19,6 +19,7 @@ import itertools import os from contextlib import contextmanager +from torch.distributed import DeviceMesh from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp._runtime_utils import _lazy_init @@ -35,11 +36,14 @@ def init_fn(x: torch.nn.Module): return x -def get_init_weight_context_manager(use_meta_tensor=True): +def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None): from accelerate import init_empty_weights cpu_init_weights = lambda: torch.device('cpu') if use_meta_tensor: - init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights + if mesh is None: + init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights + else: + init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights else: init_context = cpu_init_weights return init_context diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index c6b7f5a95e2..f2c0719aa69 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -49,9 +49,6 @@ def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) else: - raise ValueError( - 'HSDP is not supported yet because it produces incorrect results for now. Please set fsdp_size=-1') - assert world_size % fsdp_size == 0 device_mesh = init_device_mesh('cuda', mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=['ddp', 'fsdp']) @@ -117,10 +114,10 @@ def __init__(self, config: DictConfig, role: str): # normalize config if self._is_actor: self.config.actor.ppo_mini_batch_size *= self.config.rollout.n - self.config.actor.ppo_mini_batch_size //= (self.device_mesh.shape[0] // self.ulysses_sequence_parallel_size) + self.config.actor.ppo_mini_batch_size //= (self.device_mesh.size() // self.ulysses_sequence_parallel_size) # micro bsz if self.config.actor.ppo_micro_batch_size is not None: - self.config.actor.ppo_micro_batch_size //= (self.device_mesh.shape[0] // + self.config.actor.ppo_micro_batch_size //= (self.device_mesh.size() // self.ulysses_sequence_parallel_size) self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, \ @@ -130,12 +127,12 @@ def __init__(self, config: DictConfig, role: str): # normalize rollout config if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: - self.config.rollout.log_prob_micro_batch_size //= (self.device_mesh.shape[0] // + self.config.rollout.log_prob_micro_batch_size //= (self.device_mesh.size() // self.ulysses_sequence_parallel_size) self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size # normalize ref config if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: - self.config.ref.log_prob_micro_batch_size //= (self.device_mesh.shape[0] // + self.config.ref.log_prob_micro_batch_size //= (self.device_mesh.size() // self.ulysses_sequence_parallel_size) self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size @@ -195,7 +192,8 @@ def _build_model_optimizer(self, print(f'Model config after override: {actor_model_config}') # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang - init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings) + init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings, + mesh=self.device_mesh) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -687,7 +685,7 @@ def _build_critic_model_optimizer(self, config): from verl.models.transformers.monkey_patch import apply_monkey_patch apply_monkey_patch(critic_model_config, verbose=True) - init_context = get_init_weight_context_manager() + init_context = get_init_weight_context_manager(True, mesh=self.device_mesh) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") setattr(critic_model_config, 'classifier_dropout', 0.) @@ -946,7 +944,8 @@ def _build_model(self, config): apply_monkey_patch(model_config, verbose=True) # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect - init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings) + init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, + mesh=self.device_mesh) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore")