2323from vllm .lora .ops .triton_ops .sgmv_expand import sgmv_expand
2424from vllm .lora .ops .triton_ops .sgmv_shrink import sgmv_shrink
2525from vllm .lora .ops .triton_ops .utils import _LORA_A_PTR_DICT , _LORA_B_PTR_DICT
26+ from vllm .lora .ops .triton_ops .v1 import V1KernelMeta , v1_expand , v1_shrink
2627from vllm .utils import FlexibleArgumentParser
2728
2829DEFAULT_MODELS = list (WEIGHT_SHAPES .keys ())
@@ -172,6 +173,8 @@ class OpType(Enum):
172173 SGMV_EXPAND = auto ()
173174 BGMV_EXPAND = auto ()
174175 BGMV_EXPAND_SLICE = auto ()
176+ V1_SHRINK = auto ()
177+ V1_EXPAND = auto ()
175178
176179 @staticmethod
177180 def from_str (s : str ) -> "OpType" :
@@ -185,28 +188,43 @@ def from_str(s: str) -> "OpType":
185188 return OpType .BGMV_EXPAND
186189 if s .lower () == "bgmv_expand_slice" :
187190 return OpType .BGMV_EXPAND_SLICE
191+ if s .lower () == "v1_shrink" :
192+ return OpType .V1_SHRINK
193+ if s .lower () == "v1_expand" :
194+ return OpType .V1_EXPAND
188195 raise ValueError (f"Unrecognized str { s } to convert to OpType" )
189196
190197 def is_shrink_fn (self ) -> bool :
191- return self in [OpType .SGMV_SHRINK , OpType .BGMV_SHRINK ]
198+ return self in [
199+ OpType .SGMV_SHRINK , OpType .BGMV_SHRINK , OpType .V1_SHRINK
200+ ]
192201
193202 def is_expand_fn (self ) -> bool :
194- return self in [OpType .SGMV_EXPAND , OpType .BGMV_EXPAND ]
203+ return self in [
204+ OpType .SGMV_EXPAND , OpType .BGMV_EXPAND , OpType .V1_EXPAND
205+ ]
195206
196207 def is_prefill_op (self ) -> bool :
197- return self in [OpType .SGMV_SHRINK , OpType .SGMV_EXPAND ]
208+ return self in [
209+ OpType .SGMV_SHRINK , OpType .SGMV_EXPAND , OpType .V1_SHRINK ,
210+ OpType .V1_EXPAND
211+ ]
198212
199213 def is_decode_op (self ) -> bool :
200214 return self in [
201- OpType .BGMV_SHRINK , OpType .BGMV_EXPAND , OpType .BGMV_EXPAND_SLICE
215+ OpType .BGMV_SHRINK , OpType .BGMV_EXPAND , OpType .BGMV_EXPAND_SLICE ,
216+ OpType .V1_SHRINK , OpType .V1_EXPAND
202217 ]
203218
204219 def is_expand_slice_fn (self ) -> bool :
205220 return self in [OpType .BGMV_EXPAND_SLICE ]
206221
207222 def num_slices (self ) -> List [int ]:
208- if self in [OpType .SGMV_EXPAND , OpType .SGMV_SHRINK ]:
209- # SGMV kernels supports slices
223+ if self in [
224+ OpType .SGMV_EXPAND , OpType .SGMV_SHRINK , OpType .V1_SHRINK ,
225+ OpType .V1_EXPAND
226+ ]:
227+ # SGMV kernels and v1 kernels supports slices
210228 return [1 , 2 , 3 ]
211229 if self in [OpType .BGMV_SHRINK , OpType .BGMV_EXPAND ]:
212230 return [1 ]
@@ -251,11 +269,13 @@ def matmul_shapes(
251269 m , k , n = self .mkn (batch_size , seq_length , hidden_size , lora_rank )
252270
253271 b_shape = (num_loras , n , k ) # col-major
254- if self == OpType .SGMV_SHRINK :
255- # SGMV shrink supports num_slices inherently in the kernel
272+ if self in [OpType .SGMV_SHRINK , OpType .V1_SHRINK ]:
273+ # SGMV shrink and V1 shrink kernels support num_slices inherently
274+ # in the kernel.
256275 return ((m , k ), b_shape , (num_slices , m , n ))
257- if self == OpType .SGMV_EXPAND :
258- # SGMV expand supports num_slices inherently in the kernel
276+ if self in [OpType .SGMV_EXPAND , OpType .V1_EXPAND ]:
277+ # SGMV expand and V1 expand kernels support num_slices inherently
278+ # in the kernel
259279 return ((num_slices , m , k ), b_shape , (m , n * num_slices ))
260280 if self == OpType .BGMV_SHRINK :
261281 return ((m , k ), b_shape , (m , n ))
@@ -282,25 +302,30 @@ def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]):
282302 return bgmv_expand
283303 if self == OpType .BGMV_EXPAND_SLICE :
284304 return emulate_bgmv_expand_slice
305+ if self == OpType .V1_SHRINK :
306+ return v1_shrink
307+ if self == OpType .V1_EXPAND :
308+ return v1_expand
309+
285310 raise ValueError (f"Unrecognized optype { self } " )
286311
287312 def run_ref_group_gemm (self , output : torch .Tensor , input : torch .Tensor ,
288313 lora_weights : List [torch .Tensor ],
289314 ** kwargs ) -> Callable :
290- """Each benchmark operation expected the input, lora_weights and outputs
315+ """Each benchmark operation expects the input, lora_weights and outputs
291316 in a slightly different format. Refer to self.matmul_shapes().
292317 run_ref_group_gemm accounts for those differences in executing a
293318 reference group gemm for correctness testing.
294319 """
295320 w_dtype = lora_weights [0 ].dtype
296321 num_slices = len (lora_weights )
297- if self == OpType .SGMV_SHRINK :
322+ if self in [ OpType .SGMV_SHRINK , OpType . V1_SHRINK ] :
298323 for slice_idx in range (num_slices ):
299324 ref_group_gemm (ref_out = output [slice_idx , :],
300325 input = input ,
301326 lora_weights = lora_weights [slice_idx ],
302327 ** kwargs )
303- if self == OpType .SGMV_EXPAND :
328+ elif self in [ OpType .SGMV_EXPAND , OpType . V1_EXPAND ] :
304329 hidden_size = lora_weights [0 ].shape [1 ]
305330 for slice_idx in range (num_slices ):
306331 slice_offset = slice_idx * hidden_size
@@ -309,19 +334,19 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
309334 input = input [slice_idx ].clone ().to (dtype = w_dtype ),
310335 lora_weights = lora_weights [slice_idx ],
311336 ** kwargs )
312- if self == OpType .BGMV_SHRINK :
337+ elif self == OpType .BGMV_SHRINK :
313338 assert num_slices == 1
314339 ref_group_gemm (ref_out = output ,
315340 input = input ,
316341 lora_weights = lora_weights [0 ],
317342 ** kwargs )
318- if self == OpType .BGMV_EXPAND :
343+ elif self == OpType .BGMV_EXPAND :
319344 assert num_slices == 1
320345 ref_group_gemm (ref_out = output ,
321346 input = input .clone ().to (dtype = w_dtype ),
322347 lora_weights = lora_weights [0 ],
323348 ** kwargs )
324- if self == OpType .BGMV_EXPAND_SLICE :
349+ elif self == OpType .BGMV_EXPAND_SLICE :
325350 hidden_size = lora_weights [0 ].shape [1 ]
326351 for slice_idx in range (num_slices ):
327352 slice_offset = slice_idx * hidden_size
@@ -330,7 +355,8 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
330355 input = input [slice_idx ].clone ().to (dtype = w_dtype ),
331356 lora_weights = lora_weights [slice_idx ],
332357 ** kwargs )
333- raise ValueError (f"Unrecognized optype { self } " )
358+ else :
359+ raise ValueError (f"Unrecognized optype { self } " )
334360
335361
336362@dataclass
@@ -391,6 +417,8 @@ class BenchmarkTensors:
391417 seq_start_loc : torch .Tensor
392418 prompt_lora_mapping : torch .Tensor
393419 token_lora_mapping : torch .Tensor
420+ # v1 kernel metadata
421+ v1_kernel_meta : Optional [V1KernelMeta ] = None
394422
395423 def io_types (self ) -> str :
396424 return (f"{ dtype_to_str (self .input .dtype )} x"
@@ -433,10 +461,19 @@ def make(ctx: BenchmarkContext,
433461 total_tokens , ctx .batch_size , prompt_lora_indices_tensor ,
434462 seq_len_tensor , "cpu" )
435463
464+ v1_kernel_meta = None
465+ if op_type in [OpType .V1_SHRINK , OpType .V1_EXPAND ]:
466+ v1_kernel_meta = V1KernelMeta .make (
467+ max_loras = ctx .num_loras ,
468+ max_num_tokens = token_lora_indices_tensor .size (0 ),
469+ device = "cpu" )
470+ v1_kernel_meta .prepare_tensors (
471+ token_lora_mapping = token_lora_indices_tensor )
472+
436473 return BenchmarkTensors (input_tensor , lora_weights , output_tensor ,
437474 seq_len_tensor , seq_start_loc_tensor ,
438475 prompt_lora_indices_tensor ,
439- token_lora_indices_tensor )
476+ token_lora_indices_tensor , v1_kernel_meta )
440477
441478 def sanity_check (self ) -> None :
442479 """
@@ -469,6 +506,13 @@ def to_device(tensor: torch.Tensor):
469506 for i in range (len (self .lora_weights_lst )):
470507 self .lora_weights_lst [i ] = to_device (self .lora_weights_lst [i ])
471508
509+ # v1 meta
510+ if self .v1_kernel_meta :
511+ for field_name in V1KernelMeta .__dataclass_fields__ :
512+ field = getattr (self .v1_kernel_meta , field_name )
513+ assert isinstance (field , torch .Tensor )
514+ setattr (self .v1_kernel_meta , field_name , to_device (field ))
515+
472516 def metadata (self ) -> Tuple [int , int , int ]:
473517 """
474518 Return num_seqs, num_tokens and max_seq_len
@@ -668,6 +712,78 @@ def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
668712 })
669713 return {'kwargs_list' : kwargs_list }
670714
715+ def as_v1_shrink_kwargs (self ) -> Dict [str , Any ]:
716+ assert self .v1_kernel_meta is not None
717+ self .sanity_check ()
718+ self .to_device (self .input .device )
719+
720+ _ , num_tokens , _ , num_slices = self .metadata ()
721+
722+ # Sanity check matrix shapes.
723+ i_shape , lw_shape , o_shape = self .input .shape , self .lora_weights_lst [
724+ 0 ].shape , self .output .shape
725+ # Expected input shape [num_tokens, hidden_size]
726+ assert len (i_shape ) == 2
727+ assert i_shape [0 ] == num_tokens
728+ hidden_size = i_shape [1 ]
729+ # Expected lora weight shape [num_loras, lora_rank, hidden_size]
730+ assert len (lw_shape ) == 3
731+ assert lw_shape [2 ] == hidden_size
732+ lora_rank = lw_shape [1 ]
733+ # Expected output shape [num_slices, num_tokens, lora_rank]
734+ assert len (o_shape ) == 3
735+ assert o_shape == (num_slices , num_tokens , lora_rank )
736+
737+ return {
738+ 'inputs' : self .input ,
739+ 'lora_a_weights' : self .lora_weights_lst ,
740+ 'output_tensor' : self .output ,
741+ 'token_lora_mapping' : self .v1_kernel_meta .token_lora_mapping ,
742+ 'token_indices_sorted_by_lora_ids' :
743+ self .v1_kernel_meta .token_indices_sorted_by_lora_ids ,
744+ 'num_tokens_per_lora' : self .v1_kernel_meta .num_tokens_per_lora ,
745+ 'lora_token_start_loc' : self .v1_kernel_meta .lora_token_start_loc ,
746+ 'lora_ids' : self .v1_kernel_meta .active_lora_ids ,
747+ 'scaling' : 1.0 ,
748+ }
749+
750+ def as_v1_expand_kwargs (self , add_inputs : bool ) -> Dict [str , Any ]:
751+ assert self .v1_kernel_meta is not None
752+ self .sanity_check ()
753+ self .to_device (self .input .device )
754+
755+ _ , num_tokens , _ , num_slices = self .metadata ()
756+
757+ # Sanity check matrix shapes.
758+ i_shape , lw_shape , o_shape = self .input .shape , self .lora_weights_lst [
759+ 0 ].shape , self .output .shape
760+ # Expected input shape : [num_slices, num_tokens, lora_rank]
761+ assert len (i_shape ) == 3
762+ assert i_shape [0 ] == num_slices
763+ assert i_shape [1 ] == num_tokens
764+ lora_rank = i_shape [2 ]
765+ # Expected lora weight shape : [num_lora, hidden_size, lora_rank]
766+ assert len (lw_shape ) == 3
767+ assert lw_shape [2 ] == lora_rank
768+ hidden_size = lw_shape [1 ]
769+ # Expected output shape : [num_tokens, hidden_size * num_slices]
770+ assert len (o_shape ) == 2
771+ assert o_shape == (num_tokens , hidden_size * num_slices )
772+
773+ return {
774+ 'inputs' : self .input ,
775+ 'lora_b_weights' : self .lora_weights_lst ,
776+ 'output_tensor' : self .output ,
777+ 'token_lora_mapping' : self .v1_kernel_meta .token_lora_mapping ,
778+ 'token_indices_sorted_by_lora_ids' :
779+ self .v1_kernel_meta .token_indices_sorted_by_lora_ids ,
780+ 'num_tokens_per_lora' : self .v1_kernel_meta .num_tokens_per_lora ,
781+ 'lora_token_start_loc' : self .v1_kernel_meta .lora_token_start_loc ,
782+ 'lora_ids' : self .v1_kernel_meta .active_lora_ids ,
783+ 'offset_start' : 0 ,
784+ 'add_inputs' : add_inputs ,
785+ }
786+
671787 def bench_fn_kwargs (self ,
672788 op_type : OpType ,
673789 add_inputs : Optional [bool ] = None ) -> Dict [str , Any ]:
@@ -686,6 +802,10 @@ def bench_fn_kwargs(self,
686802 return self .as_bgmv_expand_kwargs (add_inputs )
687803 if op_type == OpType .BGMV_EXPAND_SLICE :
688804 return self .as_bgmv_expand_slice_kwargs (add_inputs )
805+ if op_type == OpType .V1_SHRINK :
806+ return self .as_v1_shrink_kwargs ()
807+ if op_type == OpType .V1_EXPAND :
808+ return self .as_v1_expand_kwargs (add_inputs )
689809 raise ValueError (f"Unrecognized optype { self } " )
690810
691811 def test_correctness (self , op_type : OpType ,
@@ -873,12 +993,9 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
873993 timers = []
874994 for bench_ctx in bench_ctxs :
875995 for seq_len in args .seq_lengths :
876- bench_ops : List [OpType ] = []
877- if seq_len == 1 :
878- # bench all decode ops
879- bench_ops = [op for op in args .op_types if op .is_decode_op ()]
880- else :
881- # bench all prefill ops
996+ bench_ops : List [OpType ] = args .op_types
997+ if seq_len > 1 :
998+ # bench only prefill ops
882999 bench_ops = [op for op in args .op_types if op .is_prefill_op ()]
8831000
8841001 seq_len_timers = []
0 commit comments