|
16 | 16 |
|
17 | 17 |
|
18 | 18 | import torch |
| 19 | +import torch.nn.functional as F |
19 | 20 | import torch_npu |
| 21 | +import transformers |
20 | 22 | from torch_npu import npu_rotary_mul as apply_rotary_emb |
21 | 23 | from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl |
22 | 24 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm |
@@ -46,5 +48,112 @@ def rms_norm_forward(self, x): |
46 | 48 | return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0] |
47 | 49 |
|
48 | 50 |
|
| 51 | +def apply_rotary_pos_emb_qwen3_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| 52 | + cos = cos.unsqueeze(unsqueeze_dim) |
| 53 | + sin = sin.unsqueeze(unsqueeze_dim) |
| 54 | + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) |
| 55 | + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) |
| 56 | + return q_embed.to(q.dtype), k_embed.to(k.dtype) |
| 57 | + |
| 58 | + |
| 59 | +class GmmFunction(torch.autograd.Function): |
| 60 | + @staticmethod |
| 61 | + def forward(ctx, x, weight, group_list, split_size): |
| 62 | + ctx.save_for_backward(x, weight) |
| 63 | + ctx.group_list = group_list |
| 64 | + ctx.split_size = split_size |
| 65 | + |
| 66 | + outputs = torch_npu.npu_grouped_matmul([x], [weight], group_list=group_list, group_type=0, split_item=2) |
| 67 | + return outputs[0] |
| 68 | + |
| 69 | + @staticmethod |
| 70 | + def backward(ctx, grad_outputs): |
| 71 | + x, weight = ctx.saved_tensors |
| 72 | + group_list = ctx.group_list |
| 73 | + wt = weight.permute(0, 2, 1) |
| 74 | + xt = x.permute(1, 0) |
| 75 | + dx = torch_npu.npu_grouped_matmul([grad_outputs], [wt], group_list=group_list, group_type=0, split_item=2) |
| 76 | + dw = torch.zeros_like(weight) |
| 77 | + split_size = ctx.split_size |
| 78 | + xt_list = torch.split(xt, split_size, dim=1) |
| 79 | + grad_outputs_list = torch.split(grad_outputs, split_size, dim=0) |
| 80 | + with torch.npu.amp.autocast(enabled=False): |
| 81 | + dw = torch.stack([torch.matmul(xt_list[i], grad_outputs_list[i]) for i in range(len(xt_list))]) |
| 82 | + |
| 83 | + return dx[0], dw, None, None |
| 84 | + |
| 85 | + |
| 86 | +def moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 87 | + """ """ |
| 88 | + batch_size, sequence_length, hidden_dim = hidden_states.shape |
| 89 | + hidden_states = hidden_states.view(-1, hidden_dim) |
| 90 | + # router_logits: (batch * sequence_length, n_experts) |
| 91 | + router_logits = self.gate(hidden_states) |
| 92 | + |
| 93 | + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) |
| 94 | + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) |
| 95 | + if self.norm_topk_prob: # only diff with mixtral sparse moe block! |
| 96 | + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) |
| 97 | + # we cast back to the input dtype |
| 98 | + routing_weights = routing_weights.to(hidden_states.dtype) |
| 99 | + |
| 100 | + final_hidden_states = torch.zeros( |
| 101 | + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device |
| 102 | + ) |
| 103 | + |
| 104 | + # One hot encode the selected experts to create an expert mask |
| 105 | + # this will be used to easily index which expert is going to be sollicitated |
| 106 | + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) |
| 107 | + |
| 108 | + # Loop over all available experts in the model and perform the computation on each expert |
| 109 | + # Concat all weights |
| 110 | + input_dtype = hidden_states.dtype |
| 111 | + up_weight_list = [e.up_proj.weight.t().to(input_dtype) for e in self.experts] |
| 112 | + gate_weight_list = [e.gate_proj.weight.t().to(input_dtype) for e in self.experts] |
| 113 | + down_weight_list = [e.down_proj.weight.t().to(input_dtype) for e in self.experts] |
| 114 | + w1 = torch.stack(up_weight_list) |
| 115 | + w2 = torch.stack(gate_weight_list) |
| 116 | + w3 = torch.stack(down_weight_list) |
| 117 | + |
| 118 | + # Copied from mindspeed moe_utils.py:permute |
| 119 | + routing_map = selected_experts |
| 120 | + flatten_indices = routing_map.view(-1) |
| 121 | + sorted_indices = torch.sort(flatten_indices.float(), stable=True)[1] |
| 122 | + permuted_tokens = hidden_states.index_select(0, sorted_indices // self.top_k) |
| 123 | + |
| 124 | + tokens_per_experts = torch.sum(expert_mask, dim=(1, 2)) |
| 125 | + group_list = torch.cumsum(tokens_per_experts, dim=0) |
| 126 | + |
| 127 | + cpu_group_list = group_list.to("cpu", non_blocking=False) |
| 128 | + cpu_group_list = [0] + cpu_group_list.tolist() |
| 129 | + split_size = [cpu_group_list[i + 1] - cpu_group_list[i] for i in range(len(cpu_group_list) - 1)] |
| 130 | + |
| 131 | + up_res = GmmFunction.apply(permuted_tokens, w1, group_list, split_size) |
| 132 | + gate_res = GmmFunction.apply(permuted_tokens, w2, group_list, split_size) |
| 133 | + act_res = torch_npu.npu_swiglu(torch.cat([gate_res, up_res], dim=-1)) |
| 134 | + down_res = GmmFunction.apply(act_res, w3, group_list, split_size) |
| 135 | + |
| 136 | + probs = routing_weights |
| 137 | + num_unpermuted_tokens = probs.numel() |
| 138 | + topk = self.top_k |
| 139 | + permuted_tokens = down_res |
| 140 | + |
| 141 | + unpermuted_tokens = torch.zeros( |
| 142 | + [num_unpermuted_tokens, permuted_tokens.shape[-1]], |
| 143 | + dtype=permuted_tokens.dtype, |
| 144 | + device=permuted_tokens.device, |
| 145 | + ) |
| 146 | + unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens) |
| 147 | + unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) |
| 148 | + unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) |
| 149 | + unpermuted_tokens = unpermuted_tokens.sum(dim=1).to(hidden_states.dtype) |
| 150 | + final_hidden_states = unpermuted_tokens |
| 151 | + |
| 152 | + return final_hidden_states, router_logits |
| 153 | + |
| 154 | + |
49 | 155 | Qwen2RMSNorm.forward = rms_norm_forward |
50 | 156 | modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu |
| 157 | +transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm.forward = rms_norm_forward |
| 158 | +transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = moe_block_forward |
| 159 | +transformers.models.qwen3_moe.modeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_qwen3_npu |
0 commit comments