-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Add DeepSeek V3/R1 shared experts fusion #4918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
4b53bc9
953a000
500e3e2
5dac1c2
5cff889
e771349
c69675a
4180d63
128480e
3d223ba
3fd2706
6c33e52
5d892f8
4bf262a
42678e1
3bfa90a
e3c5c3d
2a5af12
2f185f9
3bb4fc1
99abc77
f8c8c70
2a4bc93
9c797af
0261301
cd3782d
3d8a840
2419300
90ee3f8
1df1bee
40dc2c6
fb5b17a
943e986
dab29e2
1bba429
556cba3
64f261e
9a6832a
dbcae93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| # Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py | ||
| import os | ||
| import argparse | ||
| import json | ||
| import time | ||
|
|
@@ -400,6 +401,9 @@ def main(args: argparse.Namespace): | |
| shard_intermediate_size = 2 * intermediate_size // args.tp_size | ||
| elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: | ||
| E = config.n_routed_experts | ||
| n_share_fusion_experts = int(os.getenv("SHARE_EXPERTS_FUSION_REPLICA", "0")) | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if n_share_fusion_experts > 0: | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| E = E + n_share_fusion_experts | ||
|
||
| topk = config.num_experts_per_tok | ||
| intermediate_size = config.moe_intermediate_size | ||
| shard_intermediate_size = 2 * intermediate_size // args.tp_size | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| { | ||
| "1": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 64, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 16, | ||
| "num_warps": 4, | ||
| "num_stages": 4 | ||
| }, | ||
| "2": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 32, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "4": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 64, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 4, | ||
| "num_stages": 4 | ||
| }, | ||
| "8": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 32, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "16": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 16, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "24": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 16, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "32": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 32, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "48": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 32, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "64": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 64, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "96": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 64, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "128": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 16, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "256": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 16, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "512": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 16, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "1024": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 32, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "1536": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 32, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "2048": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 16, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "3072": { | ||
| "BLOCK_SIZE_M": 128, | ||
| "BLOCK_SIZE_N": 64, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 32, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| }, | ||
| "4096": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 64, | ||
| "num_warps": 4, | ||
| "num_stages": 3 | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -102,11 +102,13 @@ def grouped_topk( | |
| renormalize: bool, | ||
| num_expert_group: int = 0, | ||
| topk_group: int = 0, | ||
| share_fusion: int = 0, | ||
|
||
| ): | ||
| assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" | ||
|
|
||
| scores = torch.softmax(gating_output, dim=-1) | ||
| num_token = scores.shape[0] | ||
| num_experts = scores.shape[1] | ||
| group_scores = ( | ||
| scores.view(num_token, num_expert_group, -1).max(dim=-1).values | ||
| ) # [n, n_group] | ||
|
|
@@ -122,9 +124,20 @@ def grouped_topk( | |
| ) # [n, e] | ||
| tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] | ||
| topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) | ||
| if share_fusion: | ||
| topk_ids[:, -1] = torch.randint(low=num_experts, | ||
|
||
| high=num_experts + share_fusion, | ||
fzyzcjy marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| size=(topk_ids.size(0), ), | ||
| dtype=topk_ids.dtype, | ||
| device=topk_ids.device) | ||
| topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) * 1.0 / 2.5 | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if renormalize: | ||
| topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) | ||
| topk_weights_sum = topk_weights.sum( | ||
| dim=-1, | ||
| keepdim=True) if share_fusion == 0 else topk_weights[:, :-1].sum( | ||
| dim=-1, keepdim=True) | ||
|
||
| topk_weights = topk_weights / topk_weights_sum | ||
|
|
||
| return topk_weights.to(torch.float32), topk_ids.to(torch.int32) | ||
|
|
||
|
|
@@ -210,7 +223,7 @@ def select_experts( | |
| correction_bias: Optional[torch.Tensor] = None, | ||
| torch_native: bool = False, | ||
| ): | ||
| # DeekSeekv2 uses grouped_top_k | ||
| # DeekSeek V2/V3/R1 serices models uses grouped_top_k | ||
| if use_grouped_topk: | ||
| assert topk_group is not None | ||
| assert num_expert_group is not None | ||
|
|
@@ -222,6 +235,7 @@ def select_experts( | |
| renormalize=renormalize, | ||
| num_expert_group=num_expert_group, | ||
| topk_group=topk_group, | ||
| share_fusion=int(os.getenv("SHARE_EXPERTS_FUSION_REPLICA", "0")) > 0, | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
| else: | ||
| topk_weights, topk_ids = biased_grouped_topk( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| """Inference-only DeepseekV2 model.""" | ||
|
|
||
| import os | ||
| from tqdm import tqdm | ||
| from typing import Any, Dict, Iterable, Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
@@ -166,6 +167,7 @@ def __init__( | |
| self.tp_size = get_tensor_model_parallel_world_size() | ||
| self.routed_scaling_factor = config.routed_scaling_factor | ||
| self.n_shared_experts = config.n_shared_experts | ||
| self.n_share_fusion_experts = int(os.getenv("SHARE_EXPERTS_FUSION_REPLICA", "0")) | ||
| self.routed_scaling_factor = config.routed_scaling_factor | ||
| if self.tp_size > config.n_routed_experts: | ||
| raise ValueError( | ||
|
|
@@ -187,8 +189,8 @@ def __init__( | |
| else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) | ||
| ) | ||
| self.experts = MoEImpl( | ||
| num_experts=config.n_routed_experts, | ||
| top_k=config.num_experts_per_tok, | ||
| num_experts=config.n_routed_experts + self.n_share_fusion_experts, | ||
| top_k=config.num_experts_per_tok + min(self.n_share_fusion_experts, 1), | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| hidden_size=config.hidden_size, | ||
| intermediate_size=config.moe_intermediate_size, | ||
| renormalize=config.norm_topk_prob, | ||
|
|
@@ -256,8 +258,10 @@ def forward( | |
| return self.forward_deepep(hidden_states, forward_mode) | ||
|
|
||
| def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| if self.n_shared_experts is not None: | ||
| if self.n_shared_experts is not None and self.n_share_fusion_experts == 0: | ||
| shared_output = self.shared_experts(hidden_states) | ||
| else: | ||
| shared_output = None | ||
| # router_logits: (num_tokens, n_experts) | ||
| router_logits = self.gate(hidden_states) | ||
| final_hidden_states = ( | ||
|
|
@@ -1309,6 +1313,7 @@ def __init__( | |
| super().__init__() | ||
| self.config = config | ||
| self.quant_config = quant_config | ||
| self.n_share_fusion_experts = int(os.getenv("SHARE_EXPERTS_FUSION_REPLICA", "0")) | ||
| self.model = DeepseekV2Model( | ||
| config, quant_config, prefix=add_prefix("model", prefix) | ||
| ) | ||
|
|
@@ -1342,7 +1347,33 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | |
| ("gate_up_proj", "gate_proj", 0), | ||
| ("gate_up_proj", "up_proj", 1), | ||
| ] | ||
|
|
||
| if self.n_share_fusion_experts != 0: | ||
| weights_list = list(weights) | ||
| weights_dict = {k: v for (k, v) in weights_list} | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| suffix_list = [ | ||
| 'down_proj.weight', 'down_proj.weight_scale_inv', | ||
| 'gate_proj.weight', 'gate_proj.weight_scale_inv', | ||
| 'up_proj.weight', 'up_proj.weight_scale_inv' | ||
| ] | ||
| current_device = torch.cuda.current_device() | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| is_master = (current_device == 0) | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for moe_layer in tqdm(range(self.config.num_hidden_layers), | ||
| desc=f"Cloning {self.n_share_fusion_experts} " | ||
| "replicas of shared expert into MoE", | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| disable=not is_master): | ||
| if moe_layer < self.config.first_k_dense_replace: | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| continue | ||
| for num_repeat in range(self.n_share_fusion_experts): | ||
| for suffix in suffix_list: | ||
| weights_list.append(( | ||
| f"model.layers.{moe_layer}." | ||
| f"mlp.experts." | ||
| f"{self.config.n_routed_experts + num_repeat}" | ||
| f".{suffix}", weights_dict[ | ||
|
||
| f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"] | ||
| .clone())) | ||
| weights = weights_list | ||
|
|
||
| # Params for weights, fp8 weight scales, fp8 activation scales | ||
| # (param_name, weight_name, expert_id, shard_id) | ||
| MoEImpl = ( | ||
|
|
@@ -1354,7 +1385,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | |
| ckpt_gate_proj_name="gate_proj", | ||
| ckpt_down_proj_name="down_proj", | ||
| ckpt_up_proj_name="up_proj", | ||
| num_experts=self.config.n_routed_experts, | ||
| num_experts=self.config.n_routed_experts + self.n_share_fusion_experts, | ||
| ) | ||
|
|
||
| params_dict = dict(self.named_parameters()) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.