Skip to content

Commit 1fad7ea

Browse files
committed
Add fused kernels for qwen3_moe models in transformers
1 parent 3a394c9 commit 1fad7ea

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

verl/models/transformers/npu_patch.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717

1818
import torch
19+
import torch.nn.functional as F
1920
import torch_npu
21+
import transformers
2022
from torch_npu import npu_rotary_mul as apply_rotary_emb
2123
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
2224
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm
@@ -46,5 +48,112 @@ def rms_norm_forward(self, x):
4648
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0]
4749

4850

51+
def apply_rotary_pos_emb_qwen3_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
52+
cos = cos.unsqueeze(unsqueeze_dim)
53+
sin = sin.unsqueeze(unsqueeze_dim)
54+
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
55+
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
56+
return q_embed.to(q.dtype), k_embed.to(k.dtype)
57+
58+
59+
class GmmFunction(torch.autograd.Function):
60+
@staticmethod
61+
def forward(ctx, x, weight, group_list, split_size):
62+
ctx.save_for_backward(x, weight)
63+
ctx.group_list = group_list
64+
ctx.split_size = split_size
65+
66+
outputs = torch_npu.npu_grouped_matmul([x], [weight], group_list=group_list, group_type=0, split_item=2)
67+
return outputs[0]
68+
69+
@staticmethod
70+
def backward(ctx, grad_outputs):
71+
x, weight = ctx.saved_tensors
72+
group_list = ctx.group_list
73+
wt = weight.permute(0, 2, 1)
74+
xt = x.permute(1, 0)
75+
dx = torch_npu.npu_grouped_matmul([grad_outputs], [wt], group_list=group_list, group_type=0, split_item=2)
76+
dw = torch.zeros_like(weight)
77+
split_size = ctx.split_size
78+
xt_list = torch.split(xt, split_size, dim=1)
79+
grad_outputs_list = torch.split(grad_outputs, split_size, dim=0)
80+
with torch.npu.amp.autocast(enabled=False):
81+
dw = torch.stack([torch.matmul(xt_list[i], grad_outputs_list[i]) for i in range(len(xt_list))])
82+
83+
return dx[0], dw, None, None
84+
85+
86+
def moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
87+
""" """
88+
batch_size, sequence_length, hidden_dim = hidden_states.shape
89+
hidden_states = hidden_states.view(-1, hidden_dim)
90+
# router_logits: (batch * sequence_length, n_experts)
91+
router_logits = self.gate(hidden_states)
92+
93+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
94+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
95+
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
96+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
97+
# we cast back to the input dtype
98+
routing_weights = routing_weights.to(hidden_states.dtype)
99+
100+
final_hidden_states = torch.zeros(
101+
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
102+
)
103+
104+
# One hot encode the selected experts to create an expert mask
105+
# this will be used to easily index which expert is going to be sollicitated
106+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
107+
108+
# Loop over all available experts in the model and perform the computation on each expert
109+
# Concat all weights
110+
input_dtype = hidden_states.dtype
111+
up_weight_list = [e.up_proj.weight.t().to(input_dtype) for e in self.experts]
112+
gate_weight_list = [e.gate_proj.weight.t().to(input_dtype) for e in self.experts]
113+
down_weight_list = [e.down_proj.weight.t().to(input_dtype) for e in self.experts]
114+
w1 = torch.stack(up_weight_list)
115+
w2 = torch.stack(gate_weight_list)
116+
w3 = torch.stack(down_weight_list)
117+
118+
# Copied from mindspeed moe_utils.py:permute
119+
routing_map = selected_experts
120+
flatten_indices = routing_map.view(-1)
121+
sorted_indices = torch.sort(flatten_indices.float(), stable=True)[1]
122+
permuted_tokens = hidden_states.index_select(0, sorted_indices // self.top_k)
123+
124+
tokens_per_experts = torch.sum(expert_mask, dim=(1, 2))
125+
group_list = torch.cumsum(tokens_per_experts, dim=0)
126+
127+
cpu_group_list = group_list.to("cpu", non_blocking=False)
128+
cpu_group_list = [0] + cpu_group_list.tolist()
129+
split_size = [cpu_group_list[i + 1] - cpu_group_list[i] for i in range(len(cpu_group_list) - 1)]
130+
131+
up_res = GmmFunction.apply(permuted_tokens, w1, group_list, split_size)
132+
gate_res = GmmFunction.apply(permuted_tokens, w2, group_list, split_size)
133+
act_res = torch_npu.npu_swiglu(torch.cat([gate_res, up_res], dim=-1))
134+
down_res = GmmFunction.apply(act_res, w3, group_list, split_size)
135+
136+
probs = routing_weights
137+
num_unpermuted_tokens = probs.numel()
138+
topk = self.top_k
139+
permuted_tokens = down_res
140+
141+
unpermuted_tokens = torch.zeros(
142+
[num_unpermuted_tokens, permuted_tokens.shape[-1]],
143+
dtype=permuted_tokens.dtype,
144+
device=permuted_tokens.device,
145+
)
146+
unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
147+
unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
148+
unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
149+
unpermuted_tokens = unpermuted_tokens.sum(dim=1).to(hidden_states.dtype)
150+
final_hidden_states = unpermuted_tokens
151+
152+
return final_hidden_states, router_logits
153+
154+
49155
Qwen2RMSNorm.forward = rms_norm_forward
50156
modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu
157+
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm.forward = rms_norm_forward
158+
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = moe_block_forward
159+
transformers.models.qwen3_moe.modeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_qwen3_npu

0 commit comments

Comments
 (0)