Skip to content

Commit 4e3baf4

Browse files
mgoinpathorn
authored andcommitted
[Misc] Remove duplicated DeepSeek V2/V3 model definition (vllm-project#12793)
1 parent 4078052 commit 4e3baf4

File tree

3 files changed

+36
-667
lines changed

3 files changed

+36
-667
lines changed

vllm/model_executor/models/deepseek_v2.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
22-
"""Inference-only DeepseekV2 model."""
22+
"""Inference-only DeepseekV2/DeepseekV3 model."""
2323
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
2424

2525
import torch
@@ -113,23 +113,32 @@ def __init__(
113113
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
114114
"Only silu is supported for now.")
115115

116-
self.experts = FusedMoE(num_experts=config.n_routed_experts,
117-
top_k=config.num_experts_per_tok,
118-
hidden_size=config.hidden_size,
119-
intermediate_size=config.moe_intermediate_size,
120-
reduce_results=False,
121-
renormalize=config.norm_topk_prob,
122-
quant_config=quant_config,
123-
use_grouped_topk=True,
124-
num_expert_group=config.n_group,
125-
topk_group=config.topk_group,
126-
prefix=f"{prefix}.experts")
127-
128116
self.gate = ReplicatedLinear(config.hidden_size,
129117
config.n_routed_experts,
130118
bias=False,
131119
quant_config=None,
132120
prefix=f"{prefix}.gate")
121+
if config.topk_method == "noaux_tc":
122+
self.gate.e_score_correction_bias = nn.Parameter(
123+
torch.empty(config.n_routed_experts))
124+
else:
125+
self.gate.e_score_correction_bias = None
126+
127+
self.experts = FusedMoE(
128+
num_experts=config.n_routed_experts,
129+
top_k=config.num_experts_per_tok,
130+
hidden_size=config.hidden_size,
131+
intermediate_size=config.moe_intermediate_size,
132+
reduce_results=False,
133+
renormalize=config.norm_topk_prob,
134+
quant_config=quant_config,
135+
use_grouped_topk=True,
136+
num_expert_group=config.n_group,
137+
topk_group=config.topk_group,
138+
prefix=f"{prefix}.experts",
139+
scoring_func=config.scoring_func,
140+
e_score_correction_bias=self.gate.e_score_correction_bias)
141+
133142
if config.n_shared_experts is not None:
134143
intermediate_size = (config.moe_intermediate_size *
135144
config.n_shared_experts)
@@ -579,6 +588,15 @@ def load_weights(self, weights: Iterable[Tuple[str,
579588
for name, loaded_weight in weights:
580589
if "rotary_emb.inv_freq" in name:
581590
continue
591+
592+
# TODO(simon): support nextn predict layers
593+
if hasattr(self.config, "num_nextn_predict_layers"
594+
) and self.config.num_nextn_predict_layers > 0:
595+
assert self.config.num_nextn_predict_layers == 1
596+
layer_idx = self.config.num_hidden_layers
597+
if name.startswith(f"model.layers.{layer_idx}"):
598+
continue
599+
582600
for (param_name, weight_name, shard_id) in stacked_params_mapping:
583601
# Skip non-stacked layers and experts (experts handled below).
584602
if weight_name not in name:
@@ -640,3 +658,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
640658
weight_loader(param, loaded_weight)
641659
loaded_params.add(name)
642660
return loaded_params
661+
662+
663+
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
664+
pass

0 commit comments

Comments
 (0)