Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ jobs:
- name: Install the current repository
run: |
pip3 install --no-deps -e .[test]
pip3 install --upgrade tensordict
pip3 install --upgrade tensordict transformers
pip install --upgrade "huggingface_hub[cli]"
- name: Download model config files
run: |
Expand Down
121 changes: 117 additions & 4 deletions tests/models/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,27 @@
import pytest
import ray
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification

from verl import DataProto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils.model import compute_position_id_with_mask, create_random_mask
from verl.utils.torch_functional import logprobs_from_logits_naive
from verl.workers.config import (
ActorConfig,
CriticConfig,
FSDPEngineConfig,
FSDPOptimizerConfig,
HFModelConfig,
McoreEngineConfig,
McoreOptimizerConfig,
)
from verl.workers.roles import ActorWorker
from verl.workers.roles import ActorWorker, CriticWorker
from verl.workers.roles.utils.losses import ppo_loss, sft_loss


@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"])
def test_mcore_engine(strategy):
def test_actor_engine(strategy):
ray.init()

path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding the model path using os.path.expanduser makes this test dependent on a specific local file system setup. This will cause the test to fail for other developers or in CI environments where ~/models/Qwen/Qwen2.5-0.5B-Instruct does not exist. For large, pre-existing models, consider making the path configurable via an environment variable or a test configuration file, and provide instructions for setting it up. This will improve the test's portability and reproducibility.

Expand Down Expand Up @@ -72,7 +73,7 @@ def test_mcore_engine(strategy):
ppo_mini_batch_size=4,
optim=optimizer_config,
use_dynamic_bsz=True,
n=1,
rollout_n=1,
)
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorWorker), config=config)
resource_pool = RayResourcePool(process_on_nodes=[8])
Expand Down Expand Up @@ -151,3 +152,115 @@ def test_mcore_engine(strategy):
print(ppo_metrics)

ray.shutdown()


def create_model():
from transformers import Qwen3Config

config = Qwen3Config(num_hidden_layers=2, num_labels=1)
model = AutoModelForTokenClassification.from_config(config)
assert model.config.num_labels == 1
path = os.path.expanduser("~/models/test_model")
model.save_pretrained(path)
config.save_pretrained(path)
return path


@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"])
def test_critic_engine(strategy):
ray.init()

path = create_model()
model_config = HFModelConfig(path=path, load_tokenizer=False)

if strategy == "megatron":
engine_config = McoreEngineConfig(
forward_only=False,
use_mbridge=False,
tensor_model_parallel_size=2,
pipeline_model_parallel_size=2,
context_parallel_size=2,
)
optimizer_config = McoreOptimizerConfig(lr_decay_steps=10)
elif strategy in ["fsdp", "fsdp2"]:
engine_config = FSDPEngineConfig(
forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2
)
optimizer_config = FSDPOptimizerConfig()
else:
raise NotImplementedError(f"strategy {strategy} is not supported")

config = CriticConfig(
model_config=model_config,
engine=engine_config,
strategy=strategy,
ppo_micro_batch_size_per_gpu=256,
ppo_mini_batch_size=4,
optim=optimizer_config,
use_dynamic_bsz=True,
rollout_n=1,
)
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(CriticWorker), config=config)
resource_pool = RayResourcePool(process_on_nodes=[8])
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
# init model
wg.init_model()

batch_size = 8
seqlen = 32

response_length = seqlen // 2

torch.manual_seed(1)
np.random.seed(1)

input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen))
attention_mask = create_random_mask(
input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6
)
position_ids = compute_position_id_with_mask(attention_mask)

global_token_num = torch.sum(attention_mask, dim=-1).tolist()

print(input_ids.float().mean(), attention_mask.float().mean())

responses = input_ids[:, response_length:]
response_mask = attention_mask[:, response_length:]

assert torch.all(response_mask[:, 0] == 1)

data = DataProto.from_single_dict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"responses": responses,
"response_mask": response_mask,
},
meta_info={"temperature": 1.0, "global_token_num": global_token_num},
)

# eval
output = wg.compute_values(data)

# load hf model and compare results with hf model
hf_model = AutoModelForTokenClassification.from_pretrained(path, torch_dtype=torch.bfloat16)
hf_output = hf_model(input_ids, attention_mask=attention_mask)
hf_values = hf_output.logits[:, -response_length - 1 : -1, :].float().squeeze(-1)
hf_values_mean = torch.mean(hf_values * response_mask)

engine_values = torch.mean(output.batch["values"] * response_mask)

torch.testing.assert_close(hf_values_mean, engine_values, atol=5e-3, rtol=1e-2)

data = data.union(output)

# add ppo data
data.batch["values"] = torch.rand_like(responses, dtype=torch.float32)
data.batch["returns"] = torch.rand_like(responses, dtype=torch.float32)

# update again
ppo_metrics = wg.update_critic(data)
print(ppo_metrics)

ray.shutdown()
4 changes: 2 additions & 2 deletions verl/models/mcore/config_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def hf_to_mcore_config_dense(
hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
) -> TransformerConfig:
# for LlamaForCausalLM or Qwen2ForCausalLM
qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False)
qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False
qkv_bias = True if "Qwen2" in hf_config.architectures[0] else getattr(hf_config, "attention_bias", False)
qk_layernorm = True if "Qwen3" in hf_config.architectures[0] else False

args: dict = _get_base_transformer_config(
hf_config=hf_config,
Expand Down
3 changes: 3 additions & 0 deletions verl/models/mcore/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "reward_head.weight")
print_rank_0("load lm_head from value_head weight")
elif "score.weight" in state_dict and state_dict["score.weight"].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "score.weight")
print_rank_0("load lm_head from score weight")
else:
_broadcast_tensor(None, "lm_head.weight")
print_rank_0("fail to match lm_head in value_model")
Expand Down
6 changes: 6 additions & 0 deletions verl/models/mcore/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class SupportedModel(Enum):
QWEN3_MOE = "Qwen3MoeForCausalLM" # tested
GLM4_MOE = "Glm4MoeForCausalLM"

QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification"


# Registry for model configuration converters
MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
Expand All @@ -85,6 +87,7 @@ class SupportedModel(Enum):
SupportedModel.QWEN3: hf_to_mcore_config_dense,
SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense,
}

# Registry for model initializers
Expand All @@ -99,6 +102,7 @@ class SupportedModel(Enum):
SupportedModel.QWEN3: DenseModel,
SupportedModel.QWEN3_MOE: Qwen3MoEModel,
SupportedModel.QWEN2_5_VL: Qwen25VLModel,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: DenseModel,
}

# Registry for model forward functions
Expand All @@ -115,6 +119,7 @@ class SupportedModel(Enum):
SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
SupportedModel.DEEPSEEK_V3: gptmodel_forward,
SupportedModel.GLM4_MOE: gptmodel_forward,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward,
}

# Registry for model forward functions
Expand Down Expand Up @@ -143,6 +148,7 @@ class SupportedModel(Enum):
SupportedModel.QWEN3: McoreToHFWeightConverterDense,
SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,
SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: McoreToHFWeightConverterDense,
}


Expand Down
1 change: 1 addition & 0 deletions verl/models/weight_loader_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_weight_saver(arch: str):
"Qwen2_5_VLForConditionalGeneration": merge_megatron_ckpt_gptmodel_qwen2_5_vl,
"DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3,
"Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel,
"Qwen3ForTokenClassification": merge_megatron_ckpt_gptmodel,
"Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe,
}
if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY:
Expand Down
29 changes: 23 additions & 6 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ data:
critic:
optim:
_target_: verl.workers.config.McoreOptimizerConfig
lr: 1.0e-05
lr: 0.001
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
weight_decay: 0.01
Expand Down Expand Up @@ -346,20 +346,37 @@ critic:
override_mcore_model_config: {}
use_mbridge: false
forward_only: false
model_config:
_target_: verl.workers.config.HFModelConfig
path: ~/models/deepseek-llm-7b-chat
hf_config_path: null
tokenizer_path: null
use_shm: false
trust_remote_code: false
custom_chat_template: null
external_lib: null
override_config: {}
enable_gradient_checkpointing: true
enable_activation_offload: false
use_remove_padding: false
lora_rank: 0
lora_alpha: 16
target_modules: all-linear
exclude_modules: null
use_liger: false
use_fused_kernels: false
fused_kernel_options:
impl_backend: torch
_target_: verl.workers.config.McoreCriticConfig
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
strategy: megatron
enable: null
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"}
_target_: verl.trainer.config.BaseModelConfig
override_config:
model_config: {}
moe_config:
freeze_moe_router: false
external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}
trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}
_target_: verl.trainer.config.BaseModelConfig
ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}
ppo_micro_batch_size: null
ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}
Expand Down
46 changes: 22 additions & 24 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ data:
critic:
optim:
_target_: verl.workers.config.FSDPOptimizerConfig
lr: 1.0e-05
lr: 0.001
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
weight_decay: 0.01
Expand All @@ -311,29 +311,6 @@ critic:
num_cycles: 0.5
warmup_style: constant
model:
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
wrap_policy:
min_num_params: 0
param_offload: false
optimizer_offload: false
offload_policy: false
reshard_after_forward: true
fsdp_size: -1
forward_prefetch: false
model_dtype: fp32
use_orig_params: false
ulysses_sequence_parallel_size: 1
entropy_from_logits_with_chunking: false
use_torch_compile: true
entropy_checkpointing: false
forward_only: false
strategy: fsdp
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"}
override_config: {}
external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}
trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}
_target_: verl.workers.config.FSDPCriticModelCfg
use_shm: false
enable_gradient_checkpointing: true
Expand All @@ -342,6 +319,27 @@ critic:
lora_rank: 0
lora_alpha: 16
target_modules: all-linear
model_config:
_target_: verl.workers.config.HFModelConfig
path: ~/models/deepseek-llm-7b-chat
hf_config_path: null
tokenizer_path: null
use_shm: false
trust_remote_code: false
custom_chat_template: null
external_lib: null
override_config: {}
enable_gradient_checkpointing: true
enable_activation_offload: false
use_remove_padding: false
lora_rank: 0
lora_alpha: 16
target_modules: all-linear
exclude_modules: null
use_liger: false
use_fused_kernels: false
fused_kernel_options:
impl_backend: torch
_target_: verl.workers.config.FSDPCriticConfig
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
strategy: fsdp
Expand Down
Loading
Loading