diff --git a/docs/advance/one_step_off.md b/docs/advance/one_step_off.md index af7f88c2dc1..af435d3f2cb 100644 --- a/docs/advance/one_step_off.md +++ b/docs/advance/one_step_off.md @@ -192,6 +192,7 @@ def sync_rollout_weights(self): inference_model = ( self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model ) + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader patch_vllm_moe_model_weight_loader(inference_model) # Model parameters are broadcast tensor-by-tensor from actor to rollout for key, shape, dtype in self._weights_info: diff --git a/recipe/one_step_off_policy/README.md b/recipe/one_step_off_policy/README.md index fa63407f4f7..e29ef8d57d6 100644 --- a/recipe/one_step_off_policy/README.md +++ b/recipe/one_step_off_policy/README.md @@ -192,6 +192,7 @@ def sync_rollout_weights(self): inference_model = ( self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model ) + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader patch_vllm_moe_model_weight_loader(inference_model) # Model parameters are broadcast tensor-by-tensor from actor to rollout for key, shape, dtype in self._weights_info: diff --git a/recipe/one_step_off_policy/fsdp_workers.py b/recipe/one_step_off_policy/fsdp_workers.py index 0aa21991708..72036d6057c 100644 --- a/recipe/one_step_off_policy/fsdp_workers.py +++ b/recipe/one_step_off_policy/fsdp_workers.py @@ -38,7 +38,6 @@ ) from verl.utils.import_utils import import_external_libs from verl.utils.model import get_generation_config, update_model_config -from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader from verl.workers.fsdp_workers import ActorRolloutRefWorker as ARRWorker from verl.workers.fsdp_workers import CriticWorker @@ -71,6 +70,8 @@ def sync_rollout_weights(self): inference_model = ( self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model ) + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader + patch_vllm_moe_model_weight_loader(inference_model) for key, shape, dtype in self._weights_info: tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) diff --git a/recipe/one_step_off_policy/megatron_workers.py b/recipe/one_step_off_policy/megatron_workers.py index f7b58405b4f..6d32d4bb27c 100644 --- a/recipe/one_step_off_policy/megatron_workers.py +++ b/recipe/one_step_off_policy/megatron_workers.py @@ -26,7 +26,6 @@ ) from verl.utils.device import get_device_name, get_torch_device from verl.utils.fs import copy_to_local -from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader from verl.workers.megatron_workers import ActorRolloutRefWorker as ARRWorker from verl.workers.megatron_workers import CriticWorker, RewardModelWorker @@ -74,6 +73,8 @@ def sync_rollout_weights(self): inference_model = ( self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model ) + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader + patch_vllm_moe_model_weight_loader(inference_model) for key, shape, dtype in self._weights_info: if self._is_actor: diff --git a/verl/utils/vllm/__init__.py b/verl/utils/vllm/__init__.py new file mode 100644 index 00000000000..00aa7bdb642 --- /dev/null +++ b/verl/utils/vllm/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utils import TensorLoRARequest, VLLMHijack, is_version_ge + +# The contents of vllm/patch.py should not be imported here, because the contents of +# patch.py should be imported after the vllm LLM instance is created. Therefore, +# wait until you actually start using it before importing the contents of +# patch.py separately. + +__all__ = [ + "TensorLoRARequest", + "VLLMHijack", + "is_version_ge", +] diff --git a/verl/utils/vllm/patch.py b/verl/utils/vllm/patch.py new file mode 100644 index 00000000000..23c45dfb6d8 --- /dev/null +++ b/verl/utils/vllm/patch.py @@ -0,0 +1,94 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To support different vLLM versions, we add the model into SUPPORTED_MOE_MODELS separately to avoid triggering +# unsupported issues. +SUPPORTED_MOE_MODELS = [] + +try: + from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM + + SUPPORTED_MOE_MODELS.append(DeepseekV2ForCausalLM) + SUPPORTED_MOE_MODELS.append(DeepseekV3ForCausalLM) +except ImportError: + pass + +try: + from vllm.model_executor.models.mixtral import MixtralForCausalLM + + SUPPORTED_MOE_MODELS.append(MixtralForCausalLM) +except ImportError: + pass + +try: + from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM + + SUPPORTED_MOE_MODELS.append(Qwen2MoeForCausalLM) +except ImportError: + pass + +try: + from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM + + SUPPORTED_MOE_MODELS.append(Qwen3MoeForCausalLM) +except ImportError: + pass + +try: + from vllm.model_executor.models.kimi_vl import KimiVLForConditionalGeneration + + SUPPORTED_MOE_MODELS.append(KimiVLForConditionalGeneration) +except ImportError: + pass + + +def patch_vllm_moe_model_weight_loader(model): + # this is a work around to load the weight of vllm fused moe model + # it is from a bug from vllm 0.8.2 + # all the weights are supposed to have a weight_loader, but the moe weights + # do not have a weight_loader, so we need to patch it + # (True, 'model.embed_tokens.weight') + # (True, 'model.layers.0.self_attn.qkv_proj.weight') + # (True, 'model.layers.0.self_attn.qkv_proj.bias') + # (True, 'model.layers.0.self_attn.o_proj.weight') + # (True, 'model.layers.0.mlp.gate.weight') + # (True, 'model.layers.0.mlp.shared_expert.gate_up_proj.weight') + # (True, 'model.layers.0.mlp.shared_expert.down_proj.weight') + # (False, 'model.layers.0.mlp.shared_expert_gate.weight') use default + # (False, 'model.layers.0.input_layernorm.weight') use default + # (False, 'model.layers.0.post_attention_layernorm.weight') use default + # (False, 'model.layers.0.mlp.experts.w13_weight') use mlp.experts.weight_loader + # (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader + + # Define MLP attribute mapping for different model types + MLP_ATTR_MAPPING = { + MixtralForCausalLM: "block_sparse_moe", + } + DEFAULT_MLP_ATTR = "mlp" + + if not isinstance(model, tuple(SUPPORTED_MOE_MODELS)): + return + + model = getattr(model, "model", None) or getattr(model, "language_model", None) + if model is None: + raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.") + + for layer in model.layers: + mlp_attr = MLP_ATTR_MAPPING.get(type(model), DEFAULT_MLP_ATTR) + mlp = getattr(layer, mlp_attr) + + param_dict = dict(mlp.named_parameters()) + for name, param in param_dict.items(): + if "w13_weight" in name or "w2_weight" in name: + param.weight_loader = mlp.experts.weight_loader diff --git a/verl/utils/vllm_utils.py b/verl/utils/vllm/utils.py similarity index 65% rename from verl/utils/vllm_utils.py rename to verl/utils/vllm/utils.py index 25ee6656dbe..acf24398077 100644 --- a/verl/utils/vllm_utils.py +++ b/verl/utils/vllm/utils.py @@ -22,87 +22,6 @@ from verl.third_party.vllm import get_version -# To support different vLLM versions, we add the model into SUPPORTED_MOE_MODELS separately to avoid triggering -# unsupported issues. -SUPPORTED_MOE_MODELS = [] - -try: - from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM - - SUPPORTED_MOE_MODELS.append(DeepseekV2ForCausalLM) - SUPPORTED_MOE_MODELS.append(DeepseekV3ForCausalLM) -except ImportError: - pass - -try: - from vllm.model_executor.models.mixtral import MixtralForCausalLM - - SUPPORTED_MOE_MODELS.append(MixtralForCausalLM) -except ImportError: - pass - -try: - from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM - - SUPPORTED_MOE_MODELS.append(Qwen2MoeForCausalLM) -except ImportError: - pass - -try: - from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM - - SUPPORTED_MOE_MODELS.append(Qwen3MoeForCausalLM) -except ImportError: - pass - -try: - from vllm.model_executor.models.kimi_vl import KimiVLForConditionalGeneration - - SUPPORTED_MOE_MODELS.append(KimiVLForConditionalGeneration) -except ImportError: - pass - - -def patch_vllm_moe_model_weight_loader(model): - # this is a work around to load the weight of vllm fused moe model - # it is from a bug from vllm 0.8.2 - # all the weights are supposed to have a weight_loader, but the moe weights - # do not have a weight_loader, so we need to patch it - # (True, 'model.embed_tokens.weight') - # (True, 'model.layers.0.self_attn.qkv_proj.weight') - # (True, 'model.layers.0.self_attn.qkv_proj.bias') - # (True, 'model.layers.0.self_attn.o_proj.weight') - # (True, 'model.layers.0.mlp.gate.weight') - # (True, 'model.layers.0.mlp.shared_expert.gate_up_proj.weight') - # (True, 'model.layers.0.mlp.shared_expert.down_proj.weight') - # (False, 'model.layers.0.mlp.shared_expert_gate.weight') use default - # (False, 'model.layers.0.input_layernorm.weight') use default - # (False, 'model.layers.0.post_attention_layernorm.weight') use default - # (False, 'model.layers.0.mlp.experts.w13_weight') use mlp.experts.weight_loader - # (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader - - # Define MLP attribute mapping for different model types - MLP_ATTR_MAPPING = { - MixtralForCausalLM: "block_sparse_moe", - } - DEFAULT_MLP_ATTR = "mlp" - - if not isinstance(model, tuple(SUPPORTED_MOE_MODELS)): - return - - model = getattr(model, "model", None) or getattr(model, "language_model", None) - if model is None: - raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.") - - for layer in model.layers: - mlp_attr = MLP_ATTR_MAPPING.get(type(model), DEFAULT_MLP_ATTR) - mlp = getattr(layer, mlp_attr) - - param_dict = dict(mlp.named_parameters()) - for name, param in param_dict.items(): - if "w13_weight" in name or "w2_weight" in name: - param.weight_loader = mlp.experts.weight_loader - class TensorLoRARequest(LoRARequest): peft_config: dict = field(default=None) diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 1a9677df531..c9b163a0692 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -44,7 +44,7 @@ from verl.utils.model import check_exclude_modules, check_target_modules, convert_weight_keys from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from verl.utils.torch_functional import check_device_is_available -from verl.utils.vllm_utils import TensorLoRARequest, VLLMHijack, is_version_ge, patch_vllm_moe_model_weight_loader +from verl.utils.vllm import TensorLoRARequest, VLLMHijack, is_version_ge from .base import BaseShardingManager @@ -329,6 +329,8 @@ def replace_lora_wrapper(k): updated_params = {replace_lora_wrapper(k): v for k, v in updated_params.items()} + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader + patch_vllm_moe_model_weight_loader(model) device = get_device_id() # used when fsdp2 set cpu_offload_policy loaded_params = model.load_weights( diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index d0b09085519..2c8cf4b4c1d 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -36,7 +36,6 @@ from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage from verl.utils.profiler.performance import simple_timer from verl.utils.torch_functional import check_device_is_available -from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader from .base import BaseShardingManager @@ -166,6 +165,8 @@ def __enter__(self): self.layer_name_mapping, ) model = self.model_runner.model + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader + patch_vllm_moe_model_weight_loader(model) loaded_params = model.load_weights(per_tensor_param) info = f"vLLM load weights, loaded_params: {len(loaded_params)}"