Skip to content
96 changes: 95 additions & 1 deletion tests/models/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import pytest
import ray
import torch
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification
import torch.distributed as dist
import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, Qwen3Config, Qwen3MoeConfig

from verl import DataProto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.trainer.config import CheckpointConfig
from verl.utils.model import compute_position_id_with_mask, create_random_mask
from verl.utils.torch_functional import logprobs_from_logits_naive
from verl.workers.config import (
Expand Down Expand Up @@ -267,3 +270,94 @@ def test_critic_engine(strategy):
print(ppo_metrics)

ray.shutdown()


def create_actor_model(tmp_path, config):
model = AutoModelForCausalLM.from_config(config)
path = os.path.join(tmp_path, "test_model")
model.save_pretrained(path)
config.save_pretrained(path)
return path


def _worker(rank: int, world_size: int, rendezvous_file: str, strategy: str, model_path: str):
torch.cuda.set_device(rank)
dist.init_process_group(
backend="nccl",
init_method=f"file://{rendezvous_file}",
rank=rank,
world_size=world_size,
)

with torch.device("meta"):
ref_model = AutoModelForCausalLM.from_pretrained(model_path)

from verl.workers.engine import BaseEngine, EngineRegistry

# construct configs
model_config = HFModelConfig(path=model_path, load_tokenizer=False)

if strategy == "megatron":
engine_config = McoreEngineConfig(
forward_only=False,
use_mbridge=True,
tensor_model_parallel_size=2,
pipeline_model_parallel_size=2,
context_parallel_size=1,
)
optimizer_config = McoreOptimizerConfig(lr_decay_steps=10)
elif strategy in ["fsdp", "fsdp2"]:
engine_config = FSDPEngineConfig(
forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2
)
optimizer_config = FSDPOptimizerConfig()
else:
raise NotImplementedError(f"strategy {strategy} is not supported")

checkpoint_config = CheckpointConfig()

# build model engine
engine: BaseEngine = EngineRegistry.new(
model_type="language_model",
backend=engine_config.strategy,
model_config=model_config,
engine_config=engine_config,
optimizer_config=optimizer_config,
checkpoint_config=checkpoint_config,
)

engine.initialize()

# get per tensor parameter
per_tensor_params = engine.get_per_tensor_param()

ref_state_dict = ref_model.state_dict()

# load ground truth and compare
for key, value in per_tensor_params:
assert key in ref_state_dict, f"{key} not in ref_state_dict"
assert value.shape == ref_state_dict[key].shape, (
f"{key} shape not equal, {value.shape} != {ref_state_dict[key].shape}"
)
if rank == 0:
print(key, value.shape)

dist.barrier()
dist.destroy_process_group()


@pytest.mark.parametrize("world_size", [8])
@pytest.mark.parametrize("config", [Qwen3Config(num_hidden_layers=2), Qwen3MoeConfig(num_hidden_layers=2)])
@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"])
def test_per_tensor_generator(world_size, tmp_path, config, strategy):
rendezvous_file = str(tmp_path / "rdzv_mask")
os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
# create a model
model_path = create_actor_model(tmp_path, config)
# spawn workers
mp.spawn(
fn=_worker,
args=(world_size, rendezvous_file, strategy, model_path),
nprocs=world_size,
join=True,
)
6 changes: 4 additions & 2 deletions verl/utils/kernel/linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import torch
import torch.distributed as dist

from . import kernels


class LinearCrossEntropy(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -66,6 +64,8 @@ def forward(
assert isinstance(temperature, float), f"temperature must be a float, but got {type(temperature)}"
assert isinstance(reduction, str), f"reduction must be a str, but got {type(reduction)}"
with torch.cuda.nvtx.range("LinearCrossEntropy-forward"):
from . import kernels

REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower())

original_hidden_shape = hidden.shape
Expand All @@ -88,6 +88,8 @@ def forward(

@staticmethod
def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> list[torch.Tensor]:
from . import kernels

with torch.cuda.nvtx.range("LinearCrossEntropy-backward"):
(hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors
REDUCTION = ctx.REDUCTION
Expand Down
3 changes: 3 additions & 0 deletions verl/workers/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def infer_batch(self, data: TensorDict, loss_function: Optional[Callable] = None
outputs = self.forward_backward_batch(data, loss_function, forward_only=True)
return outputs

def get_per_tensor_param(self):
raise NotImplementedError

def get_data_parallel_size(self):
raise NotImplementedError

Expand Down
58 changes: 58 additions & 0 deletions verl/workers/engine/fsdp/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from peft import LoraConfig, TaskType, get_peft_model
from tensordict import TensorDict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType
from torch.distributed.tensor import DTensor

import verl.utils.torch_functional as verl_F
from verl.models.transformers.monkey_patch import apply_monkey_patch
Expand All @@ -46,6 +48,7 @@
FSDPModule,
MixedPrecisionPolicy,
apply_fsdp2,
collect_lora_params,
fsdp2_clip_grad_norm_,
fsdp2_load_full_state_dict,
fsdp_version,
Expand All @@ -56,7 +59,9 @@
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
replace_lora_wrapper,
)
from verl.utils.model import convert_weight_keys
from verl.utils.py_functional import convert_to_regular_types
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
Expand Down Expand Up @@ -338,6 +343,20 @@ def _build_fsdp_module(self, module):
if self.model_config.enable_activation_offload:
enable_gradient_checkpointing = self.model_config.enable_gradient_checkpointing
enable_activation_offloading(module, self.engine_config.strategy, enable_gradient_checkpointing)

if torch.distributed.get_world_size() == 1 and fsdp_version(module) == 1:
FSDP.set_state_dict_type(
module,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(),
)
elif fsdp_version(module) == 1:
FSDP.set_state_dict_type(
module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)

return module

def _build_optimizer(self, module):
Expand Down Expand Up @@ -581,6 +600,45 @@ def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.optimizer)

def get_per_tensor_param(self, layered_summon=False, base_sync_done=False):
log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger)

if self._is_offload_param:
load_fsdp_model_to_gpu(self.module)

log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger)

peft_config = None
peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module)
if hasattr(peft_model, "peft_config"): # LoRA
peft_config = peft_model.peft_config.get("default", None)
params = collect_lora_params(
module=self.module,
layered_summon=layered_summon,
base_sync_done=base_sync_done,
)
if not base_sync_done:
params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()}
else:
params = self.module.state_dict()

params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))

log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.module)
log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger)

if peft_config is not None and base_sync_done:
per_tensor_param = params
else:
device = get_device_id() # used when fsdp2 set cpu_offload_policy
per_tensor_param = (
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
for name, param in params.items()
)
return per_tensor_param


class EngineEvalModeCtx:
def __init__(self, engine: FSDPEngine):
Expand Down
35 changes: 33 additions & 2 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from omegaconf import OmegaConf
from tensordict import TensorDict

from verl.models.mcore import get_mcore_weight_converter
from verl.trainer.config import CheckpointConfig
from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager
Expand All @@ -35,6 +36,7 @@
load_megatron_optimizer,
offload_megatron_model_to_cpu,
offload_megatron_optimizer,
per_tensor_generator,
)
from verl.utils.model import load_mcore_dist_weights, load_megatron_gptmodel_weights
from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig
Expand Down Expand Up @@ -72,6 +74,12 @@ def __init__(

self.mode = None

self.layer_name_mapping = {
"qkv_layer_name": "self_attention.linear_qkv.",
"gate_proj_layer_name": "linear_fc1.",
}
self.weight_converter = None

def _init_device_mesh(self):
mpu.initialize_model_parallel(
tensor_model_parallel_size=self.engine_config.tensor_model_parallel_size,
Expand Down Expand Up @@ -106,7 +114,11 @@ def _build_tf_config(self):
else:
self.bridge = None

print(f"TF config: {tf_config}")
if not self.bridge:
self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype)

if torch.distributed.get_rank() == 0:
print(f"TF config: {tf_config}")
self.tf_config = tf_config

def _build_megatron_module(self):
Expand All @@ -119,6 +131,8 @@ def _build_megatron_module(self):
or "ForSequenceClassification" in self.model_config.architectures[0]
)

self.is_value_model = is_value_model

if self.engine_config.forward_only:
wrap_with_ddp = False
else:
Expand Down Expand Up @@ -200,12 +214,14 @@ def initialize(self):

tmp_config = OmegaConf.create({"model": {"path": self.model_config.local_path}})

role = "actor" if not self.is_value_model else "critic"

self.checkpoint_mananager = MegatronCheckpointManager(
config=tmp_config,
checkpoint_config=self.checkpoint_config,
model_config=self.model_config.hf_config,
transformer_config=self.tf_config,
role="actor",
role=role,
model=self.module,
arch=self.model_config.architectures[0],
hf_config=self.model_config.hf_config,
Expand Down Expand Up @@ -413,6 +429,21 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
else:
return {}

def get_per_tensor_param(self):
if self._is_offload_param:
load_megatron_model_to_gpu(self.module, load_grad=False)
if self.bridge is not None:
per_tensor_param = self.bridge.export_weights(self.module)
else:
per_tensor_param = per_tensor_generator(
self.module,
self.model_config.hf_config,
self.weight_converter,
self.tf_config,
self.layer_name_mapping,
)
return per_tensor_param

def forward_step(self, batch_iter, model, postprocess_micro_batch_func):
raise NotImplementedError("forward_step must be implemented in subclass")

Expand Down
43 changes: 6 additions & 37 deletions verl/workers/roles/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
import os
from functools import partial
from typing import Iterable

import psutil
from codetiming import Timer
Expand Down Expand Up @@ -129,41 +128,6 @@ def compute_log_prob(self, data: DataProto):

return output

def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
"""Make minibatch iterator for updating the actor

Args:
data (DataProto): a DataProto containing keys

``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where
``sequence_length = prompt_length + response_length``

``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64

``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64

``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that
responses = input_ids[:, -response_length:]

``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability
of responses.

``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of
responses.
See PPO paper for details. https://arxiv.org/abs/1707.06347

Returns:

"""
# Note that we do not select data here. It's the user's responsibility to select data outside trainer
# it's very important to setup seed here. Otherwise, data in model parallel region can disagree and cause hangs
return data.make_iterator(
mini_batch_size=self.ppo_mini_batch_size_per_dp,
epochs=self.config.ppo_epochs,
seed=self.config.data_loader_seed + self.engine.get_data_parallel_rank(),
dataloader_kwargs={"shuffle": self.config.shuffle},
)

@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@DistProfiler.annotate(color="red", role="actor_update")
def update_actor(self, data: DataProto):
Expand All @@ -180,7 +144,12 @@ def update_actor(self, data: DataProto):
data = data.to(get_device_id())
# perform forward computation
with self.engine.train_mode():
dataloader = self._make_minibatch_iterator(data)
dataloader = data.make_iterator(
mini_batch_size=self.ppo_mini_batch_size_per_dp,
epochs=self.config.ppo_epochs,
seed=self.config.data_loader_seed + self.engine.get_data_parallel_rank(),
dataloader_kwargs={"shuffle": self.config.shuffle},
)
with Timer(name="update_policy", logger=None) as timer:
for batch_idx, mini_batch in enumerate(dataloader):
mini_batch.meta_info["global_batch_size"] = self.config.ppo_mini_batch_size
Expand Down
Loading
Loading