-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[model] feat: add FSDP/Megatron critic worker with model engine #3439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
cede2a8
011f972
e52577b
fe2bd10
5a05aab
a18eaf8
bf34975
370f174
2524490
72429d3
a8ab634
8523e57
6484b07
e22a94b
581baf9
4693ab1
d4234d3
40f034c
5336ed0
4e7e4f6
8952548
d768b40
4e9537d
2c9f369
140e3e7
f13d88b
f049154
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,26 +22,27 @@ | |
| import pytest | ||
| import ray | ||
| import torch | ||
| from transformers import AutoModelForCausalLM | ||
| from transformers import AutoModelForCausalLM, AutoModelForTokenClassification | ||
|
|
||
| from verl import DataProto | ||
| from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup | ||
| from verl.utils.model import compute_position_id_with_mask, create_random_mask | ||
| from verl.utils.torch_functional import logprobs_from_logits_naive | ||
| from verl.workers.config import ( | ||
| ActorConfig, | ||
| CriticConfig, | ||
| FSDPEngineConfig, | ||
| FSDPOptimizerConfig, | ||
| HFModelConfig, | ||
| McoreEngineConfig, | ||
| McoreOptimizerConfig, | ||
| ) | ||
| from verl.workers.roles import ActorWorker | ||
| from verl.workers.roles import ActorWorker, CriticWorker | ||
| from verl.workers.roles.utils.losses import ppo_loss, sft_loss | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"]) | ||
| def test_mcore_engine(strategy): | ||
| def test_actor_engine(strategy): | ||
| ray.init() | ||
|
|
||
| path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct") | ||
|
|
@@ -72,7 +73,7 @@ def test_mcore_engine(strategy): | |
| ppo_mini_batch_size=4, | ||
| optim=optimizer_config, | ||
| use_dynamic_bsz=True, | ||
| n=1, | ||
| rollout_n=1, | ||
| ) | ||
| ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorWorker), config=config) | ||
| resource_pool = RayResourcePool(process_on_nodes=[8]) | ||
|
|
@@ -151,3 +152,105 @@ def test_mcore_engine(strategy): | |
| print(ppo_metrics) | ||
|
|
||
| ray.shutdown() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("strategy", ["fsdp", "fsdp2"]) | ||
| def test_critic_engine(strategy): | ||
| ray.init() | ||
|
|
||
| path = os.path.expanduser("~/models/Skywork/Skywork-Reward-V2-Qwen3-0.6B") | ||
| model_config = HFModelConfig(path=path) | ||
|
|
||
| strategy = "fsdp" | ||
|
||
|
|
||
| if strategy == "megatron": | ||
| engine_config = McoreEngineConfig( | ||
| forward_only=False, | ||
| use_mbridge=False, | ||
| tensor_model_parallel_size=2, | ||
| pipeline_model_parallel_size=2, | ||
| context_parallel_size=2, | ||
| ) | ||
| optimizer_config = McoreOptimizerConfig(lr_decay_steps=10) | ||
| elif strategy in ["fsdp", "fsdp2"]: | ||
| engine_config = FSDPEngineConfig( | ||
| forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2 | ||
| ) | ||
| optimizer_config = FSDPOptimizerConfig() | ||
| else: | ||
| raise NotImplementedError(f"strategy {strategy} is not supported") | ||
|
|
||
| config = CriticConfig( | ||
| model_config=model_config, | ||
| engine=engine_config, | ||
| strategy=strategy, | ||
| ppo_micro_batch_size_per_gpu=256, | ||
| ppo_mini_batch_size=4, | ||
| optim=optimizer_config, | ||
| use_dynamic_bsz=True, | ||
| rollout_n=1, | ||
| ) | ||
| ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(CriticWorker), config=config) | ||
| resource_pool = RayResourcePool(process_on_nodes=[8]) | ||
| wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) | ||
| # init model | ||
| wg.init_model() | ||
|
|
||
| batch_size = 8 | ||
| seqlen = 32 | ||
|
|
||
| response_length = seqlen // 2 | ||
|
|
||
| torch.manual_seed(1) | ||
| np.random.seed(1) | ||
|
|
||
| input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen)) | ||
| attention_mask = create_random_mask( | ||
| input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6 | ||
| ) | ||
| position_ids = compute_position_id_with_mask(attention_mask) | ||
|
|
||
| global_token_num = torch.sum(attention_mask, dim=-1).tolist() | ||
|
|
||
| print(input_ids.float().mean(), attention_mask.float().mean()) | ||
|
|
||
| responses = input_ids[:, response_length:] | ||
| response_mask = attention_mask[:, response_length:] | ||
|
|
||
| assert torch.all(response_mask[:, 0] == 1) | ||
|
|
||
| data = DataProto.from_single_dict( | ||
| { | ||
| "input_ids": input_ids, | ||
| "attention_mask": attention_mask, | ||
| "position_ids": position_ids, | ||
| "responses": responses, | ||
| "response_mask": response_mask, | ||
| }, | ||
| meta_info={"temperature": 1.0, "global_token_num": global_token_num}, | ||
| ) | ||
|
|
||
| # eval | ||
| output = wg.compute_values(data) | ||
|
|
||
| # load hf model and compare results with hf model | ||
| hf_model = AutoModelForTokenClassification.from_pretrained(path, torch_dtype=torch.bfloat16) | ||
| hf_output = hf_model(input_ids, attention_mask=attention_mask) | ||
| hf_values = hf_output.logits[:, -response_length - 1 : -1, :].float().squeeze(-1) | ||
| hf_values_mean = torch.mean(hf_values * response_mask) | ||
|
|
||
| engine_values = torch.mean(output.batch["values"] * response_mask) | ||
|
|
||
| torch.testing.assert_close(hf_values_mean, engine_values, atol=1e-3, rtol=1e-2) | ||
|
|
||
| data = data.union(output) | ||
|
|
||
| # add ppo data | ||
| data.batch["values"] = torch.rand_like(responses, dtype=torch.float32) | ||
| data.batch["returns"] = torch.rand_like(responses, dtype=torch.float32) | ||
|
|
||
| # update again | ||
| ppo_metrics = wg.update_critic(data) | ||
| print(ppo_metrics) | ||
|
|
||
| ray.shutdown() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -118,13 +118,13 @@ class ActorConfig(BaseConfig): | |
| profiler: ProfilerConfig = field(default_factory=ProfilerConfig) | ||
| engine: BaseConfig = field(default_factory=BaseConfig) | ||
| data_loader_seed = 1 | ||
| n: int = 1 # must be override by sampling config | ||
| rollout_n: int = 1 # must be override by sampling config | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this renaming complete in all references to the "n" parameter? |
||
| model_config: HFModelConfig = field(default_factory=BaseConfig) | ||
|
|
||
| def __post_init__(self): | ||
| """Validate actor configuration parameters.""" | ||
| assert self.strategy != MISSING | ||
| assert self.n != MISSING | ||
| assert self.rollout_n != MISSING | ||
| if not self.use_dynamic_bsz: | ||
| if self.ppo_micro_batch_size is not None and self.ppo_micro_batch_size_per_gpu is not None: | ||
| raise ValueError( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding the model path using
os.path.expandusermakes this test dependent on a specific local file system setup. This will cause the test to fail for other developers or in CI environments where~/models/Qwen/Qwen2.5-0.5B-Instructdoes not exist. For large, pre-existing models, consider making the path configurable via an environment variable or a test configuration file, and provide instructions for setting it up. This will improve the test's portability and reproducibility.