Skip to content

Commit f7f0c2a

Browse files
zyongyeheheda12345youkaichaoLucasWilkinsonrobertgshaw2-redhat
authored andcommitted
[New Model] DeepSeek-V3.2 (Rebased to Main) (vllm-project#25896)
Signed-off-by: Chen Zhang <[email protected]> Signed-off-by: youkaichao <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: NickLucche <[email protected]> Signed-off-by: Yongye Zhu <[email protected]> Signed-off-by: Barry Kang <[email protected]> Signed-off-by: Lucia Fang <[email protected]> Co-authored-by: Chen Zhang <[email protected]> Co-authored-by: youkaichao <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Co-authored-by: Robert Shaw <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Co-authored-by: yewentao256 <[email protected]> Co-authored-by: Wentao Ye <[email protected]> Co-authored-by: mgoin <[email protected]> Co-authored-by: Lucia Fang <[email protected]> Co-authored-by: Lucia Fang <[email protected]> Co-authored-by: NickLucche <[email protected]> Co-authored-by: Siyuan Fu <[email protected]> Co-authored-by: Matthew Bonanni <[email protected]> Co-authored-by: Xiaozhu Meng <[email protected]> Co-authored-by: Barry Kang <[email protected]>
1 parent c35ae07 commit f7f0c2a

File tree

4 files changed

+95
-147
lines changed

4 files changed

+95
-147
lines changed

vllm/attention/ops/flashmla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def flash_mla_with_kvcache(
136136
descale_k is None
137137
), "descale_q and descale_k should be both None or both not None"
138138

139-
if indices is None and q.element_size() == 1:
139+
if (descale_q is not None) and (descale_k is not None):
140140
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
141141
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
142142
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)

vllm/model_executor/models/config.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
400400
"exactly equal.", mamba_padding_pct)
401401

402402

403-
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
403+
class DeepseekV3ForCausalLM(VerifyAndUpdateConfig):
404404

405405
@classmethod
406406
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
@@ -409,20 +409,20 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
409409
"""
410410
hf_config = vllm_config.model_config.hf_config
411411

412-
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
413412
is_v32 = hasattr(hf_config, "index_topk")
414-
assert is_v32
415413

416-
# For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
417-
# "auto")
418-
cache_config = vllm_config.cache_config
419-
if cache_config.cache_dtype == "auto" or \
420-
cache_config.cache_dtype.startswith("fp8"):
421-
cache_config.cache_dtype = "fp8_ds_mla"
422-
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
423-
if cache_config.cache_dtype == "bfloat16":
424-
cache_config.cache_dtype = "auto"
425-
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
414+
if is_v32:
415+
# For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
416+
# "auto")
417+
cache_config = vllm_config.cache_config
418+
if cache_config.cache_dtype == "auto" or \
419+
cache_config.cache_dtype.startswith("fp8"):
420+
cache_config.cache_dtype = "fp8_ds_mla"
421+
logger.info(
422+
"Using custom fp8 kv-cache format for DeepSeekV3.2")
423+
if cache_config.cache_dtype == "bfloat16":
424+
cache_config.cache_dtype = "auto"
425+
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
426426

427427

428428
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
@@ -441,5 +441,5 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
441441
"MambaForCausalLM": MambaModelConfig,
442442
"Mamba2ForCausalLM": MambaModelConfig,
443443
"FalconMambaForCausalLM": MambaModelConfig,
444-
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
444+
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
445445
}

vllm/model_executor/models/deepseek_v2.py

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,13 @@
6464
from vllm.model_executor.model_loader.weight_utils import (
6565
default_weight_loader, maybe_remap_kv_scale_name)
6666
from vllm.model_executor.models.utils import sequence_parallel_chunk
67-
<<<<<<< HEAD
6867
from vllm.platforms import current_platform
6968
from vllm.sequence import IntermediateTensors
7069
from vllm.utils import cdiv, direct_register_custom_op
7170
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
7271
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
7372
DeepseekV32IndexerMetadata)
7473
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
75-
=======
76-
from vllm.sequence import IntermediateTensors
77-
>>>>>>> a5354b3ed ([Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982))
7874

7975
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
8076
from .utils import (PPMissingLayer, is_pp_missing_parameter,
@@ -587,43 +583,44 @@ def sparse_attn_indexer(
587583
topk_indices_buffer[:hidden_states.shape[0]] = -1
588584
if has_prefill:
589585
prefill_metadata = attn_metadata.prefill
590-
for chunk in prefill_metadata.chunks:
591-
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
592-
device=k.device,
593-
dtype=torch.float8_e4m3fn)
594-
k_scale = torch.empty([chunk.total_seq_lens, 1],
595-
device=k.device,
596-
dtype=torch.float32)
597-
cp_gather_indexer_k_quant_cache(
598-
kv_cache,
599-
k_fp8,
600-
k_scale,
601-
chunk.block_table,
602-
chunk.cu_seq_lens,
603-
chunk.num_reqs,
604-
)
605-
logits = fp8_mqa_logits(
606-
q_fp8[chunk.token_start:chunk.token_end],
607-
(k_fp8, k_scale),
608-
weights[chunk.token_start:chunk.token_end],
609-
chunk.cu_seqlen_ks,
610-
chunk.cu_seqlen_ke,
611-
)
612-
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
613-
dim=-1)[1]
614-
topk_indices -= chunk.cu_seqlen_ks[:, None]
615-
mask_lo = topk_indices >= 0
616-
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
617-
chunk.cu_seqlen_ks)[:, None] < 0
618-
mask = torch.full_like(topk_indices,
619-
False,
620-
dtype=torch.bool,
621-
device=topk_indices.device)
622-
mask = mask_lo & mask_hi
623-
topk_indices = topk_indices.masked_fill(~mask, -1)
624-
topk_indices_buffer[
625-
chunk.token_start:chunk.token_end, :topk_indices.
626-
shape[-1]] = topk_indices.to(dtype=torch.int32)
586+
num_prefills = attn_metadata.num_prefills
587+
k_fp8 = torch.empty([prefill_metadata.total_seq_lens, head_dim],
588+
device=k.device,
589+
dtype=torch.float8_e4m3fn)
590+
k_scale = torch.empty([prefill_metadata.total_seq_lens, 1],
591+
device=k.device,
592+
dtype=torch.float32)
593+
cp_gather_indexer_k_quant_cache(
594+
kv_cache,
595+
k_fp8,
596+
k_scale,
597+
prefill_metadata.block_table,
598+
prefill_metadata.cu_seq_lens,
599+
num_prefills,
600+
)
601+
cu_seqlen_ks = prefill_metadata.cu_seqlen_ks
602+
cu_seqlen_ke = prefill_metadata.cu_seqlen_ke
603+
num_tokens = attn_metadata.num_actual_tokens
604+
logits = fp8_mqa_logits(
605+
q_fp8[num_decode_tokens:num_tokens],
606+
(k_fp8, k_scale),
607+
weights[num_decode_tokens:num_tokens],
608+
cu_seqlen_ks,
609+
cu_seqlen_ke,
610+
)
611+
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
612+
dim=-1)[1]
613+
topk_indices -= cu_seqlen_ks[:, None]
614+
mask_lo = topk_indices >= 0
615+
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0
616+
mask = torch.full_like(topk_indices,
617+
False,
618+
dtype=torch.bool,
619+
device=topk_indices.device)
620+
mask = mask_lo & mask_hi
621+
topk_indices = topk_indices.masked_fill(~mask, -1)
622+
topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices.
623+
shape[-1]] = topk_indices.to(dtype=torch.int32)
627624

628625
if has_decode:
629626
decode_metadata = attn_metadata.decode

vllm/v1/attention/backends/mla/indexer.py

Lines changed: 42 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,14 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
4949

5050

5151
@dataclass
52-
class DeepseekV32IndexerPrefillChunkMetadata:
52+
class DeepseekV32IndexerPrefillMetadata:
5353
block_table: torch.Tensor
54+
query_start_loc: torch.Tensor
55+
max_query_len: int
5456
cu_seqlen_ks: torch.Tensor
5557
cu_seqlen_ke: torch.Tensor
5658
cu_seq_lens: torch.Tensor
5759
total_seq_lens: int
58-
token_start: int
59-
token_end: int
60-
num_reqs: int
61-
62-
63-
@dataclass
64-
class DeepseekV32IndexerPrefillMetadata:
65-
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
6660

6761

6862
@dataclass
@@ -104,8 +98,8 @@ class DeepseekV32IndexerMetadata:
10498

10599
# TODO (zyongye) optimize this, this is now vibe coded
106100
def kv_spans_from_batches(
107-
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor,
108-
device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
101+
start_seq_loc: torch.Tensor,
102+
seq_len_per_batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
109103
"""
110104
Args:
111105
start_seq_loc: 1D long tensor [B+1], cumulative counts of
@@ -128,14 +122,15 @@ def kv_spans_from_batches(
128122
are the **last** `counts[i]` positions of that sequence.
129123
"""
130124
q = start_seq_loc.to(dtype=torch.long)
131-
L = seq_len_per_batch.to(dtype=torch.long)
125+
L = seq_len_per_batch.to(dtype=torch.long, device=q.device)
132126
assert q.dim() == 1 and L.dim() == 1
133127
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
134128

135129
# Selected tokens per batch and totals
136130
counts = q[1:] - q[:-1] # [B]
137131
N = int(q[-1].item()) # total selected tokens
138132
B = L.numel()
133+
device = L.device
139134

140135
if N == 0:
141136
return (torch.empty(0, dtype=torch.long, device=device),
@@ -145,7 +140,8 @@ def kv_spans_from_batches(
145140
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
146141

147142
# For each selected token, which batch does it belong to?
148-
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
143+
batch_id = torch.repeat_interleave(torch.arange(B, device=device),
144+
counts) # [N]
149145

150146
# Map batch KV start to each token
151147
start_tensor = kv_starts_per_batch[batch_id] # [N]
@@ -155,56 +151,27 @@ def kv_spans_from_batches(
155151
L_expand = torch.repeat_interleave(L, counts) # [N]
156152
m_expand = torch.repeat_interleave(counts, counts) # [N]
157153
# position within the selected block: 1..counts[b]
158-
pos_within = (torch.arange(N, dtype=torch.long) -
154+
pos_within = (torch.arange(N, device=device, dtype=torch.long) -
159155
torch.repeat_interleave(q[:-1], counts) + 1)
160156

161157
local_pos = L_expand - m_expand + pos_within # [N], 1-based
162158
end_location = start_tensor + local_pos # exclusive end
163159

164-
return start_tensor.int().to(device), end_location.int().to(device)
160+
return start_tensor.int(), end_location.int()
165161

166162

167163
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
168164
max_model_len = vllm_config.model_config.max_model_len
169-
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
170-
# May be tuned later.
171-
return max_model_len * 2
172-
173-
174-
def split_prefill_chunks(seq_lens_cpu: torch.Tensor,
175-
max_prefill_buffer_size: int,
176-
reqs_start: int) -> list[tuple[int, int]]:
177-
"""
178-
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
179-
such that the total sequence length of each chunk is less than the
180-
maximum prefill buffer size.
181-
182-
Args:
183-
seq_lens_cpu: The sequence lengths of the prefill requests.
184-
max_prefill_buffer_size: The maximum prefill buffer size.
185-
reqs_start: The start index of the prefill requests.
186-
187-
Returns:
188-
A list of tuples of (reqs_start, reqs_end).
189-
"""
190-
chunk_seq_ids = []
191-
total_seq_lens = 0
192-
for i in range(reqs_start, len(seq_lens_cpu)):
193-
cur_seq_len = seq_lens_cpu[i].item()
194-
assert cur_seq_len <= max_prefill_buffer_size
195-
total_seq_lens += cur_seq_len
196-
if total_seq_lens > max_prefill_buffer_size:
197-
chunk_seq_ids.append((reqs_start, i))
198-
reqs_start = i
199-
total_seq_lens = cur_seq_len
200-
if total_seq_lens > 0:
201-
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
202-
return chunk_seq_ids
165+
# max_num_batched_tokens = \
166+
# vllm_config.scheduler_config.max_num_batched_tokens
167+
max_num_seq = vllm_config.scheduler_config.max_num_seqs
168+
# NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
169+
return max_model_len * max_num_seq
203170

204171

205172
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
206173
cudagraph_support: ClassVar[AttentionCGSupport] = \
207-
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
174+
AttentionCGSupport.UNIFORM_BATCH
208175

209176
reorder_batch_threshold: int = 1
210177

@@ -234,33 +201,6 @@ def __init__(self, *args, **kwargs):
234201
dtype=torch.int32,
235202
device=self.device)
236203

237-
def build_one_prefill_chunk(self, reqs_start, reqs_end,
238-
query_start_loc_cpu, seq_lens_cpu,
239-
block_table):
240-
prefill_query_start_loc = query_start_loc_cpu[
241-
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
242-
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
243-
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
244-
self.device)
245-
token_start = query_start_loc_cpu[reqs_start].item()
246-
token_end = query_start_loc_cpu[reqs_end].item()
247-
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
248-
assert total_seq_lens <= self.max_prefill_buffer_size
249-
cu_seq_lens = torch.cat([
250-
torch.zeros(1, dtype=torch.int32),
251-
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
252-
]).to(torch.int32).to(self.device)
253-
return DeepseekV32IndexerPrefillChunkMetadata(
254-
cu_seqlen_ks=cu_seqlen_ks,
255-
cu_seqlen_ke=cu_seqlen_ke,
256-
cu_seq_lens=cu_seq_lens,
257-
total_seq_lens=total_seq_lens,
258-
block_table=block_table[reqs_start:reqs_end],
259-
token_start=token_start,
260-
token_end=token_end,
261-
num_reqs=reqs_end - reqs_start,
262-
)
263-
264204
def build(self,
265205
common_prefix_len: int,
266206
common_attn_metadata: CommonAttentionMetadata,
@@ -269,7 +209,11 @@ def build(self,
269209
num_reqs = common_attn_metadata.num_reqs
270210
num_tokens = common_attn_metadata.num_actual_tokens
271211

272-
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
212+
device = self.device
213+
block_table_tensor = common_attn_metadata.block_table_tensor
214+
215+
query_start_loc = common_attn_metadata.query_start_loc
216+
273217
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
274218
split_decodes_and_prefills(
275219
common_attn_metadata,
@@ -280,20 +224,27 @@ def build(self,
280224

281225
prefill_metadata = None
282226
if num_prefills > 0:
283-
chunk_seq_ids = split_prefill_chunks(
284-
common_attn_metadata.seq_lens_cpu,
285-
self.max_prefill_buffer_size,
286-
num_decodes,
287-
)
288-
chunks = [
289-
self.build_one_prefill_chunk(
290-
reqs_start, reqs_end, query_start_loc_cpu,
291-
common_attn_metadata.seq_lens_cpu,
292-
common_attn_metadata.block_table_tensor)
293-
for reqs_start, reqs_end in chunk_seq_ids
294-
]
227+
reqs_start = num_decodes
228+
prefill_query_start_loc = query_start_loc[
229+
reqs_start:] - query_start_loc[reqs_start]
230+
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
231+
prefill_query_start_loc,
232+
common_attn_metadata.seq_lens[reqs_start:])
233+
total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum()
234+
assert total_seq_lens < self.max_prefill_buffer_size
235+
cu_seq_lens = torch.cat([
236+
torch.zeros(1, dtype=torch.int32, device=device),
237+
common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0)
238+
]).to(torch.int32).cuda()
295239
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
296-
chunks=chunks, )
240+
block_table=block_table_tensor[reqs_start:, ...],
241+
query_start_loc=prefill_query_start_loc,
242+
max_query_len=common_attn_metadata.max_query_len,
243+
cu_seqlen_ks=cu_seqlen_ks,
244+
cu_seqlen_ke=cu_seqlen_ke,
245+
cu_seq_lens=cu_seq_lens,
246+
total_seq_lens=total_seq_lens,
247+
)
297248

298249
decode_metadata = None
299250
if num_decodes > 0:

0 commit comments

Comments
 (0)