Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 148 additions & 127 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""

import concurrent.futures
import logging
import os
from enum import IntEnum, auto
Expand Down Expand Up @@ -2436,154 +2437,174 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
assert self.num_fused_shared_experts == 1
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")

params_dict = dict(self.named_parameters())
weight_names = []
for name, loaded_weight in weights:
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
params_dict = dict(self.named_parameters())
weight_names = []
for name, loaded_weight in weights:
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)

weight_names.append(name)
weight_names.append(name)

if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
if num_nextn_layers > 0 and name.startswith("model.layers"):
name_list = name.split(".")
if (
len(name_list) >= 3
and int(name_list[2]) >= self.config.num_hidden_layers
):
continue
else:
if not name.startswith(nextn_layer_prefix):
continue
if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
if num_nextn_layers > 0 and name.startswith("model.layers"):
name_list = name.split(".")
if (
len(name_list) >= 3
and int(name_list[2]) >= self.config.num_hidden_layers
):
continue
else:
if not name.startswith(nextn_layer_prefix):
continue

# Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name:
continue
# Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name:
continue

is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")

if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")

if "rotary_emb.inv_freq" in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
futures.append(
executor.submit(weight_loader, param, loaded_weight, shard_id)
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
futures.append(
executor.submit(
weight_loader,
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
)

# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
cat_dim = 0
if self.quant_config is not None and (
self.quant_config.get_name() == "awq"
or self.quant_config.get_name() == "moe_wna16"
):
cat_dim = 1
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
)
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
param = params_dict[param_name]

# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
cat_dim = 0
if self.quant_config is not None and (
self.quant_config.get_name() == "awq"
or self.quant_config.get_name() == "moe_wna16"
):
cat_dim = 1
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
)
param_name = (
name.replace(
"q_a_proj", "fused_qkv_a_proj_with_mqa"
)
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa",
"fused_qkv_a_proj_with_mqa",
)
)
param = params_dict[param_name]

weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
futures.append(
executor.submit(weight_loader, param, fused_weight)
)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
for scale in ["k_scale", "v_scale"]:
if scale in name:
name = name.replace(
f"{scale[0]}_proj", "attn_mqa"
)
break
if name not in params_dict:
# modelopt ckpt contains not needed weights for MTP module:
# model.decoder.self_attn.attn_mqa.v_scale and
# model.decoder.self_attn.attn_mqa.k_scale
logger.warning(f"{name} not found in params_dict.")
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
for scale in ["k_scale", "v_scale"]:
if scale in name:
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
break
if name not in params_dict:
# modelopt ckpt contains not needed weights for MTP module:
# model.decoder.self_attn.attn_mqa.v_scale and
# model.decoder.self_attn.attn_mqa.k_scale
logger.warning(f"{name} not found in params_dict.")
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
futures.append(
executor.submit(weight_loader, param, loaded_weight)
)

# Wait for all tasks to complete and raise any exceptions.
for future in concurrent.futures.as_completed(futures):
future.result()

self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)

Expand Down
Loading