@@ -49,20 +49,14 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
4949
5050
5151@dataclass
52- class DeepseekV32IndexerPrefillChunkMetadata :
52+ class DeepseekV32IndexerPrefillMetadata :
5353 block_table : torch .Tensor
54+ query_start_loc : torch .Tensor
55+ max_query_len : int
5456 cu_seqlen_ks : torch .Tensor
5557 cu_seqlen_ke : torch .Tensor
5658 cu_seq_lens : torch .Tensor
5759 total_seq_lens : int
58- token_start : int
59- token_end : int
60- num_reqs : int
61-
62-
63- @dataclass
64- class DeepseekV32IndexerPrefillMetadata :
65- chunks : list [DeepseekV32IndexerPrefillChunkMetadata ]
6660
6761
6862@dataclass
@@ -104,8 +98,8 @@ class DeepseekV32IndexerMetadata:
10498
10599# TODO (zyongye) optimize this, this is now vibe coded
106100def kv_spans_from_batches (
107- start_seq_loc : torch .Tensor , seq_len_per_batch : torch . Tensor ,
108- device : torch .device ) -> tuple [torch .Tensor , torch .Tensor ]:
101+ start_seq_loc : torch .Tensor ,
102+ seq_len_per_batch : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
109103 """
110104 Args:
111105 start_seq_loc: 1D long tensor [B+1], cumulative counts of
@@ -128,14 +122,15 @@ def kv_spans_from_batches(
128122 are the **last** `counts[i]` positions of that sequence.
129123 """
130124 q = start_seq_loc .to (dtype = torch .long )
131- L = seq_len_per_batch .to (dtype = torch .long )
125+ L = seq_len_per_batch .to (dtype = torch .long , device = q . device )
132126 assert q .dim () == 1 and L .dim () == 1
133127 assert q .numel () == L .numel () + 1 , "start_seq_loc must have length B+1"
134128
135129 # Selected tokens per batch and totals
136130 counts = q [1 :] - q [:- 1 ] # [B]
137131 N = int (q [- 1 ].item ()) # total selected tokens
138132 B = L .numel ()
133+ device = L .device
139134
140135 if N == 0 :
141136 return (torch .empty (0 , dtype = torch .long , device = device ),
@@ -145,7 +140,8 @@ def kv_spans_from_batches(
145140 kv_starts_per_batch = torch .cumsum (L , dim = 0 ) - L # [B]
146141
147142 # For each selected token, which batch does it belong to?
148- batch_id = torch .repeat_interleave (torch .arange (B ), counts ) # [N]
143+ batch_id = torch .repeat_interleave (torch .arange (B , device = device ),
144+ counts ) # [N]
149145
150146 # Map batch KV start to each token
151147 start_tensor = kv_starts_per_batch [batch_id ] # [N]
@@ -155,56 +151,27 @@ def kv_spans_from_batches(
155151 L_expand = torch .repeat_interleave (L , counts ) # [N]
156152 m_expand = torch .repeat_interleave (counts , counts ) # [N]
157153 # position within the selected block: 1..counts[b]
158- pos_within = (torch .arange (N , dtype = torch .long ) -
154+ pos_within = (torch .arange (N , device = device , dtype = torch .long ) -
159155 torch .repeat_interleave (q [:- 1 ], counts ) + 1 )
160156
161157 local_pos = L_expand - m_expand + pos_within # [N], 1-based
162158 end_location = start_tensor + local_pos # exclusive end
163159
164- return start_tensor .int (). to ( device ) , end_location .int (). to ( device )
160+ return start_tensor .int (), end_location .int ()
165161
166162
167163def get_max_prefill_buffer_size (vllm_config : VllmConfig ):
168164 max_model_len = vllm_config .model_config .max_model_len
169- # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
170- # May be tuned later.
171- return max_model_len * 2
172-
173-
174- def split_prefill_chunks (seq_lens_cpu : torch .Tensor ,
175- max_prefill_buffer_size : int ,
176- reqs_start : int ) -> list [tuple [int , int ]]:
177- """
178- Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
179- such that the total sequence length of each chunk is less than the
180- maximum prefill buffer size.
181-
182- Args:
183- seq_lens_cpu: The sequence lengths of the prefill requests.
184- max_prefill_buffer_size: The maximum prefill buffer size.
185- reqs_start: The start index of the prefill requests.
186-
187- Returns:
188- A list of tuples of (reqs_start, reqs_end).
189- """
190- chunk_seq_ids = []
191- total_seq_lens = 0
192- for i in range (reqs_start , len (seq_lens_cpu )):
193- cur_seq_len = seq_lens_cpu [i ].item ()
194- assert cur_seq_len <= max_prefill_buffer_size
195- total_seq_lens += cur_seq_len
196- if total_seq_lens > max_prefill_buffer_size :
197- chunk_seq_ids .append ((reqs_start , i ))
198- reqs_start = i
199- total_seq_lens = cur_seq_len
200- if total_seq_lens > 0 :
201- chunk_seq_ids .append ((reqs_start , len (seq_lens_cpu )))
202- return chunk_seq_ids
165+ # max_num_batched_tokens = \
166+ # vllm_config.scheduler_config.max_num_batched_tokens
167+ max_num_seq = vllm_config .scheduler_config .max_num_seqs
168+ # NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
169+ return max_model_len * max_num_seq
203170
204171
205172class DeepseekV32IndexerMetadataBuilder (AttentionMetadataBuilder ):
206173 cudagraph_support : ClassVar [AttentionCGSupport ] = \
207- AttentionCGSupport .UNIFORM_SINGLE_TOKEN_DECODE
174+ AttentionCGSupport .UNIFORM_BATCH
208175
209176 reorder_batch_threshold : int = 1
210177
@@ -234,33 +201,6 @@ def __init__(self, *args, **kwargs):
234201 dtype = torch .int32 ,
235202 device = self .device )
236203
237- def build_one_prefill_chunk (self , reqs_start , reqs_end ,
238- query_start_loc_cpu , seq_lens_cpu ,
239- block_table ):
240- prefill_query_start_loc = query_start_loc_cpu [
241- reqs_start :reqs_end + 1 ] - query_start_loc_cpu [reqs_start ]
242- cu_seqlen_ks , cu_seqlen_ke = kv_spans_from_batches (
243- prefill_query_start_loc , seq_lens_cpu [reqs_start :reqs_end ],
244- self .device )
245- token_start = query_start_loc_cpu [reqs_start ].item ()
246- token_end = query_start_loc_cpu [reqs_end ].item ()
247- total_seq_lens = seq_lens_cpu [reqs_start :reqs_end ].sum ()
248- assert total_seq_lens <= self .max_prefill_buffer_size
249- cu_seq_lens = torch .cat ([
250- torch .zeros (1 , dtype = torch .int32 ),
251- seq_lens_cpu [reqs_start :reqs_end ].cumsum (dim = 0 )
252- ]).to (torch .int32 ).to (self .device )
253- return DeepseekV32IndexerPrefillChunkMetadata (
254- cu_seqlen_ks = cu_seqlen_ks ,
255- cu_seqlen_ke = cu_seqlen_ke ,
256- cu_seq_lens = cu_seq_lens ,
257- total_seq_lens = total_seq_lens ,
258- block_table = block_table [reqs_start :reqs_end ],
259- token_start = token_start ,
260- token_end = token_end ,
261- num_reqs = reqs_end - reqs_start ,
262- )
263-
264204 def build (self ,
265205 common_prefix_len : int ,
266206 common_attn_metadata : CommonAttentionMetadata ,
@@ -269,7 +209,11 @@ def build(self,
269209 num_reqs = common_attn_metadata .num_reqs
270210 num_tokens = common_attn_metadata .num_actual_tokens
271211
272- query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
212+ device = self .device
213+ block_table_tensor = common_attn_metadata .block_table_tensor
214+
215+ query_start_loc = common_attn_metadata .query_start_loc
216+
273217 num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = \
274218 split_decodes_and_prefills (
275219 common_attn_metadata ,
@@ -280,20 +224,27 @@ def build(self,
280224
281225 prefill_metadata = None
282226 if num_prefills > 0 :
283- chunk_seq_ids = split_prefill_chunks (
284- common_attn_metadata . seq_lens_cpu ,
285- self . max_prefill_buffer_size ,
286- num_decodes ,
287- )
288- chunks = [
289- self . build_one_prefill_chunk (
290- reqs_start , reqs_end , query_start_loc_cpu ,
291- common_attn_metadata . seq_lens_cpu ,
292- common_attn_metadata . block_table_tensor )
293- for reqs_start , reqs_end in chunk_seq_ids
294- ]
227+ reqs_start = num_decodes
228+ prefill_query_start_loc = query_start_loc [
229+ reqs_start :] - query_start_loc [ reqs_start ]
230+ cu_seqlen_ks , cu_seqlen_ke = kv_spans_from_batches (
231+ prefill_query_start_loc ,
232+ common_attn_metadata . seq_lens [ reqs_start :])
233+ total_seq_lens = common_attn_metadata . seq_lens [ reqs_start :]. sum ()
234+ assert total_seq_lens < self . max_prefill_buffer_size
235+ cu_seq_lens = torch . cat ([
236+ torch . zeros ( 1 , dtype = torch . int32 , device = device ),
237+ common_attn_metadata . seq_lens [ reqs_start :]. cumsum ( dim = 0 )
238+ ]). to ( torch . int32 ). cuda ()
295239 prefill_metadata = DeepseekV32IndexerPrefillMetadata (
296- chunks = chunks , )
240+ block_table = block_table_tensor [reqs_start :, ...],
241+ query_start_loc = prefill_query_start_loc ,
242+ max_query_len = common_attn_metadata .max_query_len ,
243+ cu_seqlen_ks = cu_seqlen_ks ,
244+ cu_seqlen_ke = cu_seqlen_ke ,
245+ cu_seq_lens = cu_seq_lens ,
246+ total_seq_lens = total_seq_lens ,
247+ )
297248
298249 decode_metadata = None
299250 if num_decodes > 0 :
0 commit comments