Skip to content

Commit be0b3af

Browse files
zwd003pcmoritz
andauthored
Support Deepseek-V2 (#4650)
Co-authored-by: Philipp Moritz <[email protected]>
1 parent 2cd402e commit be0b3af

File tree

6 files changed

+700
-1
lines changed

6 files changed

+700
-1
lines changed

vllm/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ def get_hidden_size(self) -> int:
297297
return self.hf_text_config.hidden_size
298298

299299
def get_head_size(self) -> int:
300+
# TODO remove hard code
301+
if hasattr(self.hf_text_config, "model_type"
302+
) and self.hf_text_config.model_type == 'deepseek_v2':
303+
# FlashAttention supports only head_size 32, 64, 128, 256,
304+
# we need to pad head_size 192 to 256
305+
return 256
300306
if hasattr(self.hf_text_config, "head_dim"):
301307
return self.hf_text_config.head_dim
302308
# FIXME(woosuk): This may not be true for all models.
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from vllm.model_executor.layers.fused_moe.fused_moe import (
2-
fused_experts, fused_moe, fused_topk, get_config_file_name)
2+
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
33

44
__all__ = [
55
"fused_moe",
66
"fused_topk",
77
"fused_experts",
88
"get_config_file_name",
9+
"grouped_topk",
910
]

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,37 @@ def fused_topk(
367367
return topk_weights, topk_ids
368368

369369

370+
# This is used by the Deepseek-V2 model
371+
def grouped_topk(
372+
hidden_states: torch.Tensor,
373+
gating_output: torch.Tensor,
374+
topk: int,
375+
renormalize: bool,
376+
num_expert_group: int = 0,
377+
topk_group: int = 0,
378+
):
379+
scores = torch.softmax(gating_output, dim=-1)
380+
num_token = scores.shape[0]
381+
group_scores = scores.view(num_token, num_expert_group,
382+
-1).max(dim=-1).values # [n, n_group]
383+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
384+
sorted=False)[1] # [n, top_k_group]
385+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
386+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
387+
score_mask = group_mask.unsqueeze(-1).expand(
388+
num_token, num_expert_group,
389+
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
390+
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
391+
topk_weights, topk_ids = torch.topk(tmp_scores,
392+
k=topk,
393+
dim=-1,
394+
sorted=False)
395+
396+
if renormalize:
397+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
398+
return topk_weights, topk_ids
399+
400+
370401
def fused_experts(hidden_states: torch.Tensor,
371402
w1: torch.Tensor,
372403
w2: torch.Tensor,

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,119 @@ def forward(
610610
return query.flatten(-2), key.flatten(-2)
611611

612612

613+
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
614+
if scale <= 1:
615+
return 1.0
616+
return 0.1 * mscale * math.log(scale) + 1.0
617+
618+
619+
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
620+
"""RotaryEmbedding extended with YaRN method.
621+
622+
Credits to Peng et al. github.com/jquesnelle/yarn
623+
"""
624+
625+
def __init__(
626+
self,
627+
head_size: int,
628+
rotary_dim: int,
629+
max_position_embeddings: int,
630+
base: int,
631+
is_neox_style: bool,
632+
scaling_factor: float,
633+
dtype: torch.dtype,
634+
*,
635+
extrapolation_factor: float = 1,
636+
attn_factor: float = 1,
637+
beta_fast: int = 32,
638+
beta_slow: int = 1,
639+
mscale: float = 1,
640+
mscale_all_dim: float = 0,
641+
) -> None:
642+
self.scaling_factor = scaling_factor
643+
self.extrapolation_factor = extrapolation_factor
644+
self.attn_factor = attn_factor
645+
self.beta_fast = beta_fast
646+
self.beta_slow = beta_slow
647+
# Get n-d magnitude scaling corrected for interpolation.
648+
self.mscale = float(
649+
yarn_get_mscale(self.scaling_factor, float(mscale)) /
650+
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
651+
attn_factor)
652+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
653+
is_neox_style, dtype)
654+
655+
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
656+
pos_freqs = self.base**(torch.arange(
657+
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
658+
self.rotary_dim)
659+
inv_freq_extrapolation = 1.0 / pos_freqs
660+
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
661+
662+
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
663+
self.rotary_dim, self.base,
664+
self.max_position_embeddings)
665+
# Get n-d rotational scaling corrected for extrapolation
666+
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
667+
low, high, self.rotary_dim // 2,
668+
dtype=torch.float)) * self.extrapolation_factor
669+
inv_freq = inv_freq_interpolation * (
670+
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
671+
return inv_freq
672+
673+
def _compute_cos_sin_cache(self) -> torch.Tensor:
674+
inv_freq = self._compute_inv_freq(self.scaling_factor)
675+
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
676+
device="cuda",
677+
dtype=torch.float32)
678+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
679+
cos = (freqs.cos() * self.mscale)
680+
sin = (freqs.sin() * self.mscale)
681+
cache = torch.cat((cos, sin), dim=-1)
682+
print("Cache shape", cache.shape)
683+
return cache
684+
685+
def forward(
686+
self,
687+
positions: torch.Tensor,
688+
query: torch.Tensor,
689+
key: torch.Tensor,
690+
offsets: Optional[torch.Tensor] = None,
691+
) -> Tuple[torch.Tensor, torch.Tensor]:
692+
"""PyTorch-native implementation equivalent to forward()."""
693+
query_rot = query[..., :self.rotary_dim]
694+
key_rot = key[..., :self.rotary_dim]
695+
if self.rotary_dim < self.head_size:
696+
query_pass = query[..., self.rotary_dim:]
697+
key_pass = key[..., self.rotary_dim:]
698+
699+
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
700+
positions.device)
701+
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
702+
if offsets is not None else positions]
703+
cos, sin = cos_sin.chunk(2, dim=-1)
704+
if self.is_neox_style:
705+
# NOTE(woosuk): Here we assume that the positions tensor has the
706+
# shape [batch_size, seq_len].
707+
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
708+
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
709+
else:
710+
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
711+
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
712+
713+
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
714+
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
715+
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
716+
717+
if self.rotary_dim < self.head_size:
718+
query = torch.cat((query_rot, query_pass), dim=-1)
719+
key = torch.cat((key_rot, key_pass), dim=-1)
720+
else:
721+
query = query_rot
722+
key = key_rot
723+
return query, key
724+
725+
613726
class GemmaRotaryEmbedding(RotaryEmbedding):
614727

615728
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
@@ -679,6 +792,19 @@ def get_rope(
679792
base, is_neox_style,
680793
scaling_factor, dtype,
681794
**extra_kwargs)
795+
elif scaling_type == "deepseek_yarn":
796+
original_max_position = rope_scaling[
797+
"original_max_position_embeddings"]
798+
# assert max_position == original_max_position * scaling_factor
799+
extra_kwargs = {
800+
k: v
801+
for k, v in rope_scaling.items()
802+
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
803+
"beta_slow", "mscale", "mscale_all_dim")
804+
}
805+
rotary_emb = DeepseekScalingRotaryEmbedding(
806+
head_size, rotary_dim, original_max_position, base,
807+
is_neox_style, scaling_factor, dtype, **extra_kwargs)
682808
# The correct one should be "longrope" but keep "su" here
683809
# for backward compatible
684810
elif scaling_type == "su" or scaling_type == "longrope":

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
2222
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
2323
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
24+
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
2425
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
2526
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
2627
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),

0 commit comments

Comments
 (0)