2727
2828def compute_meta (
2929 token_lora_tensor : torch .Tensor
30- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , int , int , bool ]:
30+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , int , int , int , bool ]:
3131 """
3232 Get the information required for the sgmv kernel. With the features:
3333 1. If consecutive requests in the batch use the same LoRA, this function
@@ -43,7 +43,7 @@ def compute_meta(
4343 b_seq_start_tensor = torch .zeros_like (seq_length_tensor )
4444 b_seq_start_tensor [1 :].copy_ (cum_result [:- 1 ])
4545 max_length = seq_length_tensor .max ().item ()
46-
46+ token_nums = seq_length_tensor . sum (). item ()
4747 batch_size = lora_indices_tensor .size (0 )
4848 no_lora = False
4949 # -1 means no lora should be applied. Use `no_lora` to determine whether
@@ -52,7 +52,7 @@ def compute_meta(
5252 if batch_size == 1 and lora_indices_tensor == - 1 :
5353 no_lora = True
5454 return (b_seq_start_tensor , seq_length_tensor , lora_indices_tensor ,
55- batch_size , max_length , no_lora )
55+ batch_size , max_length , token_nums , no_lora )
5656
5757
5858# TODO see if this can be vectorized
@@ -178,7 +178,7 @@ def convert_mapping(
178178class PunicaWrapper :
179179 """
180180 PunicaWrapper is designed to manage and provide metadata for the punica
181- kernel. The main function is to maintain the state information for
181+ kernel. The main function is to maintain the state information for
182182 Multi-LoRA, and to provide the interface for the punica kernel.
183183 """
184184
@@ -216,6 +216,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
216216 dtype = torch .long ,
217217 device = device )
218218 self .max_length : int = 0
219+ self .token_nums : int = 0
219220 self .batch_size : int = - 1
220221 self .is_prefill = False
221222 self .no_lora = False
@@ -276,13 +277,13 @@ def _update_base_metadata(
276277 long_lora_offsets_tensor )
277278 else :
278279 self ._long_lora_indices .zero_ ()
279-
280280 self .indices_len [:] = indices_len
281281
282282 def _update_prefill_metada (self , token_lora_tensor : torch .Tensor ) -> None :
283283
284284 (b_seq_start_tensor , seq_length_tensor , lora_indices_tensor ,
285- batch_size , max_length , no_lora ) = compute_meta (token_lora_tensor )
285+ batch_size , max_length , token_nums ,
286+ no_lora ) = compute_meta (token_lora_tensor )
286287
287288 self ._seq_start_locs [:b_seq_start_tensor .shape [0 ]].copy_ (
288289 b_seq_start_tensor )
@@ -291,25 +292,28 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
291292 lora_indices_tensor )
292293 self .batch_size = batch_size
293294 self .max_length = max_length
295+ self .token_nums = token_nums
294296 self .no_lora = no_lora
295297
296298 @property
297299 def prefill_metadata (
298- self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , int , int ]:
300+ self
301+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , int , int , int ]:
299302 """
300303 This property provides a convenient way to access the necessary
301304 metadata for prefill-related kernel computations.
302- 1. seq_start_locs: Tensor of sequence start positions
303- 2. seq_lengths: Tensor of sequence lengths
305+ 1. seq_start_locs: Tensor of sequence start positions.
306+ 2. seq_lengths: Tensor of sequence lengths.
304307 3. lora_indices_per_batch: Tensor of lora indices, and an index of
305308 -1 means no lora should be applied.
306- 4. batch_size: batch size after clustering identical lora indices
307- 5. max_length: The maximum sequence length in the batch
309+ 4. batch_size: Batch size after clustering identical lora indices.
310+ 5. max_length: The maximum sequence length in the batch.
311+ 6. token_nums: The token numbers in the batch.
308312 """
309313 return (self ._seq_start_locs [:self .batch_size ],
310314 self ._seq_lengths [:self .batch_size ],
311315 self ._lora_indices_per_batch [:self .batch_size ],
312- self .batch_size , self .max_length )
316+ self .batch_size , self .max_length , self . token_nums )
313317
314318 @property
315319 def token_lora_indices (self ) -> torch .Tensor :
@@ -324,15 +328,15 @@ def token_lora_indices(self) -> torch.Tensor:
324328 def sampler_indices (self ) -> torch .Tensor :
325329 """
326330 This property is used to access the lora indices specifically for
327- LogitsProcessorWithLoRA
331+ LogitsProcessorWithLoRA.
328332 """
329333 sampler_indices_len = self .indices_len [1 ]
330334 return self ._sampler_indices [:sampler_indices_len ]
331335
332336 @property
333337 def sampler_indices_padded (self ) -> torch .Tensor :
334338 """
335- This property provides access to padded sampler indices
339+ This property provides access to padded sampler indices.
336340 """
337341 indices_padded_len = self .indices_len [2 ]
338342 return self ._sampler_indices_padded [:indices_padded_len ]
@@ -341,7 +345,7 @@ def sampler_indices_padded(self) -> torch.Tensor:
341345 def embeddings_indices (self ) -> torch .Tensor :
342346 """
343347 This property provides access to the indices used for lora embeddings,
344- specifically for VocabParallelEmbeddingWithLoRA
348+ specifically for VocabParallelEmbeddingWithLoRA.
345349 """
346350 embeddings_indices_len = self .indices_len [3 ]
347351 return self ._embeddings_indices [:, :embeddings_indices_len ]
@@ -350,7 +354,7 @@ def embeddings_indices(self) -> torch.Tensor:
350354 def long_lora_indices (self ) -> torch .Tensor :
351355 """
352356 This property provides access to the indices used for long context
353- lora, specifically for LinearScalingRotaryEmbeddingWithLora
357+ lora, specifically for LinearScalingRotaryEmbeddingWithLora.
354358 """
355359 long_lora_len = self .indices_len [4 ]
356360 return self ._long_lora_indices [:long_lora_len ]
@@ -524,7 +528,7 @@ def add_lora(self,
524528 scale (float): Scaling factor.
525529 y_offset (Optional[int], optional): Offset to apply to the starting
526530 column of y.
527- y_slice_size (Optional[int], optional): Size of the y column slice..
531+ y_slice_size (Optional[int], optional): Size of the y column slice.
528532 buffer (Optional[torch.Tensor], optional): Defaults to None.
529533 """
530534 y_org = y
0 commit comments