diff --git a/verl/models/transformers/npu_patch.py b/verl/models/transformers/npu_patch.py index f9dcf021da6..5d27d7f8487 100644 --- a/verl/models/transformers/npu_patch.py +++ b/verl/models/transformers/npu_patch.py @@ -23,6 +23,7 @@ import torch_npu from torch_npu import npu_rotary_mul as apply_rotary_emb from transformers.modeling_utils import PretrainedConfig, PreTrainedModel +from transformers.models.qwen2 import modeling_qwen2 from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl from transformers.models.qwen3 import modeling_qwen3 from transformers.models.qwen3_moe import modeling_qwen3_moe @@ -34,7 +35,7 @@ # This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in # subsequent versions # https://github.com/huggingface/transformers/pull/38491 -def apply_rotary_pos_emb_flashatt_qwen2_5_vl_npu( +def apply_rotary_pos_emb_flashatt_npu( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() @@ -61,7 +62,7 @@ def silu_forward(self, hidden_state): return self.down_proj(torch_npu.npu_swiglu(gate_up, dim=-1)) -def apply_rotary_pos_emb_qwen3_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = torch_npu.npu_rotary_mul(q, cos, sin) @@ -194,12 +195,15 @@ def _check_and_enable_flash_attn_2( return config +modeling_qwen2.Qwen2RMSNorm.forward = rms_norm_forward +modeling_qwen2.Qwen2MLP.forward = silu_forward +modeling_qwen2.apply_rotary_pos_emb = apply_rotary_pos_emb_npu modeling_qwen2_5_vl.Qwen2RMSNorm.forward = rms_norm_forward modeling_qwen2_5_vl.Qwen2_5_VLMLP.forward = silu_forward -modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_qwen2_5_vl_npu +modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu modeling_qwen3_moe.Qwen3MoeRMSNorm.forward = rms_norm_forward modeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = moe_block_forward -modeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_qwen3_npu +modeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_npu modeling_qwen3.Qwen3RMSNorm.forward = rms_norm_forward modeling_qwen3.Qwen3MLP.forward = silu_forward