Skip to content

Commit 6e6fafd

Browse files
[model] feat: add FSDP/Megatron critic worker with model engine (#3439)
### What does this PR do? - As title - Add a test to compare the output of FSDP/Megatron engine with huggingface model ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 3c9b884 commit 6e6fafd

File tree

17 files changed

+526
-212
lines changed

17 files changed

+526
-212
lines changed

.github/workflows/model.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ jobs:
200200
- name: Install the current repository
201201
run: |
202202
pip3 install --no-deps -e .[test]
203-
pip3 install --upgrade tensordict
203+
pip3 install --upgrade tensordict transformers
204204
pip install --upgrade "huggingface_hub[cli]"
205205
- name: Download model config files
206206
run: |

tests/models/test_engine.py

Lines changed: 120 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,27 @@
2222
import pytest
2323
import ray
2424
import torch
25-
from transformers import AutoModelForCausalLM
25+
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification
2626

2727
from verl import DataProto
2828
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
2929
from verl.utils.model import compute_position_id_with_mask, create_random_mask
3030
from verl.utils.torch_functional import logprobs_from_logits_naive
3131
from verl.workers.config import (
3232
ActorConfig,
33+
CriticConfig,
3334
FSDPEngineConfig,
3435
FSDPOptimizerConfig,
3536
HFModelConfig,
3637
McoreEngineConfig,
3738
McoreOptimizerConfig,
3839
)
39-
from verl.workers.roles import ActorWorker
40+
from verl.workers.roles import ActorWorker, CriticWorker
4041
from verl.workers.roles.utils.losses import ppo_loss, sft_loss
4142

4243

4344
@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"])
44-
def test_mcore_engine(strategy):
45+
def test_actor_engine(strategy):
4546
ray.init()
4647

4748
path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
@@ -72,7 +73,7 @@ def test_mcore_engine(strategy):
7273
ppo_mini_batch_size=4,
7374
optim=optimizer_config,
7475
use_dynamic_bsz=True,
75-
n=1,
76+
rollout_n=1,
7677
)
7778
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorWorker), config=config)
7879
resource_pool = RayResourcePool(process_on_nodes=[8])
@@ -151,3 +152,118 @@ def test_mcore_engine(strategy):
151152
print(ppo_metrics)
152153

153154
ray.shutdown()
155+
156+
157+
def create_model():
158+
from transformers import Qwen3Config
159+
160+
config = Qwen3Config(num_hidden_layers=2, num_labels=1)
161+
model = AutoModelForTokenClassification.from_config(config)
162+
assert model.config.num_labels == 1
163+
path = os.path.expanduser("~/models/test_model")
164+
model.save_pretrained(path)
165+
config.save_pretrained(path)
166+
return path
167+
168+
169+
@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"])
170+
def test_critic_engine(strategy):
171+
ray.init()
172+
173+
path = create_model()
174+
model_config = HFModelConfig(path=path, load_tokenizer=False)
175+
176+
if strategy == "megatron":
177+
engine_config = McoreEngineConfig(
178+
forward_only=False,
179+
use_mbridge=False,
180+
tensor_model_parallel_size=2,
181+
pipeline_model_parallel_size=2,
182+
context_parallel_size=2,
183+
)
184+
optimizer_config = McoreOptimizerConfig(lr_decay_steps=10)
185+
elif strategy in ["fsdp", "fsdp2"]:
186+
engine_config = FSDPEngineConfig(
187+
forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2
188+
)
189+
optimizer_config = FSDPOptimizerConfig()
190+
else:
191+
raise NotImplementedError(f"strategy {strategy} is not supported")
192+
193+
config = CriticConfig(
194+
model_config=model_config,
195+
engine=engine_config,
196+
strategy=strategy,
197+
ppo_micro_batch_size_per_gpu=256,
198+
ppo_mini_batch_size=4,
199+
optim=optimizer_config,
200+
use_dynamic_bsz=True,
201+
rollout_n=1,
202+
)
203+
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(CriticWorker), config=config)
204+
resource_pool = RayResourcePool(process_on_nodes=[8])
205+
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
206+
# init model
207+
wg.init_model()
208+
209+
batch_size = 8
210+
seqlen = 32
211+
212+
response_length = seqlen // 2
213+
214+
torch.manual_seed(1)
215+
np.random.seed(1)
216+
217+
input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen))
218+
attention_mask = create_random_mask(
219+
input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6
220+
)
221+
position_ids = compute_position_id_with_mask(attention_mask)
222+
223+
global_token_num = torch.sum(attention_mask, dim=-1).tolist()
224+
225+
print(input_ids.float().mean(), attention_mask.float().mean())
226+
227+
responses = input_ids[:, response_length:]
228+
response_mask = attention_mask[:, response_length:]
229+
230+
assert torch.all(response_mask[:, 0] == 1)
231+
232+
data = DataProto.from_single_dict(
233+
{
234+
"input_ids": input_ids,
235+
"attention_mask": attention_mask,
236+
"position_ids": position_ids,
237+
"responses": responses,
238+
"response_mask": response_mask,
239+
},
240+
meta_info={"temperature": 1.0, "global_token_num": global_token_num},
241+
)
242+
243+
# eval
244+
output = wg.compute_values(data)
245+
246+
# load hf model and compare results with hf model
247+
with torch.device("cuda"):
248+
hf_model = AutoModelForTokenClassification.from_pretrained(
249+
path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
250+
)
251+
hf_output = hf_model(input_ids.cuda(), attention_mask=attention_mask.cuda())
252+
hf_values = hf_output.logits[:, -response_length - 1 : -1, :].float().squeeze(-1).cpu()
253+
hf_values_mean = torch.mean(hf_values * response_mask)
254+
255+
engine_values = torch.mean(output.batch["values"] * response_mask)
256+
257+
torch.testing.assert_close(hf_values_mean, engine_values, atol=1e-2, rtol=1e-2)
258+
259+
data = data.union(output)
260+
261+
# add ppo data
262+
data.batch["values"] = torch.rand_like(responses, dtype=torch.float32)
263+
data.batch["returns"] = torch.rand_like(responses, dtype=torch.float32)
264+
265+
# update again
266+
ppo_metrics = wg.update_critic(data)
267+
print(ppo_metrics)
268+
269+
ray.shutdown()

verl/models/mcore/config_converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def hf_to_mcore_config_dense(
165165
hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
166166
) -> TransformerConfig:
167167
# for LlamaForCausalLM or Qwen2ForCausalLM
168-
qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False)
169-
qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False
168+
qkv_bias = True if "Qwen2" in hf_config.architectures[0] else getattr(hf_config, "attention_bias", False)
169+
qk_layernorm = True if "Qwen3" in hf_config.architectures[0] else False
170170

171171
args: dict = _get_base_transformer_config(
172172
hf_config=hf_config,

verl/models/mcore/loader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -
474474
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
475475
_broadcast_tensor(lm_head_weight, "reward_head.weight")
476476
print_rank_0("load lm_head from value_head weight")
477+
elif "score.weight" in state_dict and state_dict["score.weight"].shape[0] == 1:
478+
_broadcast_tensor(lm_head_weight, "score.weight")
479+
print_rank_0("load lm_head from score weight")
477480
else:
478481
_broadcast_tensor(None, "lm_head.weight")
479482
print_rank_0("fail to match lm_head in value_model")

verl/models/mcore/registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class SupportedModel(Enum):
7272
QWEN3_MOE = "Qwen3MoeForCausalLM" # tested
7373
GLM4_MOE = "Glm4MoeForCausalLM"
7474

75+
QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification"
76+
7577

7678
# Registry for model configuration converters
7779
MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
@@ -85,6 +87,7 @@ class SupportedModel(Enum):
8587
SupportedModel.QWEN3: hf_to_mcore_config_dense,
8688
SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
8789
SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
90+
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense,
8891
}
8992

9093
# Registry for model initializers
@@ -99,6 +102,7 @@ class SupportedModel(Enum):
99102
SupportedModel.QWEN3: DenseModel,
100103
SupportedModel.QWEN3_MOE: Qwen3MoEModel,
101104
SupportedModel.QWEN2_5_VL: Qwen25VLModel,
105+
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: DenseModel,
102106
}
103107

104108
# Registry for model forward functions
@@ -115,6 +119,7 @@ class SupportedModel(Enum):
115119
SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
116120
SupportedModel.DEEPSEEK_V3: gptmodel_forward,
117121
SupportedModel.GLM4_MOE: gptmodel_forward,
122+
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward,
118123
}
119124

120125
# Registry for model forward functions
@@ -143,6 +148,7 @@ class SupportedModel(Enum):
143148
SupportedModel.QWEN3: McoreToHFWeightConverterDense,
144149
SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,
145150
SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL,
151+
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: McoreToHFWeightConverterDense,
146152
}
147153

148154

verl/models/weight_loader_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def get_weight_saver(arch: str):
4646
"Qwen2_5_VLForConditionalGeneration": merge_megatron_ckpt_gptmodel_qwen2_5_vl,
4747
"DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3,
4848
"Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel,
49+
"Qwen3ForTokenClassification": merge_megatron_ckpt_gptmodel,
4950
"Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe,
5051
}
5152
if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY:

verl/utils/checkpoint/fsdp_checkpoint_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ def __init__(
8181
checkpoint_config: DictConfig = None,
8282
**kwargs,
8383
):
84-
if processing_class is None:
85-
assert "tokenizer" in kwargs, "tokenizer or processor must be provided"
84+
if processing_class is None and "tokenizer" in kwargs:
8685
warnings.warn(
8786
"`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2
8887
)
@@ -278,7 +277,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
278277
pass
279278

280279
model_config.save_pretrained(hf_config_tokenizer_path)
281-
self.processing_class.save_pretrained(hf_config_tokenizer_path)
280+
if self.processing_class is not None:
281+
self.processing_class.save_pretrained(hf_config_tokenizer_path)
282282
log_with_rank(
283283
f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}",
284284
rank=self.rank,

verl/utils/checkpoint/megatron_checkpoint_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
438438
if self.rank == 0:
439439
# Save tokenizer
440440
hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path)
441-
self.processing_class.save_pretrained(hf_config_tokenizer_path)
441+
if self.processing_class is not None:
442+
self.processing_class.save_pretrained(hf_config_tokenizer_path)
442443
# Save huggingface config
443444
self.hf_config.save_pretrained(hf_config_tokenizer_path)
444445
if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path:

verl/utils/model.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
from torch import nn
2727
from transformers import (
2828
AutoConfig,
29+
AutoModel,
2930
AutoModelForCausalLM,
31+
AutoModelForSequenceClassification,
32+
AutoModelForTokenClassification,
33+
AutoModelForVision2Seq,
3034
GenerationConfig,
3135
MistralForSequenceClassification,
3236
PretrainedConfig,
@@ -402,6 +406,9 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path):
402406
architectures = getattr(model_config, "architectures", [])
403407
local_cache_path = os.path.expanduser(local_cache_path)
404408

409+
# get auto class
410+
auto_cls = get_hf_auto_model_class(model_config)
411+
405412
if config.model.path.startswith("hdfs:"):
406413
from verl.utils.fs import copy_to_local
407414

@@ -434,7 +441,7 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path):
434441
] # workaround, 32001 -> 32000
435442
is_value_model = True
436443
else:
437-
model = AutoModelForCausalLM.from_pretrained(
444+
model = auto_cls.from_pretrained(
438445
local_model_path,
439446
torch_dtype="auto",
440447
# device_map="auto", # disable auto device_map, the HF weight is only loaded to CPU in src_rank
@@ -658,13 +665,15 @@ def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_cod
658665
return model
659666

660667

661-
def get_hf_auto_model_class(hf_config):
662-
from transformers import (
663-
AutoModel,
664-
AutoModelForCausalLM,
665-
AutoModelForVision2Seq,
666-
)
668+
_architecture_to_auto_class = {
669+
"ForCausalLM": AutoModelForCausalLM,
670+
"ForVision2Seq": AutoModelForVision2Seq,
671+
"ForTokenClassification": AutoModelForTokenClassification,
672+
"ForSequenceClassification": AutoModelForSequenceClassification,
673+
}
674+
667675

676+
def get_hf_auto_model_class(hf_config):
668677
has_remote_code = hasattr(hf_config, "auto_map") and any(
669678
hf_config.architectures[0] in val for val in hf_config.auto_map.values()
670679
)
@@ -678,12 +687,11 @@ def get_hf_auto_model_class(hf_config):
678687
case _:
679688
actor_module_class = AutoModel
680689
else:
681-
if type(hf_config) in AutoModelForVision2Seq._model_mapping.keys():
682-
actor_module_class = AutoModelForVision2Seq
683-
elif type(hf_config) in AutoModelForCausalLM._model_mapping.keys():
684-
actor_module_class = AutoModelForCausalLM
685-
else:
686-
actor_module_class = AutoModel
690+
actor_module_class = AutoModel
691+
for key, cls in _architecture_to_auto_class.items():
692+
if key in hf_config.architectures[0]:
693+
actor_module_class = cls
694+
break
687695

688696
return actor_module_class
689697

verl/workers/config/actor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,13 @@ class ActorConfig(BaseConfig):
119119
profiler: ProfilerConfig = field(default_factory=ProfilerConfig)
120120
engine: BaseConfig = field(default_factory=BaseConfig)
121121
data_loader_seed = 1
122-
n: int = 1 # must be override by sampling config
122+
rollout_n: int = 1 # must be override by sampling config
123123
model_config: HFModelConfig = field(default_factory=BaseConfig)
124124

125125
def __post_init__(self):
126126
"""Validate actor configuration parameters."""
127127
assert self.strategy != MISSING
128-
assert self.n != MISSING
128+
assert self.rollout_n != MISSING
129129
if not self.use_dynamic_bsz:
130130
if self.ppo_micro_batch_size is not None and self.ppo_micro_batch_size_per_gpu is not None:
131131
raise ValueError(

0 commit comments

Comments
 (0)