Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ jobs:
- name: Install the current repository
run: |
pip3 install --no-deps -e .[test]
pip3 install --upgrade tensordict
pip3 install --upgrade tensordict transformers
pip install --upgrade "huggingface_hub[cli]"
- name: Download model config files
run: |
Expand Down
124 changes: 120 additions & 4 deletions tests/models/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,27 @@
import pytest
import ray
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification

from verl import DataProto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils.model import compute_position_id_with_mask, create_random_mask
from verl.utils.torch_functional import logprobs_from_logits_naive
from verl.workers.config import (
ActorConfig,
CriticConfig,
FSDPEngineConfig,
FSDPOptimizerConfig,
HFModelConfig,
McoreEngineConfig,
McoreOptimizerConfig,
)
from verl.workers.roles import ActorWorker
from verl.workers.roles import ActorWorker, CriticWorker
from verl.workers.roles.utils.losses import ppo_loss, sft_loss


@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"])
def test_mcore_engine(strategy):
def test_actor_engine(strategy):
ray.init()

path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding the model path using os.path.expanduser makes this test dependent on a specific local file system setup. This will cause the test to fail for other developers or in CI environments where ~/models/Qwen/Qwen2.5-0.5B-Instruct does not exist. For large, pre-existing models, consider making the path configurable via an environment variable or a test configuration file, and provide instructions for setting it up. This will improve the test's portability and reproducibility.

Expand Down Expand Up @@ -72,7 +73,7 @@ def test_mcore_engine(strategy):
ppo_mini_batch_size=4,
optim=optimizer_config,
use_dynamic_bsz=True,
n=1,
rollout_n=1,
)
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorWorker), config=config)
resource_pool = RayResourcePool(process_on_nodes=[8])
Expand Down Expand Up @@ -151,3 +152,118 @@ def test_mcore_engine(strategy):
print(ppo_metrics)

ray.shutdown()


def create_model():
from transformers import Qwen3Config

config = Qwen3Config(num_hidden_layers=2, num_labels=1)
model = AutoModelForTokenClassification.from_config(config)
assert model.config.num_labels == 1
path = os.path.expanduser("~/models/test_model")
model.save_pretrained(path)
config.save_pretrained(path)
return path


@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"])
def test_critic_engine(strategy):
ray.init()

path = create_model()
model_config = HFModelConfig(path=path, load_tokenizer=False)

if strategy == "megatron":
engine_config = McoreEngineConfig(
forward_only=False,
use_mbridge=False,
tensor_model_parallel_size=2,
pipeline_model_parallel_size=2,
context_parallel_size=2,
)
optimizer_config = McoreOptimizerConfig(lr_decay_steps=10)
elif strategy in ["fsdp", "fsdp2"]:
engine_config = FSDPEngineConfig(
forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2
)
optimizer_config = FSDPOptimizerConfig()
else:
raise NotImplementedError(f"strategy {strategy} is not supported")

config = CriticConfig(
model_config=model_config,
engine=engine_config,
strategy=strategy,
ppo_micro_batch_size_per_gpu=256,
ppo_mini_batch_size=4,
optim=optimizer_config,
use_dynamic_bsz=True,
rollout_n=1,
)
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(CriticWorker), config=config)
resource_pool = RayResourcePool(process_on_nodes=[8])
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
# init model
wg.init_model()

batch_size = 8
seqlen = 32

response_length = seqlen // 2

torch.manual_seed(1)
np.random.seed(1)

input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen))
attention_mask = create_random_mask(
input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6
)
position_ids = compute_position_id_with_mask(attention_mask)

global_token_num = torch.sum(attention_mask, dim=-1).tolist()

print(input_ids.float().mean(), attention_mask.float().mean())

responses = input_ids[:, response_length:]
response_mask = attention_mask[:, response_length:]

assert torch.all(response_mask[:, 0] == 1)

data = DataProto.from_single_dict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"responses": responses,
"response_mask": response_mask,
},
meta_info={"temperature": 1.0, "global_token_num": global_token_num},
)

# eval
output = wg.compute_values(data)

# load hf model and compare results with hf model
with torch.device("cuda"):
hf_model = AutoModelForTokenClassification.from_pretrained(
path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
hf_output = hf_model(input_ids.cuda(), attention_mask=attention_mask.cuda())
hf_values = hf_output.logits[:, -response_length - 1 : -1, :].float().squeeze(-1).cpu()
hf_values_mean = torch.mean(hf_values * response_mask)

engine_values = torch.mean(output.batch["values"] * response_mask)

torch.testing.assert_close(hf_values_mean, engine_values, atol=1e-2, rtol=1e-2)

data = data.union(output)

# add ppo data
data.batch["values"] = torch.rand_like(responses, dtype=torch.float32)
data.batch["returns"] = torch.rand_like(responses, dtype=torch.float32)

# update again
ppo_metrics = wg.update_critic(data)
print(ppo_metrics)

ray.shutdown()
4 changes: 2 additions & 2 deletions verl/models/mcore/config_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def hf_to_mcore_config_dense(
hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
) -> TransformerConfig:
# for LlamaForCausalLM or Qwen2ForCausalLM
qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False)
qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False
qkv_bias = True if "Qwen2" in hf_config.architectures[0] else getattr(hf_config, "attention_bias", False)
qk_layernorm = True if "Qwen3" in hf_config.architectures[0] else False

args: dict = _get_base_transformer_config(
hf_config=hf_config,
Expand Down
3 changes: 3 additions & 0 deletions verl/models/mcore/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "reward_head.weight")
print_rank_0("load lm_head from value_head weight")
elif "score.weight" in state_dict and state_dict["score.weight"].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "score.weight")
print_rank_0("load lm_head from score weight")
else:
_broadcast_tensor(None, "lm_head.weight")
print_rank_0("fail to match lm_head in value_model")
Expand Down
6 changes: 6 additions & 0 deletions verl/models/mcore/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class SupportedModel(Enum):
QWEN3_MOE = "Qwen3MoeForCausalLM" # tested
GLM4_MOE = "Glm4MoeForCausalLM"

QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification"


# Registry for model configuration converters
MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
Expand All @@ -85,6 +87,7 @@ class SupportedModel(Enum):
SupportedModel.QWEN3: hf_to_mcore_config_dense,
SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense,
}

# Registry for model initializers
Expand All @@ -99,6 +102,7 @@ class SupportedModel(Enum):
SupportedModel.QWEN3: DenseModel,
SupportedModel.QWEN3_MOE: Qwen3MoEModel,
SupportedModel.QWEN2_5_VL: Qwen25VLModel,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: DenseModel,
}

# Registry for model forward functions
Expand All @@ -115,6 +119,7 @@ class SupportedModel(Enum):
SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
SupportedModel.DEEPSEEK_V3: gptmodel_forward,
SupportedModel.GLM4_MOE: gptmodel_forward,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward,
}

# Registry for model forward functions
Expand Down Expand Up @@ -143,6 +148,7 @@ class SupportedModel(Enum):
SupportedModel.QWEN3: McoreToHFWeightConverterDense,
SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,
SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: McoreToHFWeightConverterDense,
}


Expand Down
1 change: 1 addition & 0 deletions verl/models/weight_loader_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_weight_saver(arch: str):
"Qwen2_5_VLForConditionalGeneration": merge_megatron_ckpt_gptmodel_qwen2_5_vl,
"DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3,
"Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel,
"Qwen3ForTokenClassification": merge_megatron_ckpt_gptmodel,
"Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe,
}
if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY:
Expand Down
6 changes: 3 additions & 3 deletions verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def __init__(
checkpoint_config: DictConfig = None,
**kwargs,
):
if processing_class is None:
assert "tokenizer" in kwargs, "tokenizer or processor must be provided"
if processing_class is None and "tokenizer" in kwargs:
warnings.warn(
"`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2
)
Expand Down Expand Up @@ -278,7 +277,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
pass

model_config.save_pretrained(hf_config_tokenizer_path)
self.processing_class.save_pretrained(hf_config_tokenizer_path)
if self.processing_class is not None:
self.processing_class.save_pretrained(hf_config_tokenizer_path)
log_with_rank(
f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}",
rank=self.rank,
Expand Down
3 changes: 2 additions & 1 deletion verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
if self.rank == 0:
# Save tokenizer
hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path)
self.processing_class.save_pretrained(hf_config_tokenizer_path)
if self.processing_class is not None:
self.processing_class.save_pretrained(hf_config_tokenizer_path)
# Save huggingface config
self.hf_config.save_pretrained(hf_config_tokenizer_path)
if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path:
Expand Down
34 changes: 21 additions & 13 deletions verl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
GenerationConfig,
MistralForSequenceClassification,
PretrainedConfig,
Expand Down Expand Up @@ -402,6 +406,9 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path):
architectures = getattr(model_config, "architectures", [])
local_cache_path = os.path.expanduser(local_cache_path)

# get auto class
auto_cls = get_hf_auto_model_class(model_config)

if config.model.path.startswith("hdfs:"):
from verl.utils.fs import copy_to_local

Expand Down Expand Up @@ -434,7 +441,7 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path):
] # workaround, 32001 -> 32000
is_value_model = True
else:
model = AutoModelForCausalLM.from_pretrained(
model = auto_cls.from_pretrained(
local_model_path,
torch_dtype="auto",
# device_map="auto", # disable auto device_map, the HF weight is only loaded to CPU in src_rank
Expand Down Expand Up @@ -658,13 +665,15 @@ def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_cod
return model


def get_hf_auto_model_class(hf_config):
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForVision2Seq,
)
_architecture_to_auto_class = {
"ForCausalLM": AutoModelForCausalLM,
"ForVision2Seq": AutoModelForVision2Seq,
"ForTokenClassification": AutoModelForTokenClassification,
"ForSequenceClassification": AutoModelForSequenceClassification,
}


def get_hf_auto_model_class(hf_config):
has_remote_code = hasattr(hf_config, "auto_map") and any(
hf_config.architectures[0] in val for val in hf_config.auto_map.values()
)
Expand All @@ -678,12 +687,11 @@ def get_hf_auto_model_class(hf_config):
case _:
actor_module_class = AutoModel
else:
if type(hf_config) in AutoModelForVision2Seq._model_mapping.keys():
actor_module_class = AutoModelForVision2Seq
elif type(hf_config) in AutoModelForCausalLM._model_mapping.keys():
actor_module_class = AutoModelForCausalLM
else:
actor_module_class = AutoModel
actor_module_class = AutoModel
for key, cls in _architecture_to_auto_class.items():
if key in hf_config.architectures[0]:
actor_module_class = cls
break

return actor_module_class

Expand Down
4 changes: 2 additions & 2 deletions verl/workers/config/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ class ActorConfig(BaseConfig):
profiler: ProfilerConfig = field(default_factory=ProfilerConfig)
engine: BaseConfig = field(default_factory=BaseConfig)
data_loader_seed = 1
n: int = 1 # must be override by sampling config
rollout_n: int = 1 # must be override by sampling config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this renaming complete in all references to the "n" parameter?

model_config: HFModelConfig = field(default_factory=BaseConfig)

def __post_init__(self):
"""Validate actor configuration parameters."""
assert self.strategy != MISSING
assert self.n != MISSING
assert self.rollout_n != MISSING
if not self.use_dynamic_bsz:
if self.ppo_micro_batch_size is not None and self.ppo_micro_batch_size_per_gpu is not None:
raise ValueError(
Expand Down
Loading