Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
actor_rollout_ref.actor.fsdp_config.fsdp_size=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
Expand All @@ -34,7 +34,7 @@ python3 -m verl.trainer.main_ppo \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
critic.model.fsdp_config.fsdp_size=-1 \
critic.model.fsdp_config.fsdp_size=4 \
reward_model.enable=True \
reward_model.ulysses_sequence_parallel_size=2 \
reward_model.model.path=Qwen/Qwen2.5-0.5B\
Expand Down
3 changes: 2 additions & 1 deletion verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def _build_model_optimizer(self):
apply_monkey_patch(config, verbose=True)

# This may be very large
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings)
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings,
mesh=self.device_mesh)

with init_context():
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,
Expand Down
8 changes: 6 additions & 2 deletions verl/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import itertools
import os
from contextlib import contextmanager
from torch.distributed import DeviceMesh
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._runtime_utils import _lazy_init
Expand All @@ -35,11 +36,14 @@ def init_fn(x: torch.nn.Module):
return x


def get_init_weight_context_manager(use_meta_tensor=True):
def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None):
from accelerate import init_empty_weights
cpu_init_weights = lambda: torch.device('cpu')
if use_meta_tensor:
init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights
if mesh is None:
init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights
else:
init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights
else:
init_context = cpu_init_weights
return init_context
Expand Down
19 changes: 9 additions & 10 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ def create_device_mesh(world_size, fsdp_size):
if fsdp_size < 0 or fsdp_size >= world_size:
device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp'])
else:
raise ValueError(
'HSDP is not supported yet because it produces incorrect results for now. Please set fsdp_size=-1')
assert world_size % fsdp_size == 0
device_mesh = init_device_mesh('cuda',
mesh_shape=(world_size // fsdp_size, fsdp_size),
mesh_dim_names=['ddp', 'fsdp'])
Expand Down Expand Up @@ -117,10 +114,10 @@ def __init__(self, config: DictConfig, role: str):
# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
self.config.actor.ppo_mini_batch_size //= (self.device_mesh.shape[0] // self.ulysses_sequence_parallel_size)
self.config.actor.ppo_mini_batch_size //= (self.device_mesh.size() // self.ulysses_sequence_parallel_size)
# micro bsz
if self.config.actor.ppo_micro_batch_size is not None:
self.config.actor.ppo_micro_batch_size //= (self.device_mesh.shape[0] //
self.config.actor.ppo_micro_batch_size //= (self.device_mesh.size() //
self.ulysses_sequence_parallel_size)
self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size
assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, \
Expand All @@ -130,12 +127,12 @@ def __init__(self, config: DictConfig, role: str):

# normalize rollout config
if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
self.config.rollout.log_prob_micro_batch_size //= (self.device_mesh.shape[0] //
self.config.rollout.log_prob_micro_batch_size //= (self.device_mesh.size() //
self.ulysses_sequence_parallel_size)
self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size
# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
self.config.ref.log_prob_micro_batch_size //= (self.device_mesh.shape[0] //
self.config.ref.log_prob_micro_batch_size //= (self.device_mesh.size() //
self.ulysses_sequence_parallel_size)
self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size

Expand Down Expand Up @@ -195,7 +192,8 @@ def _build_model_optimizer(self,
print(f'Model config after override: {actor_model_config}')

# NOTE(fix me): tie_word_embedding causes meta_tensor init to hang
init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings)
init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings,
mesh=self.device_mesh)

with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand Down Expand Up @@ -687,7 +685,7 @@ def _build_critic_model_optimizer(self, config):
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(critic_model_config, verbose=True)

init_context = get_init_weight_context_manager()
init_context = get_init_weight_context_manager(True, mesh=self.device_mesh)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
setattr(critic_model_config, 'classifier_dropout', 0.)
Expand Down Expand Up @@ -946,7 +944,8 @@ def _build_model(self, config):
apply_monkey_patch(model_config, verbose=True)

# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings)
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings,
mesh=self.device_mesh)

with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand Down
Loading