Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions benchmarks/kernels/benchmark_sgmv_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import math

import torch
import triton

from vllm.model_executor.layers.lora import sgmv_triton as sgmv

MAX_TEST_POWER = 6


# duplicated from tests/lora/test_sgmv_triton.py so there isn't a dependency
# on the tests module
def setup_(S, R, H, dtype, repeats_per_lora=1):
S = math.ceil(S / repeats_per_lora) * repeats_per_lora
num_unique = S // repeats_per_lora
if R is None:
ranks = torch.randint(3, MAX_TEST_POWER, (S, ), device='cuda')
ranks = 2**ranks # random powers of 2 between [8, MAX_TEST_POWER]
R = 2**(MAX_TEST_POWER - 1)
else:
ranks = torch.full((S, ), R, device='cuda', dtype=torch.int32)
weights = torch.randn((num_unique, 1, H, R), device='cuda', dtype=dtype)
indices = torch.randint(0, num_unique, (num_unique, ), device='cuda')
repeats = torch.full((num_unique, ),
repeats_per_lora,
device='cuda',
dtype=torch.int32)
repeats = torch.cat([
torch.zeros((1, ), device='cuda', dtype=torch.int32),
repeats.cumsum(dim=-1)
])
return (weights, ranks, indices, repeats, num_unique, R, dtype)


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['S'], # argument names to use as an x-axis for the plot
x_vals=[16 * 2**i for i in range(3, 6)] +
[4096], # different possible values for `x_name`
line_arg=
'R', # argument name which corresponds to a different line in the plot
line_vals=[64, None], # possible values for `line_arg``
line_names=['Rank=64', f'Random Rank up to {2**MAX_TEST_POWER}'
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="ms", # label name for the y-axis
plot_name=
"sgmv", # name for the plot. Used as file name for saving the plot too.
args={},
))
def benchmark_repeats_expand(S, R, repeats_per_lora=1):
weights, ranks, indices, repeats, _, R, dtype = setup_(
S, R, 4096, dtype=torch.bfloat16, repeats_per_lora=repeats_per_lora)

buffer = torch.randn((S, R), device='cuda', dtype=torch.float32)
out = torch.randn((S, 4096), device='cuda', dtype=dtype)
ms = triton.testing.do_bench(lambda: sgmv.sgmv_expand(buffer,
weights,
out,
ranks,
indices,
repeats,
repeats_per_lora,
out_col_offset=0),
warmup=500,
rep=4000)
return ms


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['S'], # argument names to use as an x-axis for the plot
x_vals=[16 * 2**i for i in range(3, 6)] +
[4096], # different possible values for `x_name`
line_arg=
'R', # argument name which corresponds to a different line in the plot
line_vals=[64, None], # possible values for `line_arg``
line_names=['Rank=64', f'Random Rank up to {2**MAX_TEST_POWER}'
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="ms", # label name for the y-axis
plot_name=
"sgmv", # name for the plot. Used as file name for saving the plot too.
args={},
))
def benchmark_repeats_shrink(S, R, repeats_per_lora=1):
weights, ranks, indices, repeats, _, R, dtype = setup_(
S, R, 4096, dtype=torch.bfloat16, repeats_per_lora=repeats_per_lora)

x = torch.rand((S, 4096), device='cuda', dtype=dtype)
out = torch.zeros((S, R), device='cuda', dtype=torch.float32)
ms = triton.testing.do_bench(lambda: sgmv.sgmv_shrink(
x, weights, out, ranks, indices, repeats, repeats_per_lora),
warmup=500,
rep=4000)
return ms


if __name__ == '__main__':
# NOTE: the random rank benchmark is random ranks up to 2^MAX_TEST_POWER,
# not random up to the rank specified,
# so it doesn't change when you change the rank you're testing
print('Times are in ms.')
print('-' * 40)
print('Expand | repeats [1]')
benchmark_repeats_expand.run(show_plots=False,
print_data=True,
repeats_per_lora=1)
print('-' * 40)
print('Shrink | repeats [1]')
benchmark_repeats_shrink.run(show_plots=False,
print_data=True,
repeats_per_lora=1)

# print('-' * 40)
# print('Expand | repeats [8]')
# benchmark_repeats_expand.run(show_plots=False,
# print_data=True,
# repeats_per_lora=8)
# print('-' * 40)
# print('Shrink | repeats [8]')
# benchmark_repeats_shrink.run(show_plots=False,
# print_data=True,
# repeats_per_lora=8)

# # set repeats >= 16 for plaid mode
# # (tl.dot is applicable which makes it fast)
# print('-' * 40)
# print('Expand | repeats [16]')
# benchmark_repeats_expand.run(show_plots=False,
# print_data=True,
# repeats_per_lora=16)
# print('-' * 40)
# print('Shrink | repeats [16]')
# benchmark_repeats_shrink.run(show_plots=False,
# print_data=True,
# repeats_per_lora=16)

print('-' * 40)
print('Expand | repeats [32]')
benchmark_repeats_expand.run(show_plots=False,
print_data=True,
repeats_per_lora=32)
print('-' * 40)
print('Shrink | repeats [32]')
benchmark_repeats_shrink.run(show_plots=False,
print_data=True,
repeats_per_lora=32)
print('-' * 40)
2 changes: 2 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ peft
requests
ray
sentence-transformers # required for embedding
pandas
matplotlib # comes from Triton benchmarking function
sparseml==1.8.0 # required for compressed-tensors
compressed-tensors==0.4.0 # required for compressed-tensors

Expand Down
2 changes: 1 addition & 1 deletion requirements-xpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ torch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-
intel_extension_for_pytorch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.1.200%2Bxpu-cp310-cp310-linux_x86_64.whl

triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v3.0.0b2/triton_xpu-3.0.0b2-cp310-cp310-linux_x86_64.whl

60 changes: 43 additions & 17 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
lora_id_to_r = {i + 1: 8 for i in range(max_loras)}

def create_random_embedding_layer():
embedding = VocabParallelEmbedding(vocab_size, 256)
Expand Down Expand Up @@ -219,8 +220,8 @@ def create_random_embedding_layer():
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
vocab_size,
mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r,
max_loras, vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info)

Expand Down Expand Up @@ -257,8 +258,8 @@ def create_random_embedding_layer():
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
vocab_size,
mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r,
max_loras, vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )

Expand All @@ -285,7 +286,10 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
lora_extra_vocab_size=177)
# verify weird extra vocab sizes work without too many tests
lora_id_to_r = {i + 1: 8 for i in range(max_loras)}

def create_random_embedding_layer():
embedding = VocabParallelEmbedding(vocab_size, 256)
Expand Down Expand Up @@ -349,8 +353,8 @@ def create_random_embedding_layer():
(embedding_id + 1) * embeddings_tensor_len - 1)
original_input_[-2] = vocab_size + embeddings_tensor_len - 1

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
vocab_size,
mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r,
max_loras, vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )

Expand Down Expand Up @@ -394,8 +398,8 @@ def create_random_embedding_layer():

original_inputs = deepcopy(inputs)

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
vocab_size,
mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r,
max_loras, vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )

Expand All @@ -420,7 +424,10 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
lora_extra_vocab_size=133)
# verify weird extra vocab sizes work without too many tests
lora_id_to_r = {i + 1: 8 for i in range(max_loras)}

def _pretest():
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
Expand Down Expand Up @@ -467,10 +474,12 @@ def _pretest():
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
lora_id_to_r,
max_loras,
vocab_size,
lora_config.lora_extra_vocab_size,
)

lora_logits_processor.set_mapping(*mapping_info, )

lora_result = lora_logits_processor._get_logits(
Expand All @@ -493,7 +502,8 @@ def _pretest():
lm_head=linear,
embedding_bias=None)
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += (input_ @ lora.lora_a @ lora.lora_b *
lora.scaling)[:, :logits_processor.org_vocab_size]
expected_results.append(result)
expected_result = torch.cat(expected_results)
logits_processor.org_vocab_size = vocab_size
Expand All @@ -512,8 +522,8 @@ def _pretest():
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
vocab_size,
mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r,
max_loras, vocab_size,
lora_config.lora_extra_vocab_size)
lora_logits_processor.set_mapping(*mapping_info, )

Expand Down Expand Up @@ -547,6 +557,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)
# verify weird extra vocab sizes work without too many tests
lora_id_to_r = {i + 1: 8 for i in range(max_loras)}

def create_random_linear_parallel_layer():
if orientation == "row":
Expand Down Expand Up @@ -594,6 +606,7 @@ def create_random_linear_parallel_layer():
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
lora_id_to_r,
max_loras,
512,
lora_config.lora_extra_vocab_size,
Expand All @@ -611,6 +624,9 @@ def create_random_linear_parallel_layer():
expected_result = torch.cat(expected_results)

rtol, atol = TOLERANCES[lora_result.dtype]

diff = (lora_result - expected_result).abs()
print(f'diff max {diff.max()}, mean {diff.mean()}')
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
Expand All @@ -630,8 +646,9 @@ def create_random_linear_parallel_layer():
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r,
max_loras, 512,
lora_config.lora_extra_vocab_size)
lora_linear.set_mapping(*mapping_info, )

lora_result = lora_linear(torch.cat(inputs))[0]
Expand All @@ -657,7 +674,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
lora_extra_vocab_size=19)
# verify weird extra vocab sizes work without too many tests
lora_id_to_r = {i + 1: 8 for i in range(max_loras)}

def create_column_parallel_packed_layer():
if repeats == 2:
Expand Down Expand Up @@ -727,6 +747,7 @@ class FakeConfig:
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
lora_id_to_r,
max_loras,
512,
lora_config.lora_extra_vocab_size,
Expand Down Expand Up @@ -767,6 +788,7 @@ class FakeConfig:
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
lora_id_to_r,
max_loras,
512,
lora_config.lora_extra_vocab_size,
Expand Down Expand Up @@ -808,7 +830,10 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
long_lora_scaling_factors=scaling_factors,
lora_dtype=dtype)
lora_dtype=dtype,
lora_extra_vocab_size=538)
# verify weird extra vocab sizes work without too many tests
lora_id_to_r = {i + 1: 8 for i in range(max_loras)}

if rotary_dim is None:
rotary_dim = head_size
Expand Down Expand Up @@ -857,6 +882,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
lora_id_to_r,
max_loras,
512,
lora_config.lora_extra_vocab_size,
Expand Down
Loading