diff --git a/.github/workflows/model.yml b/.github/workflows/model.yml index 05d91e37c48..f08702d06a3 100644 --- a/.github/workflows/model.yml +++ b/.github/workflows/model.yml @@ -200,7 +200,7 @@ jobs: - name: Install the current repository run: | pip3 install --no-deps -e .[test] - pip3 install --upgrade tensordict + pip3 install --upgrade tensordict transformers pip install --upgrade "huggingface_hub[cli]" - name: Download model config files run: | diff --git a/tests/models/test_engine.py b/tests/models/test_engine.py index a31b6efd4f4..0a7c4a926f7 100644 --- a/tests/models/test_engine.py +++ b/tests/models/test_engine.py @@ -22,7 +22,7 @@ import pytest import ray import torch -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoModelForTokenClassification from verl import DataProto from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup @@ -30,18 +30,19 @@ from verl.utils.torch_functional import logprobs_from_logits_naive from verl.workers.config import ( ActorConfig, + CriticConfig, FSDPEngineConfig, FSDPOptimizerConfig, HFModelConfig, McoreEngineConfig, McoreOptimizerConfig, ) -from verl.workers.roles import ActorWorker +from verl.workers.roles import ActorWorker, CriticWorker from verl.workers.roles.utils.losses import ppo_loss, sft_loss @pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"]) -def test_mcore_engine(strategy): +def test_actor_engine(strategy): ray.init() path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct") @@ -72,7 +73,7 @@ def test_mcore_engine(strategy): ppo_mini_batch_size=4, optim=optimizer_config, use_dynamic_bsz=True, - n=1, + rollout_n=1, ) ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorWorker), config=config) resource_pool = RayResourcePool(process_on_nodes=[8]) @@ -151,3 +152,118 @@ def test_mcore_engine(strategy): print(ppo_metrics) ray.shutdown() + + +def create_model(): + from transformers import Qwen3Config + + config = Qwen3Config(num_hidden_layers=2, num_labels=1) + model = AutoModelForTokenClassification.from_config(config) + assert model.config.num_labels == 1 + path = os.path.expanduser("~/models/test_model") + model.save_pretrained(path) + config.save_pretrained(path) + return path + + +@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"]) +def test_critic_engine(strategy): + ray.init() + + path = create_model() + model_config = HFModelConfig(path=path, load_tokenizer=False) + + if strategy == "megatron": + engine_config = McoreEngineConfig( + forward_only=False, + use_mbridge=False, + tensor_model_parallel_size=2, + pipeline_model_parallel_size=2, + context_parallel_size=2, + ) + optimizer_config = McoreOptimizerConfig(lr_decay_steps=10) + elif strategy in ["fsdp", "fsdp2"]: + engine_config = FSDPEngineConfig( + forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2 + ) + optimizer_config = FSDPOptimizerConfig() + else: + raise NotImplementedError(f"strategy {strategy} is not supported") + + config = CriticConfig( + model_config=model_config, + engine=engine_config, + strategy=strategy, + ppo_micro_batch_size_per_gpu=256, + ppo_mini_batch_size=4, + optim=optimizer_config, + use_dynamic_bsz=True, + rollout_n=1, + ) + ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(CriticWorker), config=config) + resource_pool = RayResourcePool(process_on_nodes=[8]) + wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) + # init model + wg.init_model() + + batch_size = 8 + seqlen = 32 + + response_length = seqlen // 2 + + torch.manual_seed(1) + np.random.seed(1) + + input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen)) + attention_mask = create_random_mask( + input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6 + ) + position_ids = compute_position_id_with_mask(attention_mask) + + global_token_num = torch.sum(attention_mask, dim=-1).tolist() + + print(input_ids.float().mean(), attention_mask.float().mean()) + + responses = input_ids[:, response_length:] + response_mask = attention_mask[:, response_length:] + + assert torch.all(response_mask[:, 0] == 1) + + data = DataProto.from_single_dict( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "responses": responses, + "response_mask": response_mask, + }, + meta_info={"temperature": 1.0, "global_token_num": global_token_num}, + ) + + # eval + output = wg.compute_values(data) + + # load hf model and compare results with hf model + with torch.device("cuda"): + hf_model = AutoModelForTokenClassification.from_pretrained( + path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + hf_output = hf_model(input_ids.cuda(), attention_mask=attention_mask.cuda()) + hf_values = hf_output.logits[:, -response_length - 1 : -1, :].float().squeeze(-1).cpu() + hf_values_mean = torch.mean(hf_values * response_mask) + + engine_values = torch.mean(output.batch["values"] * response_mask) + + torch.testing.assert_close(hf_values_mean, engine_values, atol=1e-2, rtol=1e-2) + + data = data.union(output) + + # add ppo data + data.batch["values"] = torch.rand_like(responses, dtype=torch.float32) + data.batch["returns"] = torch.rand_like(responses, dtype=torch.float32) + + # update again + ppo_metrics = wg.update_critic(data) + print(ppo_metrics) + + ray.shutdown() diff --git a/verl/models/mcore/config_converter.py b/verl/models/mcore/config_converter.py index 9d3809f93c2..5ab90122854 100644 --- a/verl/models/mcore/config_converter.py +++ b/verl/models/mcore/config_converter.py @@ -165,8 +165,8 @@ def hf_to_mcore_config_dense( hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs ) -> TransformerConfig: # for LlamaForCausalLM or Qwen2ForCausalLM - qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) - qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False + qkv_bias = True if "Qwen2" in hf_config.architectures[0] else getattr(hf_config, "attention_bias", False) + qk_layernorm = True if "Qwen3" in hf_config.architectures[0] else False args: dict = _get_base_transformer_config( hf_config=hf_config, diff --git a/verl/models/mcore/loader.py b/verl/models/mcore/loader.py index 659b4baa243..577ffc5ecf4 100644 --- a/verl/models/mcore/loader.py +++ b/verl/models/mcore/loader.py @@ -474,6 +474,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: _broadcast_tensor(lm_head_weight, "reward_head.weight") print_rank_0("load lm_head from value_head weight") + elif "score.weight" in state_dict and state_dict["score.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "score.weight") + print_rank_0("load lm_head from score weight") else: _broadcast_tensor(None, "lm_head.weight") print_rank_0("fail to match lm_head in value_model") diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py index 91503506309..2fafe542581 100644 --- a/verl/models/mcore/registry.py +++ b/verl/models/mcore/registry.py @@ -72,6 +72,8 @@ class SupportedModel(Enum): QWEN3_MOE = "Qwen3MoeForCausalLM" # tested GLM4_MOE = "Glm4MoeForCausalLM" + QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification" + # Registry for model configuration converters MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = { @@ -85,6 +87,7 @@ class SupportedModel(Enum): SupportedModel.QWEN3: hf_to_mcore_config_dense, SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe, SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl, + SupportedModel.QWEN3_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense, } # Registry for model initializers @@ -99,6 +102,7 @@ class SupportedModel(Enum): SupportedModel.QWEN3: DenseModel, SupportedModel.QWEN3_MOE: Qwen3MoEModel, SupportedModel.QWEN2_5_VL: Qwen25VLModel, + SupportedModel.QWEN3_TOKEN_CLASSIFICATION: DenseModel, } # Registry for model forward functions @@ -115,6 +119,7 @@ class SupportedModel(Enum): SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl, SupportedModel.DEEPSEEK_V3: gptmodel_forward, SupportedModel.GLM4_MOE: gptmodel_forward, + SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward, } # Registry for model forward functions @@ -143,6 +148,7 @@ class SupportedModel(Enum): SupportedModel.QWEN3: McoreToHFWeightConverterDense, SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe, SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL, + SupportedModel.QWEN3_TOKEN_CLASSIFICATION: McoreToHFWeightConverterDense, } diff --git a/verl/models/weight_loader_registry.py b/verl/models/weight_loader_registry.py index 8aa3bc71f84..0904f14fad4 100644 --- a/verl/models/weight_loader_registry.py +++ b/verl/models/weight_loader_registry.py @@ -46,6 +46,7 @@ def get_weight_saver(arch: str): "Qwen2_5_VLForConditionalGeneration": merge_megatron_ckpt_gptmodel_qwen2_5_vl, "DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3, "Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel, + "Qwen3ForTokenClassification": merge_megatron_ckpt_gptmodel, "Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, } if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY: diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index 283bdb59d31..4b3d7a8e7ad 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -81,8 +81,7 @@ def __init__( checkpoint_config: DictConfig = None, **kwargs, ): - if processing_class is None: - assert "tokenizer" in kwargs, "tokenizer or processor must be provided" + if processing_class is None and "tokenizer" in kwargs: warnings.warn( "`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2 ) @@ -278,7 +277,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i pass model_config.save_pretrained(hf_config_tokenizer_path) - self.processing_class.save_pretrained(hf_config_tokenizer_path) + if self.processing_class is not None: + self.processing_class.save_pretrained(hf_config_tokenizer_path) log_with_rank( f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}", rank=self.rank, diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py index de5317f5f02..b91ef6070b9 100644 --- a/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ b/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -438,7 +438,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i if self.rank == 0: # Save tokenizer hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path) - self.processing_class.save_pretrained(hf_config_tokenizer_path) + if self.processing_class is not None: + self.processing_class.save_pretrained(hf_config_tokenizer_path) # Save huggingface config self.hf_config.save_pretrained(hf_config_tokenizer_path) if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path: diff --git a/verl/utils/model.py b/verl/utils/model.py index 4fb1fc350d6..e0c275e8c43 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -26,7 +26,11 @@ from torch import nn from transformers import ( AutoConfig, + AutoModel, AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoModelForVision2Seq, GenerationConfig, MistralForSequenceClassification, PretrainedConfig, @@ -402,6 +406,9 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): architectures = getattr(model_config, "architectures", []) local_cache_path = os.path.expanduser(local_cache_path) + # get auto class + auto_cls = get_hf_auto_model_class(model_config) + if config.model.path.startswith("hdfs:"): from verl.utils.fs import copy_to_local @@ -434,7 +441,7 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): ] # workaround, 32001 -> 32000 is_value_model = True else: - model = AutoModelForCausalLM.from_pretrained( + model = auto_cls.from_pretrained( local_model_path, torch_dtype="auto", # device_map="auto", # disable auto device_map, the HF weight is only loaded to CPU in src_rank @@ -658,13 +665,15 @@ def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_cod return model -def get_hf_auto_model_class(hf_config): - from transformers import ( - AutoModel, - AutoModelForCausalLM, - AutoModelForVision2Seq, - ) +_architecture_to_auto_class = { + "ForCausalLM": AutoModelForCausalLM, + "ForVision2Seq": AutoModelForVision2Seq, + "ForTokenClassification": AutoModelForTokenClassification, + "ForSequenceClassification": AutoModelForSequenceClassification, +} + +def get_hf_auto_model_class(hf_config): has_remote_code = hasattr(hf_config, "auto_map") and any( hf_config.architectures[0] in val for val in hf_config.auto_map.values() ) @@ -678,12 +687,11 @@ def get_hf_auto_model_class(hf_config): case _: actor_module_class = AutoModel else: - if type(hf_config) in AutoModelForVision2Seq._model_mapping.keys(): - actor_module_class = AutoModelForVision2Seq - elif type(hf_config) in AutoModelForCausalLM._model_mapping.keys(): - actor_module_class = AutoModelForCausalLM - else: - actor_module_class = AutoModel + actor_module_class = AutoModel + for key, cls in _architecture_to_auto_class.items(): + if key in hf_config.architectures[0]: + actor_module_class = cls + break return actor_module_class diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py index e0acdc70342..af6199732b7 100644 --- a/verl/workers/config/actor.py +++ b/verl/workers/config/actor.py @@ -119,13 +119,13 @@ class ActorConfig(BaseConfig): profiler: ProfilerConfig = field(default_factory=ProfilerConfig) engine: BaseConfig = field(default_factory=BaseConfig) data_loader_seed = 1 - n: int = 1 # must be override by sampling config + rollout_n: int = 1 # must be override by sampling config model_config: HFModelConfig = field(default_factory=BaseConfig) def __post_init__(self): """Validate actor configuration parameters.""" assert self.strategy != MISSING - assert self.n != MISSING + assert self.rollout_n != MISSING if not self.use_dynamic_bsz: if self.ppo_micro_batch_size is not None and self.ppo_micro_batch_size_per_gpu is not None: raise ValueError( diff --git a/verl/workers/config/critic.py b/verl/workers/config/critic.py index 0b9891ca1a2..e65eac0cd17 100644 --- a/verl/workers/config/critic.py +++ b/verl/workers/config/critic.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from typing import Optional @@ -22,6 +23,7 @@ from verl.utils.profiler import ProfilerConfig from .engine import FSDPEngineConfig, McoreEngineConfig +from .model import HFModelConfig from .optimizer import OptimizerConfig __all__ = ["CriticConfig", "FSDPCriticConfig", "McoreCriticConfig", "FSDPCriticModelCfg"] @@ -57,6 +59,7 @@ class CriticConfig(BaseConfig): "ppo_micro_batch_size_per_gpu", "ppo_mini_batch_size", "ppo_micro_batch_size", + "model_config", } strategy: str = MISSING @@ -66,20 +69,32 @@ class CriticConfig(BaseConfig): ppo_mini_batch_size: int = 1 use_dynamic_bsz: bool = False ppo_max_token_len_per_gpu: int = 32768 + # deprecate this forward_max_token_len_per_gpu: int = 32768 + ppo_infer_micro_batch_size_per_gpu: Optional[int] = None + ppo_infer_max_token_len_per_gpu: int = 32768 ppo_epochs: int = 1 + data_loader_seed: int = 1 shuffle: bool = True cliprange_value: float = 0.5 loss_agg_mode: str = "token-mean" ppo_micro_batch_size: Optional[int] = None + engine: BaseConfig = field(default_factory=BaseConfig) optim: OptimizerConfig = field(default_factory=OptimizerConfig) + # deprecate model to favor model_config model: BaseModelConfig = field(default_factory=BaseModelConfig) + model_config: HFModelConfig = None checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) profiler: ProfilerConfig = field(default_factory=ProfilerConfig) def __post_init__(self): """Validate critic configuration parameters.""" assert self.strategy != MISSING + + if self.model_config is None: + warnings.warn("using model in Critic Config is deprecated, please use model_config instead", stacklevel=2) + self.model_config = self.model + if not self.use_dynamic_bsz: self._check_mutually_exclusive(self.ppo_micro_batch_size, self.ppo_micro_batch_size_per_gpu, "critic") diff --git a/verl/workers/config/model.py b/verl/workers/config/model.py index f95c477f3f3..59baf914458 100644 --- a/verl/workers/config/model.py +++ b/verl/workers/config/model.py @@ -50,6 +50,9 @@ class HFModelConfig(BaseConfig): tokenizer_path: Optional[str] = None local_tokenizer_path: Optional[str] = None + # whether to load tokenizer. This is useful when we only want to load model config + load_tokenizer: bool = True + hf_config: Any = None generation_config: Any = None tokenizer: Any = None @@ -95,9 +98,10 @@ def __post_init__(self): self.local_path = copy_to_local(self.path, use_shm=self.use_shm) # constuct tokenizer - self.local_tokenizer_path = copy_to_local(self.tokenizer_path, use_shm=self.use_shm) - self.tokenizer = hf_tokenizer(self.local_tokenizer_path, trust_remote_code=self.trust_remote_code) - self.processor = hf_processor(self.local_tokenizer_path, trust_remote_code=self.trust_remote_code) + if self.load_tokenizer: + self.local_tokenizer_path = copy_to_local(self.tokenizer_path, use_shm=self.use_shm) + self.tokenizer = hf_tokenizer(self.local_tokenizer_path, trust_remote_code=self.trust_remote_code) + self.processor = hf_processor(self.local_tokenizer_path, trust_remote_code=self.trust_remote_code) if self.custom_chat_template is not None: if self.processor is not None: @@ -115,11 +119,17 @@ def __post_init__(self): self.hf_config = AutoConfig.from_pretrained( self.local_hf_config_path, trust_remote_code=self.trust_remote_code, attn_implementation=attn_implementation ) - override_config_kwargs = { - "bos_token_id": self.tokenizer.bos_token_id, - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - } + + override_config_kwargs = {} + + if self.tokenizer is not None: + override_config_kwargs.update( + { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + ) override_config_kwargs.update(self.override_config) update_model_config(self.hf_config, override_config_kwargs=override_config_kwargs) diff --git a/verl/workers/engine/fsdp/transformer_impl.py b/verl/workers/engine/fsdp/transformer_impl.py index 161a6a5216d..46430d26741 100644 --- a/verl/workers/engine/fsdp/transformer_impl.py +++ b/verl/workers/engine/fsdp/transformer_impl.py @@ -878,3 +878,45 @@ def forward_step(self, micro_batch: TensorDict, loss_function, forward_only): } return loss, output + + +@EngineRegistry.register(model_type="value_model", backend=["fsdp", "fsdp2"]) +class FSDPEngineWithValueHead(FSDPEngineWithLMHead): + """ + The only difference between critic and actor is how the raw model output is processed + """ + + def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict): + use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True) + response_length = micro_batch["responses"].size(-1) + + if use_remove_padding: + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + + if hasattr(self.module, "v_head"): + # For trl.AutoModelForCausalLMWithValueHead + values_rmpad = output[2].squeeze(0).unsqueeze(-1) + else: + values_rmpad = output.logits + values_rmpad = values_rmpad.squeeze(0) # (total_nnz) + + indices = output_args["indices"] + + # gather output if sp > 1 + if self.use_ulysses_sp: + pad_size = output_args["pad_size"] + values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size) + + # pad it back + values = pad_input(values_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) + values = values[:, -response_length - 1 : -1] + else: + if hasattr(self.module, "v_head"): + # For trl.AutoModelForCausalLMWithValueHead + values = output[2] + else: + values = output.logits + values = values[:, -response_length - 1 : -1].squeeze(-1) + + return {"values": values} diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 1f1250207ff..64778c4a58e 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -138,7 +138,7 @@ def _build_megatron_module(self): override_model_config=self.engine_config.override_mcore_model_config, override_ddp_config=self.engine_config.override_ddp_config, ) - print(f"actor_module: {len(module)}") + print(f"module: {len(module)}") if self.engine_config.use_dist_checkpointing: load_mcore_dist_weights(module, self.engine_config.dist_checkpointing_path, is_value_model=is_value_model) @@ -473,15 +473,9 @@ def __exit__(self, exc_type, exc_value, traceback): @EngineRegistry.register(model_type="language_model", backend="megatron") class MegatronEngineWithLMHead(MegatronEngine): - def forward_step(self, batch_iter: Iterator[TensorDict], model, postprocess_micro_batch_func): - batch: TensorDict = next(batch_iter) + def prepare_model_inputs(self, batch: TensorDict): batch = batch.to(get_device_id()) batch = batch.contiguous() - - use_fused_kernels = tu.get_non_tensor_data(batch, key="use_fused_kernels", default=False) - calculate_entropy = tu.get_non_tensor_data(batch, key="calculate_entropy", default=False) - temperature = batch["temperature"] - input_ids = batch["input_ids"] attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] @@ -508,6 +502,39 @@ def forward_step(self, batch_iter: Iterator[TensorDict], model, postprocess_micr multi_modal_inputs[key] = torch.cat( [mmi[idx].get(key).to(input_ids.device) for idx in idxs if mmi[idx].get(key) is not None], dim=0 ) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "multi_modal_inputs": multi_modal_inputs, + } + + def prepare_model_outputs(self, output: dict, data: TensorDict): + calculate_entropy = tu.get_non_tensor_data(data, key="calculate_entropy", default=False) + responses = data["responses"] + response_length = responses.size(1) + + log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous() + model_output = {"log_probs": log_prob} + if calculate_entropy: + entropy = output["entropy"][:, -response_length - 1 : -1].contiguous() + model_output["entropy"] = entropy + + return model_output + + def forward_step(self, batch_iter: Iterator[TensorDict], model, postprocess_micro_batch_func): + batch: TensorDict = next(batch_iter) + use_fused_kernels = tu.get_non_tensor_data(batch, key="use_fused_kernels", default=False) + calculate_entropy = tu.get_non_tensor_data(batch, key="calculate_entropy", default=False) + temperature = batch["temperature"] + + model_inputs = self.prepare_model_inputs(batch) + input_ids = model_inputs["input_ids"] + attention_mask = model_inputs["attention_mask"] + position_ids = model_inputs["position_ids"] + multi_modal_inputs = model_inputs["multi_modal_inputs"] + responses = batch["responses"] response_length = responses.size(1) label = position_ids.clone() @@ -574,18 +601,9 @@ def logits_processor(logits, label, label_mask): def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: bool, loss_function): # For memory efficiency # We move calculation of entropy to compute_log_probs, forward_only == True - calculate_entropy = tu.get_non_tensor_data(data, key="calculate_entropy", default=False) - - device = output["log_probs"].device - - responses = data["responses"] - response_length = responses.size(1) - - log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous() - model_output = {"log_probs": log_prob} - if calculate_entropy: - entropy = output["entropy"][:, -response_length - 1 : -1].contiguous() - model_output["entropy"] = entropy + device = data["input_ids"].device + local_micro_bsz = data["input_ids"].shape[0] + model_output = self.prepare_model_outputs(output, data) if loss_function is not None: loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group()) @@ -593,9 +611,7 @@ def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: b # by n_micro_batch and cp size inside pp schedule n_micro_batch = data["num_micro_batch"] loss = loss * n_micro_batch / mpu.get_context_parallel_world_size() - global_bsz = data["global_batch_size"] - local_micro_bsz = responses.shape[0] loss_scale_factor = local_micro_bsz / (global_bsz / self.get_data_parallel_size()) loss = loss * loss_scale_factor else: @@ -613,6 +629,35 @@ def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: b return loss, output -class MegatronEngineWithValueHead(MegatronEngine): +@EngineRegistry.register(model_type="value_model", backend="megatron") +class MegatronEngineWithValueHead(MegatronEngineWithLMHead): # for value head - pass + def forward_step(self, batch_iter, model, postprocess_micro_batch_func): + batch: TensorDict = next(batch_iter) + model_inputs = self.prepare_model_inputs(batch) + input_ids = model_inputs["input_ids"] + attention_mask = model_inputs["attention_mask"] + position_ids = model_inputs["position_ids"] + multi_modal_inputs = model_inputs["multi_modal_inputs"] + + from verl.models.mcore import get_mcore_forward_fn + + forward_fn = get_mcore_forward_fn(self.model_config.hf_config) + + output = forward_fn( + model, + input_ids, + attention_mask, + position_ids, + sequence_parallel=self.tf_config.sequence_parallel, + multi_modal_inputs=multi_modal_inputs, + value_model=True, + ) + + return output, partial(postprocess_micro_batch_func, data=batch) + + def prepare_model_outputs(self, output: dict | torch.Tensor, data: TensorDict): + responses = data["responses"] + response_length = responses.size(1) + output = output[:, -response_length - 1 : -1].contiguous() + return {"values": output} diff --git a/verl/workers/roles/actor.py b/verl/workers/roles/actor.py index 8ca1b51475a..6f513474477 100644 --- a/verl/workers/roles/actor.py +++ b/verl/workers/roles/actor.py @@ -61,31 +61,21 @@ def __init__(self, config: ActorConfig): self.loss_fn = partial(ppo_loss, config=self.config) def _build_engine(self): - model_config = self.config.model_config - engine_config = self.config.engine - optimizer_config = self.config.optim - checkpoint_config = self.config.checkpoint - - if self.config.strategy == "megatron": - from verl.workers.engine.megatron.transformer_impl import MegatronEngineWithLMHead - - self.engine = MegatronEngineWithLMHead( - model_config=model_config, - engine_config=engine_config, - optimizer_config=optimizer_config, - checkpoint_config=checkpoint_config, - ) - elif self.config.strategy in ["fsdp", "fsdp2"]: - from verl.workers.engine.fsdp.transformer_impl import FSDPEngineWithLMHead - - self.engine = FSDPEngineWithLMHead( - model_config=model_config, - engine_config=engine_config, - optimizer_config=optimizer_config, - checkpoint_config=checkpoint_config, - ) - else: - raise ValueError(f"Unknown strategy {self.config.strategy}") + self.model_config = self.config.model_config + self.engine_config = self.config.engine + self.optimizer_config = self.config.optim + self.checkpoint_config = self.config.checkpoint + + from verl.workers.engine import BaseEngine, EngineRegistry + + self.engine: BaseEngine = EngineRegistry.new( + model_type="language_model", + backend=self.config.strategy, + model_config=self.model_config, + engine_config=self.engine_config, + optimizer_config=self.optimizer_config, + checkpoint_config=self.checkpoint_config, + ) # build dispatch info self._register_dispatch_collect_info( @@ -95,14 +85,14 @@ def _build_engine(self): ) # aggregate with bon sampling - self.ppo_mini_batch_size = self.config.ppo_mini_batch_size * self.config.n + self.ppo_mini_batch_size = self.config.ppo_mini_batch_size * self.config.rollout_n assert self.ppo_mini_batch_size % self.engine.get_data_parallel_size() == 0, ( f"{self.ppo_mini_batch_size=} is not divisible by {self.engine.get_data_parallel_size()=}" ) self.ppo_mini_batch_size_per_dp = self.ppo_mini_batch_size // self.engine.get_data_parallel_size() # setup flops counter - self.flops_counter = FlopsCounter(model_config.hf_config) + self.flops_counter = FlopsCounter(self.model_config.hf_config) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): @@ -128,9 +118,9 @@ def compute_log_prob(self, data: DataProto): # TODO: make worker API to accept TensorDict as well data = data.to_tensordict() output = self.engine.infer_batch(data) - output = output.get("model_output", {}) - if "log_probs" in output and "entropy" in output: + if self.engine.is_mp_src_rank_with_outputs(): + output = output["model_output"] # in megatron, only last pp contains valid data and returned to the single controller output = DataProto.from_dict( tensors={"old_log_probs": output["log_probs"].float(), "entropy": output["entropy"].float()}, diff --git a/verl/workers/roles/critic.py b/verl/workers/roles/critic.py index 0464c07c41a..ef9a57f0578 100644 --- a/verl/workers/roles/critic.py +++ b/verl/workers/roles/critic.py @@ -11,181 +11,227 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -The main entry point to run the PPO algorithm -""" + import logging import os +import warnings +from functools import partial +from typing import Iterable -import torch +import psutil from codetiming import Timer from verl import DataProto from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register -from verl.trainer.ppo import core_algos -from verl.utils.config import omega_conf_to_dataclass +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register from verl.utils.device import ( get_device_id, - get_nccl_backend, + get_device_name, + get_torch_device, ) -from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig +from verl.utils.distributed import initialize_global_process_group_ray +from verl.utils.flops_counter import FlopsCounter +from verl.utils.profiler import DistProfiler, DistProfilerExtension from verl.utils.py_functional import append_to_dict -from verl.utils.torch_functional import masked_mean -from verl.workers.engine import EngineRegistry +from verl.workers.config import CriticConfig +from verl.workers.roles.utils.losses import value_loss logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +device_name = get_device_name() + class CriticWorker(Worker, DistProfilerExtension): - def __init__(self, config): + """ + This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy + or a hybrid engine based on the config.rollout + """ + + def __init__(self, config: CriticConfig): + self.config = config Worker.__init__(self) - omega_profiler_config = config.get("profiler", {}) - profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) - if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: - tool_config = omega_conf_to_dataclass( - omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) - ) - else: - tool_config = None + self.profiler_config = self.config.profiler + tool_config = self.profiler_config.tool_config DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + self, DistProfiler(rank=self.rank, config=self.profiler_config, tool_config=tool_config) ) - import torch.distributed - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend=get_nccl_backend()) - self.config = config - self.engine = EngineRegistry.new(self.config.strategy, self.config) + initialize_global_process_group_ray(timeout_second=None) + + self.loss_fn = partial(value_loss, config=self.config) + + def _build_engine(self): + from copy import copy, deepcopy + + self.model_config = copy(self.config.model_config) + self.model_config.hf_config = deepcopy(self.config.model_config.hf_config) + self.engine_config = self.config.engine + self.optimizer_config = self.config.optim + self.checkpoint_config = self.config.checkpoint + + from verl.workers.engine import BaseEngine, EngineRegistry + + # replace AutoModelForSequenceClassification to AutoModelForTokenClassification + hf_config = self.model_config.hf_config + + arch = hf_config.architectures[0] + # This logic assumes the critic is a token classification model. + # If the provided model is a CausalLM, we adapt it. + if "ForCausalLM" in arch: + model_name = arch.split("ForCausalLM")[0] + new_arch = f"{model_name}ForTokenClassification" + warnings.warn(f"Implicitly changing critic architecture from '{arch}' to '{new_arch}'", stacklevel=2) + hf_config.architectures[0] = new_arch + elif "ForTokenClassification" not in arch and "ForSequenceClassification" not in arch: + raise ValueError( + f"Unsupported critic architecture: {arch}. " + f"Critic worker expects an architecture suitable for value function estimation, " + f"such as '...ForTokenClassification' or '...ForSequenceClassification'." + ) - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - self.engine.init_model() + # make sure output dropout is 0 + hf_config.classifier_dropout = 0 - def _post_fn_values(self, micro_batch, preds): - response_length = micro_batch["responses"].size(-1) - values = preds[:, -response_length - 1 : -1] + self.engine: BaseEngine = EngineRegistry.new( + model_type="value_model", + backend=self.config.strategy, + model_config=self.model_config, + engine_config=self.engine_config, + optimizer_config=self.optimizer_config, + checkpoint_config=self.checkpoint_config, + ) - use_remove_padding = self.config.model.get("use_remove_padding", False) - if not use_remove_padding: - values = values.squeeze(-1) + # build dispatch info + self._register_dispatch_collect_info( + mesh_name="critic", + dp_rank=self.engine.get_data_parallel_rank(), + is_collect=self.engine.is_mp_src_rank_with_outputs(), + ) + + # aggregate with bon sampling + self.ppo_mini_batch_size = self.config.ppo_mini_batch_size * self.config.rollout_n + assert self.ppo_mini_batch_size % self.engine.get_data_parallel_size() == 0, ( + f"{self.ppo_mini_batch_size=} is not divisible by {self.engine.get_data_parallel_size()=}" + ) + self.ppo_mini_batch_size_per_dp = self.ppo_mini_batch_size // self.engine.get_data_parallel_size() + + # setup flops counter + self.flops_counter = FlopsCounter(self.model_config.hf_config) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + self._build_engine() + self.engine.initialize() - return values, {"values": values.clone().detach()} + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_loss_fn(self, loss_fn): + self.loss_fn = loss_fn - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - @DistProfiler.annotate(color="cyan") + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="blue", role="critic_compute_values") def compute_values(self, data: DataProto): - # Support all hardwares - data = data.to(get_device_id()) - micro_batch_size = self.config.forward_micro_batch_size_per_gpu - data.meta_info["micro_batch_size"] = micro_batch_size - data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz + if self.config.use_dynamic_bsz: + data.meta_info["max_token_len_per_gpu"] = self.config.ppo_infer_max_token_len_per_gpu + else: + data.meta_info["micro_batch_size_per_gpu"] = self.config.ppo_infer_micro_batch_size_per_gpu with self.engine.eval_mode(): - data = self.engine.shard_data(data=data) - output = self.engine.infer_batch(data, post_fn=self._post_fn_values) - response_mask = data.batch["response_mask"] - values = output["values"] * response_mask # Only action tokens have values - output = DataProto.from_dict(tensors={"values": values}) - - output = self.engine.unshard_data(data=output) - output = output.to("cpu") + # TODO: make worker API to accept TensorDict as well + data = data.to_tensordict() + output = self.engine.infer_batch(data) + + if self.engine.is_mp_src_rank_with_outputs(): + # in megatron, only last pp contains valid data and returned to the single controller + output = output["model_output"] + output = DataProto.from_dict( + tensors={"values": output["values"].float()}, + ) + output = output.to("cpu") + return output - def loss_fn( - self, batch: DataProto, vpreds: dict[str, torch.Tensor] - ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: - old_values = batch["values"] - returns = batch["returns"] - response_mask = batch["response_mask"] - micro_batch_metrics = {} - - values, _ = self._post_fn_values(batch, vpreds) - - vf_loss, vf_clipfrac = core_algos.compute_value_loss( - vpreds=values, - values=old_values, - returns=returns, - response_mask=response_mask, - cliprange_value=self.config.cliprange_value, - loss_agg_mode=self.config.loss_agg_mode, - ) - if self.config.use_dynamic_bsz: - # relative to the dynamic bsz - loss = vf_loss * (len(batch) / self.config.ppo_mini_batch_size) - else: - gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - loss = vf_loss / gradient_accumulation + def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: + """Make minibatch iterator for updating the actor + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where + ``sequence_length = prompt_length + response_length`` + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64 + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64 - micro_batch_metrics = { - "critic/vf_loss": vf_loss.detach().item(), - "critic/vf_clipfrac": vf_clipfrac.detach().item(), - "critic/vpred_mean": masked_mean(values, response_mask).detach().item(), - } + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that + responses = input_ids[:, -response_length:] - return loss, micro_batch_metrics + ``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability + of responses. - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - @DistProfiler.annotate(color="pink") + ``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of + responses. + See PPO paper for details. https://arxiv.org/abs/1707.06347 + + Returns: + + """ + # Note that we do not select data here. It's the user's responsibility to select data outside trainer + # it's very important to setup seed here. Otherwise, data in model parallel region can disagree and cause hangs + return data.make_iterator( + mini_batch_size=self.ppo_mini_batch_size_per_dp, + epochs=self.config.ppo_epochs, + seed=self.config.data_loader_seed + self.engine.get_data_parallel_rank(), + dataloader_kwargs={"shuffle": self.config.shuffle}, + ) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="red", role="critic_update") def update_critic(self, data: DataProto): + data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz + if self.config.use_dynamic_bsz: + data.meta_info["max_token_len_per_gpu"] = self.config.ppo_max_token_len_per_gpu + else: + data.meta_info["micro_batch_size_per_gpu"] = self.config.ppo_micro_batch_size_per_gpu + metrics = {} # Support all hardwares data = data.to(get_device_id()) # perform forward computation with self.engine.train_mode(): - data = self.engine.shard_data(data=data) - - with Timer(name="update_critic", logger=None) as timer: - select_keys = [ - "input_ids", - "responses", - "response_mask", - "attention_mask", - "position_ids", - "values", - "returns", - ] - batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - - # Split to make minibatch iterator for updating the actor - # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) - else: - dataloader = batch.split(self.config.ppo_mini_batch_size) - - for epoch in range(self.config.ppo_epochs): - for batch_idx, mini_batch in enumerate(dataloader): - self.engine.optimizer_zero_grad() - mini_batch_metrics = self.engine.train_batch(mini_batch, self.loss_fn) - grad_norm = self.engine.optimizer_step() - mini_batch_metrics["critic/grad_norm"] = grad_norm.detach().item() - append_to_dict(metrics, mini_batch_metrics) - self.engine.optimizer_zero_grad() + dataloader = self._make_minibatch_iterator(data) + with Timer(name="update_policy", logger=None) as timer: + for batch_idx, mini_batch in enumerate(dataloader): + mini_batch.meta_info["global_batch_size"] = self.config.ppo_mini_batch_size + # TODO: make worker API to accept TensorDict as well + mini_batch = mini_batch.to_tensordict() + output = self.engine.train_batch(mini_batch, self.loss_fn) + mini_batch_metrics = output.get("metrics", {}) + append_to_dict(metrics, mini_batch_metrics, prefix="critic/") + delta_time = timer.last - # TODO: should not access engine's flops_counter global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.engine.flops_counter.estimate_flops(global_num_tokens, delta_time) + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + + lr = self.engine.lr_scheduler_step() + metrics["critic/lr"] = lr - metrics["critic/lr"] = self.engine.lr_scheduler_step()[0] output = DataProto(batch=None, meta_info={"metrics": metrics}) - output = self.engine.unshard_data(data=output) - output = output.to("cpu") return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): - self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) + return self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): - self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load) + def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): + return self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load) diff --git a/verl/workers/roles/utils/losses.py b/verl/workers/roles/utils/losses.py index 44c62be250f..e1f1f6a13c8 100644 --- a/verl/workers/roles/utils/losses.py +++ b/verl/workers/roles/utils/losses.py @@ -16,8 +16,9 @@ import torch from tensordict import TensorDict -from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty -from verl.workers.config import ActorConfig +from verl.trainer.ppo.core_algos import agg_loss, compute_value_loss, get_policy_loss_fn, kl_penalty +from verl.utils.torch_functional import masked_mean +from verl.workers.config import ActorConfig, CriticConfig def sft_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None): @@ -80,3 +81,33 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None) metrics["kl_coef"] = config.kl_loss_coef return policy_loss, metrics + + +def value_loss(config: CriticConfig, model_output, data: TensorDict, dp_group=None): + vpreds = model_output["values"] + values = data["values"] + + values = data["values"] + returns = data["returns"] + response_mask = data["response_mask"].to(bool) + + vf_loss, vf_clipfrac = compute_value_loss( + vpreds=vpreds, + values=values, + returns=returns, + response_mask=response_mask, + cliprange_value=config.cliprange_value, + loss_agg_mode=config.loss_agg_mode, + ) + + metrics = {} + + metrics.update( + { + "critic/vf_loss": vf_loss.detach().item(), + "critic/vf_clipfrac": vf_clipfrac.detach().item(), + "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), + } + ) + + return vf_loss, metrics