|
22 | 22 | import pytest |
23 | 23 | import ray |
24 | 24 | import torch |
25 | | -from transformers import AutoModelForCausalLM |
| 25 | +from transformers import AutoModelForCausalLM, AutoModelForTokenClassification |
26 | 26 |
|
27 | 27 | from verl import DataProto |
28 | 28 | from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup |
29 | 29 | from verl.utils.model import compute_position_id_with_mask, create_random_mask |
30 | 30 | from verl.utils.torch_functional import logprobs_from_logits_naive |
31 | 31 | from verl.workers.config import ( |
32 | 32 | ActorConfig, |
| 33 | + CriticConfig, |
33 | 34 | FSDPEngineConfig, |
34 | 35 | FSDPOptimizerConfig, |
35 | 36 | HFModelConfig, |
36 | 37 | McoreEngineConfig, |
37 | 38 | McoreOptimizerConfig, |
38 | 39 | ) |
39 | | -from verl.workers.roles import ActorWorker |
| 40 | +from verl.workers.roles import ActorWorker, CriticWorker |
40 | 41 | from verl.workers.roles.utils.losses import ppo_loss, sft_loss |
41 | 42 |
|
42 | 43 |
|
43 | 44 | @pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"]) |
44 | | -def test_mcore_engine(strategy): |
| 45 | +def test_actor_engine(strategy): |
45 | 46 | ray.init() |
46 | 47 |
|
47 | 48 | path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct") |
@@ -72,7 +73,7 @@ def test_mcore_engine(strategy): |
72 | 73 | ppo_mini_batch_size=4, |
73 | 74 | optim=optimizer_config, |
74 | 75 | use_dynamic_bsz=True, |
75 | | - n=1, |
| 76 | + rollout_n=1, |
76 | 77 | ) |
77 | 78 | ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorWorker), config=config) |
78 | 79 | resource_pool = RayResourcePool(process_on_nodes=[8]) |
@@ -151,3 +152,118 @@ def test_mcore_engine(strategy): |
151 | 152 | print(ppo_metrics) |
152 | 153 |
|
153 | 154 | ray.shutdown() |
| 155 | + |
| 156 | + |
| 157 | +def create_model(): |
| 158 | + from transformers import Qwen3Config |
| 159 | + |
| 160 | + config = Qwen3Config(num_hidden_layers=2, num_labels=1) |
| 161 | + model = AutoModelForTokenClassification.from_config(config) |
| 162 | + assert model.config.num_labels == 1 |
| 163 | + path = os.path.expanduser("~/models/test_model") |
| 164 | + model.save_pretrained(path) |
| 165 | + config.save_pretrained(path) |
| 166 | + return path |
| 167 | + |
| 168 | + |
| 169 | +@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"]) |
| 170 | +def test_critic_engine(strategy): |
| 171 | + ray.init() |
| 172 | + |
| 173 | + path = create_model() |
| 174 | + model_config = HFModelConfig(path=path, load_tokenizer=False) |
| 175 | + |
| 176 | + if strategy == "megatron": |
| 177 | + engine_config = McoreEngineConfig( |
| 178 | + forward_only=False, |
| 179 | + use_mbridge=False, |
| 180 | + tensor_model_parallel_size=2, |
| 181 | + pipeline_model_parallel_size=2, |
| 182 | + context_parallel_size=2, |
| 183 | + ) |
| 184 | + optimizer_config = McoreOptimizerConfig(lr_decay_steps=10) |
| 185 | + elif strategy in ["fsdp", "fsdp2"]: |
| 186 | + engine_config = FSDPEngineConfig( |
| 187 | + forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2 |
| 188 | + ) |
| 189 | + optimizer_config = FSDPOptimizerConfig() |
| 190 | + else: |
| 191 | + raise NotImplementedError(f"strategy {strategy} is not supported") |
| 192 | + |
| 193 | + config = CriticConfig( |
| 194 | + model_config=model_config, |
| 195 | + engine=engine_config, |
| 196 | + strategy=strategy, |
| 197 | + ppo_micro_batch_size_per_gpu=256, |
| 198 | + ppo_mini_batch_size=4, |
| 199 | + optim=optimizer_config, |
| 200 | + use_dynamic_bsz=True, |
| 201 | + rollout_n=1, |
| 202 | + ) |
| 203 | + ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(CriticWorker), config=config) |
| 204 | + resource_pool = RayResourcePool(process_on_nodes=[8]) |
| 205 | + wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) |
| 206 | + # init model |
| 207 | + wg.init_model() |
| 208 | + |
| 209 | + batch_size = 8 |
| 210 | + seqlen = 32 |
| 211 | + |
| 212 | + response_length = seqlen // 2 |
| 213 | + |
| 214 | + torch.manual_seed(1) |
| 215 | + np.random.seed(1) |
| 216 | + |
| 217 | + input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen)) |
| 218 | + attention_mask = create_random_mask( |
| 219 | + input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6 |
| 220 | + ) |
| 221 | + position_ids = compute_position_id_with_mask(attention_mask) |
| 222 | + |
| 223 | + global_token_num = torch.sum(attention_mask, dim=-1).tolist() |
| 224 | + |
| 225 | + print(input_ids.float().mean(), attention_mask.float().mean()) |
| 226 | + |
| 227 | + responses = input_ids[:, response_length:] |
| 228 | + response_mask = attention_mask[:, response_length:] |
| 229 | + |
| 230 | + assert torch.all(response_mask[:, 0] == 1) |
| 231 | + |
| 232 | + data = DataProto.from_single_dict( |
| 233 | + { |
| 234 | + "input_ids": input_ids, |
| 235 | + "attention_mask": attention_mask, |
| 236 | + "position_ids": position_ids, |
| 237 | + "responses": responses, |
| 238 | + "response_mask": response_mask, |
| 239 | + }, |
| 240 | + meta_info={"temperature": 1.0, "global_token_num": global_token_num}, |
| 241 | + ) |
| 242 | + |
| 243 | + # eval |
| 244 | + output = wg.compute_values(data) |
| 245 | + |
| 246 | + # load hf model and compare results with hf model |
| 247 | + with torch.device("cuda"): |
| 248 | + hf_model = AutoModelForTokenClassification.from_pretrained( |
| 249 | + path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" |
| 250 | + ) |
| 251 | + hf_output = hf_model(input_ids.cuda(), attention_mask=attention_mask.cuda()) |
| 252 | + hf_values = hf_output.logits[:, -response_length - 1 : -1, :].float().squeeze(-1).cpu() |
| 253 | + hf_values_mean = torch.mean(hf_values * response_mask) |
| 254 | + |
| 255 | + engine_values = torch.mean(output.batch["values"] * response_mask) |
| 256 | + |
| 257 | + torch.testing.assert_close(hf_values_mean, engine_values, atol=1e-2, rtol=1e-2) |
| 258 | + |
| 259 | + data = data.union(output) |
| 260 | + |
| 261 | + # add ppo data |
| 262 | + data.batch["values"] = torch.rand_like(responses, dtype=torch.float32) |
| 263 | + data.batch["returns"] = torch.rand_like(responses, dtype=torch.float32) |
| 264 | + |
| 265 | + # update again |
| 266 | + ppo_metrics = wg.update_critic(data) |
| 267 | + print(ppo_metrics) |
| 268 | + |
| 269 | + ray.shutdown() |
0 commit comments