Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 112 additions & 4 deletions verl/models/transformers/npu_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@


import torch
import torch.nn.functional as F
import torch_npu
from torch_npu import npu_rotary_mul as apply_rotary_emb
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm
from transformers.models.qwen3_moe import modeling_qwen3_moe


# This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in
# subsequent versions
# https://github.com/huggingface/transformers/pull/38491
def apply_rotary_pos_emb_flashatt_npu(
def apply_rotary_pos_emb_flashatt_qwen2_5_vl_npu(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
Expand All @@ -46,5 +47,112 @@ def rms_norm_forward(self, x):
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0]


Qwen2RMSNorm.forward = rms_norm_forward
modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu
def apply_rotary_pos_emb_qwen3_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed.to(q.dtype), k_embed.to(k.dtype)


class GmmFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, group_list, split_size):
ctx.save_for_backward(x, weight)
ctx.group_list = group_list
ctx.split_size = split_size

outputs = torch_npu.npu_grouped_matmul([x], [weight], group_list=group_list, group_type=0, split_item=2)
return outputs[0]

@staticmethod
def backward(ctx, grad_outputs):
x, weight = ctx.saved_tensors
group_list = ctx.group_list
wt = weight.permute(0, 2, 1)
xt = x.permute(1, 0)
dx = torch_npu.npu_grouped_matmul([grad_outputs], [wt], group_list=group_list, group_type=0, split_item=2)
dw = torch.zeros_like(weight)
split_size = ctx.split_size
xt_list = torch.split(xt, split_size, dim=1)
grad_outputs_list = torch.split(grad_outputs, split_size, dim=0)
with torch.npu.amp.autocast(enabled=False):
dw = torch.stack([torch.matmul(xt_list[i], grad_outputs_list[i]) for i in range(len(xt_list))])

return dx[0], dw, None, None


def moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
# Concat all weights
input_dtype = hidden_states.dtype
up_weight_list = [e.up_proj.weight.t().to(input_dtype) for e in self.experts]
gate_weight_list = [e.gate_proj.weight.t().to(input_dtype) for e in self.experts]
down_weight_list = [e.down_proj.weight.t().to(input_dtype) for e in self.experts]
w1 = torch.stack(up_weight_list)
w2 = torch.stack(gate_weight_list)
w3 = torch.stack(down_weight_list)

# Copied from mindspeed moe_utils.py:permute
routing_map = selected_experts
flatten_indices = routing_map.view(-1)
sorted_indices = torch.sort(flatten_indices.float(), stable=True)[1]
permuted_tokens = hidden_states.index_select(0, sorted_indices // self.top_k)

tokens_per_experts = torch.sum(expert_mask, dim=(1, 2))
group_list = torch.cumsum(tokens_per_experts, dim=0)

cpu_group_list = group_list.to("cpu", non_blocking=False)
cpu_group_list = [0] + cpu_group_list.tolist()
split_size = [cpu_group_list[i + 1] - cpu_group_list[i] for i in range(len(cpu_group_list) - 1)]

up_res = GmmFunction.apply(permuted_tokens, w1, group_list, split_size)
gate_res = GmmFunction.apply(permuted_tokens, w2, group_list, split_size)
act_res = torch_npu.npu_swiglu(torch.cat([gate_res, up_res], dim=-1))
down_res = GmmFunction.apply(act_res, w3, group_list, split_size)

probs = routing_weights
num_unpermuted_tokens = probs.numel()
topk = self.top_k
permuted_tokens = down_res

unpermuted_tokens = torch.zeros(
[num_unpermuted_tokens, permuted_tokens.shape[-1]],
dtype=permuted_tokens.dtype,
device=permuted_tokens.device,
)
unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
unpermuted_tokens = unpermuted_tokens.sum(dim=1).to(hidden_states.dtype)
final_hidden_states = unpermuted_tokens

return final_hidden_states, router_logits


modeling_qwen2_5_vl.Qwen2RMSNorm.forward = rms_norm_forward
modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_qwen2_5_vl_npu
modeling_qwen3_moe.Qwen3MoeRMSNorm.forward = rms_norm_forward
modeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = moe_block_forward
modeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_qwen3_npu