Skip to content

Commit d78cd57

Browse files
author
Varun Sundar Rabindranath
committed
Add v1 kernels
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
1 parent 992e5c3 commit d78cd57

18 files changed

+100434
-96
lines changed

benchmarks/kernels/benchmark_lora.py

Lines changed: 141 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
2424
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
2525
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
26+
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink
2627
from vllm.utils import FlexibleArgumentParser
2728

2829
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@@ -172,6 +173,8 @@ class OpType(Enum):
172173
SGMV_EXPAND = auto()
173174
BGMV_EXPAND = auto()
174175
BGMV_EXPAND_SLICE = auto()
176+
V1_SHRINK = auto()
177+
V1_EXPAND = auto()
175178

176179
@staticmethod
177180
def from_str(s: str) -> "OpType":
@@ -185,28 +188,43 @@ def from_str(s: str) -> "OpType":
185188
return OpType.BGMV_EXPAND
186189
if s.lower() == "bgmv_expand_slice":
187190
return OpType.BGMV_EXPAND_SLICE
191+
if s.lower() == "v1_shrink":
192+
return OpType.V1_SHRINK
193+
if s.lower() == "v1_expand":
194+
return OpType.V1_EXPAND
188195
raise ValueError(f"Unrecognized str {s} to convert to OpType")
189196

190197
def is_shrink_fn(self) -> bool:
191-
return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK]
198+
return self in [
199+
OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
200+
]
192201

193202
def is_expand_fn(self) -> bool:
194-
return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND]
203+
return self in [
204+
OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
205+
]
195206

196207
def is_prefill_op(self) -> bool:
197-
return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND]
208+
return self in [
209+
OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
210+
OpType.V1_EXPAND
211+
]
198212

199213
def is_decode_op(self) -> bool:
200214
return self in [
201-
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE
215+
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
216+
OpType.V1_SHRINK, OpType.V1_EXPAND
202217
]
203218

204219
def is_expand_slice_fn(self) -> bool:
205220
return self in [OpType.BGMV_EXPAND_SLICE]
206221

207222
def num_slices(self) -> List[int]:
208-
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
209-
# SGMV kernels supports slices
223+
if self in [
224+
OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
225+
OpType.V1_EXPAND
226+
]:
227+
# SGMV kernels and v1 kernels supports slices
210228
return [1, 2, 3]
211229
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
212230
return [1]
@@ -251,11 +269,13 @@ def matmul_shapes(
251269
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
252270

253271
b_shape = (num_loras, n, k) # col-major
254-
if self == OpType.SGMV_SHRINK:
255-
# SGMV shrink supports num_slices inherently in the kernel
272+
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
273+
# SGMV shrink and V1 shrink kernels support num_slices inherently
274+
# in the kernel.
256275
return ((m, k), b_shape, (num_slices, m, n))
257-
if self == OpType.SGMV_EXPAND:
258-
# SGMV expand supports num_slices inherently in the kernel
276+
if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
277+
# SGMV expand and V1 expand kernels support num_slices inherently
278+
# in the kernel
259279
return ((num_slices, m, k), b_shape, (m, n * num_slices))
260280
if self == OpType.BGMV_SHRINK:
261281
return ((m, k), b_shape, (m, n))
@@ -282,25 +302,30 @@ def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]):
282302
return bgmv_expand
283303
if self == OpType.BGMV_EXPAND_SLICE:
284304
return emulate_bgmv_expand_slice
305+
if self == OpType.V1_SHRINK:
306+
return v1_shrink
307+
if self == OpType.V1_EXPAND:
308+
return v1_expand
309+
285310
raise ValueError(f"Unrecognized optype {self}")
286311

287312
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
288313
lora_weights: List[torch.Tensor],
289314
**kwargs) -> Callable:
290-
"""Each benchmark operation expected the input, lora_weights and outputs
315+
"""Each benchmark operation expects the input, lora_weights and outputs
291316
in a slightly different format. Refer to self.matmul_shapes().
292317
run_ref_group_gemm accounts for those differences in executing a
293318
reference group gemm for correctness testing.
294319
"""
295320
w_dtype = lora_weights[0].dtype
296321
num_slices = len(lora_weights)
297-
if self == OpType.SGMV_SHRINK:
322+
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
298323
for slice_idx in range(num_slices):
299324
ref_group_gemm(ref_out=output[slice_idx, :],
300325
input=input,
301326
lora_weights=lora_weights[slice_idx],
302327
**kwargs)
303-
if self == OpType.SGMV_EXPAND:
328+
elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
304329
hidden_size = lora_weights[0].shape[1]
305330
for slice_idx in range(num_slices):
306331
slice_offset = slice_idx * hidden_size
@@ -309,19 +334,19 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
309334
input=input[slice_idx].clone().to(dtype=w_dtype),
310335
lora_weights=lora_weights[slice_idx],
311336
**kwargs)
312-
if self == OpType.BGMV_SHRINK:
337+
elif self == OpType.BGMV_SHRINK:
313338
assert num_slices == 1
314339
ref_group_gemm(ref_out=output,
315340
input=input,
316341
lora_weights=lora_weights[0],
317342
**kwargs)
318-
if self == OpType.BGMV_EXPAND:
343+
elif self == OpType.BGMV_EXPAND:
319344
assert num_slices == 1
320345
ref_group_gemm(ref_out=output,
321346
input=input.clone().to(dtype=w_dtype),
322347
lora_weights=lora_weights[0],
323348
**kwargs)
324-
if self == OpType.BGMV_EXPAND_SLICE:
349+
elif self == OpType.BGMV_EXPAND_SLICE:
325350
hidden_size = lora_weights[0].shape[1]
326351
for slice_idx in range(num_slices):
327352
slice_offset = slice_idx * hidden_size
@@ -330,7 +355,8 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
330355
input=input[slice_idx].clone().to(dtype=w_dtype),
331356
lora_weights=lora_weights[slice_idx],
332357
**kwargs)
333-
raise ValueError(f"Unrecognized optype {self}")
358+
else:
359+
raise ValueError(f"Unrecognized optype {self}")
334360

335361

336362
@dataclass
@@ -391,6 +417,8 @@ class BenchmarkTensors:
391417
seq_start_loc: torch.Tensor
392418
prompt_lora_mapping: torch.Tensor
393419
token_lora_mapping: torch.Tensor
420+
# v1 kernel metadata
421+
v1_kernel_meta: Optional[V1KernelMeta] = None
394422

395423
def io_types(self) -> str:
396424
return (f"{dtype_to_str(self.input.dtype)}x"
@@ -433,10 +461,19 @@ def make(ctx: BenchmarkContext,
433461
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
434462
seq_len_tensor, "cpu")
435463

464+
v1_kernel_meta = None
465+
if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]:
466+
v1_kernel_meta = V1KernelMeta.make(
467+
max_loras=ctx.num_loras,
468+
max_num_tokens=token_lora_indices_tensor.size(0),
469+
device="cpu")
470+
v1_kernel_meta.prepare_tensors(
471+
token_lora_mapping=token_lora_indices_tensor)
472+
436473
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
437474
seq_len_tensor, seq_start_loc_tensor,
438475
prompt_lora_indices_tensor,
439-
token_lora_indices_tensor)
476+
token_lora_indices_tensor, v1_kernel_meta)
440477

441478
def sanity_check(self) -> None:
442479
"""
@@ -469,6 +506,13 @@ def to_device(tensor: torch.Tensor):
469506
for i in range(len(self.lora_weights_lst)):
470507
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
471508

509+
# v1 meta
510+
if self.v1_kernel_meta:
511+
for field_name in V1KernelMeta.__dataclass_fields__:
512+
field = getattr(self.v1_kernel_meta, field_name)
513+
assert isinstance(field, torch.Tensor)
514+
setattr(self.v1_kernel_meta, field_name, to_device(field))
515+
472516
def metadata(self) -> Tuple[int, int, int]:
473517
"""
474518
Return num_seqs, num_tokens and max_seq_len
@@ -668,6 +712,78 @@ def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
668712
})
669713
return {'kwargs_list': kwargs_list}
670714

715+
def as_v1_shrink_kwargs(self) -> Dict[str, Any]:
716+
assert self.v1_kernel_meta is not None
717+
self.sanity_check()
718+
self.to_device(self.input.device)
719+
720+
_, num_tokens, _, num_slices = self.metadata()
721+
722+
# Sanity check matrix shapes.
723+
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
724+
0].shape, self.output.shape
725+
# Expected input shape [num_tokens, hidden_size]
726+
assert len(i_shape) == 2
727+
assert i_shape[0] == num_tokens
728+
hidden_size = i_shape[1]
729+
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
730+
assert len(lw_shape) == 3
731+
assert lw_shape[2] == hidden_size
732+
lora_rank = lw_shape[1]
733+
# Expected output shape [num_slices, num_tokens, lora_rank]
734+
assert len(o_shape) == 3
735+
assert o_shape == (num_slices, num_tokens, lora_rank)
736+
737+
return {
738+
'inputs': self.input,
739+
'lora_a_weights': self.lora_weights_lst,
740+
'output_tensor': self.output,
741+
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
742+
'token_indices_sorted_by_lora_ids':
743+
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
744+
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
745+
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
746+
'lora_ids': self.v1_kernel_meta.active_lora_ids,
747+
'scaling': 1.0,
748+
}
749+
750+
def as_v1_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
751+
assert self.v1_kernel_meta is not None
752+
self.sanity_check()
753+
self.to_device(self.input.device)
754+
755+
_, num_tokens, _, num_slices = self.metadata()
756+
757+
# Sanity check matrix shapes.
758+
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
759+
0].shape, self.output.shape
760+
# Expected input shape : [num_slices, num_tokens, lora_rank]
761+
assert len(i_shape) == 3
762+
assert i_shape[0] == num_slices
763+
assert i_shape[1] == num_tokens
764+
lora_rank = i_shape[2]
765+
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
766+
assert len(lw_shape) == 3
767+
assert lw_shape[2] == lora_rank
768+
hidden_size = lw_shape[1]
769+
# Expected output shape : [num_tokens, hidden_size * num_slices]
770+
assert len(o_shape) == 2
771+
assert o_shape == (num_tokens, hidden_size * num_slices)
772+
773+
return {
774+
'inputs': self.input,
775+
'lora_b_weights': self.lora_weights_lst,
776+
'output_tensor': self.output,
777+
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
778+
'token_indices_sorted_by_lora_ids':
779+
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
780+
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
781+
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
782+
'lora_ids': self.v1_kernel_meta.active_lora_ids,
783+
'offset_start': 0,
784+
'add_inputs': add_inputs,
785+
}
786+
671787
def bench_fn_kwargs(self,
672788
op_type: OpType,
673789
add_inputs: Optional[bool] = None) -> Dict[str, Any]:
@@ -686,6 +802,10 @@ def bench_fn_kwargs(self,
686802
return self.as_bgmv_expand_kwargs(add_inputs)
687803
if op_type == OpType.BGMV_EXPAND_SLICE:
688804
return self.as_bgmv_expand_slice_kwargs(add_inputs)
805+
if op_type == OpType.V1_SHRINK:
806+
return self.as_v1_shrink_kwargs()
807+
if op_type == OpType.V1_EXPAND:
808+
return self.as_v1_expand_kwargs(add_inputs)
689809
raise ValueError(f"Unrecognized optype {self}")
690810

691811
def test_correctness(self, op_type: OpType,
@@ -873,12 +993,9 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
873993
timers = []
874994
for bench_ctx in bench_ctxs:
875995
for seq_len in args.seq_lengths:
876-
bench_ops: List[OpType] = []
877-
if seq_len == 1:
878-
# bench all decode ops
879-
bench_ops = [op for op in args.op_types if op.is_decode_op()]
880-
else:
881-
# bench all prefill ops
996+
bench_ops: List[OpType] = args.op_types
997+
if seq_len > 1:
998+
# bench only prefill ops
882999
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
8831000

8841001
seq_len_timers = []

0 commit comments

Comments
 (0)