Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions verl/models/transformers/npu_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch_npu
from torch_npu import npu_rotary_mul as apply_rotary_emb
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
from transformers.models.qwen3 import modeling_qwen3
from transformers.models.qwen3_moe import modeling_qwen3_moe
Expand All @@ -34,7 +35,7 @@
# This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in
# subsequent versions
# https://github.com/huggingface/transformers/pull/38491
def apply_rotary_pos_emb_flashatt_qwen2_5_vl_npu(
def apply_rotary_pos_emb_flashatt_npu(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
Expand All @@ -61,7 +62,7 @@ def silu_forward(self, hidden_state):
return self.down_proj(torch_npu.npu_swiglu(gate_up, dim=-1))


def apply_rotary_pos_emb_qwen3_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
Expand Down Expand Up @@ -194,12 +195,15 @@ def _check_and_enable_flash_attn_2(
return config


modeling_qwen2.Qwen2RMSNorm.forward = rms_norm_forward
modeling_qwen2.Qwen2MLP.forward = silu_forward
modeling_qwen2.apply_rotary_pos_emb = apply_rotary_pos_emb_npu
modeling_qwen2_5_vl.Qwen2RMSNorm.forward = rms_norm_forward
modeling_qwen2_5_vl.Qwen2_5_VLMLP.forward = silu_forward
modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_qwen2_5_vl_npu
modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu
modeling_qwen3_moe.Qwen3MoeRMSNorm.forward = rms_norm_forward
modeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = moe_block_forward
modeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_qwen3_npu
modeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_npu
modeling_qwen3.Qwen3RMSNorm.forward = rms_norm_forward
modeling_qwen3.Qwen3MLP.forward = silu_forward

Expand Down