diff --git a/benchmarks/kernels/benchmark_sgmv_triton.py b/benchmarks/kernels/benchmark_sgmv_triton.py new file mode 100644 index 000000000000..13561b6734d9 --- /dev/null +++ b/benchmarks/kernels/benchmark_sgmv_triton.py @@ -0,0 +1,149 @@ +import math + +import torch +import triton + +from vllm.model_executor.layers.lora import sgmv_triton as sgmv + +MAX_TEST_POWER = 6 + + +# duplicated from tests/lora/test_sgmv_triton.py so there isn't a dependency +# on the tests module +def setup_(S, R, H, dtype, repeats_per_lora=1): + S = math.ceil(S / repeats_per_lora) * repeats_per_lora + num_unique = S // repeats_per_lora + if R is None: + ranks = torch.randint(3, MAX_TEST_POWER, (S, ), device='cuda') + ranks = 2**ranks # random powers of 2 between [8, MAX_TEST_POWER] + R = 2**(MAX_TEST_POWER - 1) + else: + ranks = torch.full((S, ), R, device='cuda', dtype=torch.int32) + weights = torch.randn((num_unique, 1, H, R), device='cuda', dtype=dtype) + indices = torch.randint(0, num_unique, (num_unique, ), device='cuda') + repeats = torch.full((num_unique, ), + repeats_per_lora, + device='cuda', + dtype=torch.int32) + repeats = torch.cat([ + torch.zeros((1, ), device='cuda', dtype=torch.int32), + repeats.cumsum(dim=-1) + ]) + return (weights, ranks, indices, repeats, num_unique, R, dtype) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['S'], # argument names to use as an x-axis for the plot + x_vals=[16 * 2**i for i in range(3, 6)] + + [4096], # different possible values for `x_name` + line_arg= + 'R', # argument name which corresponds to a different line in the plot + line_vals=[64, None], # possible values for `line_arg`` + line_names=['Rank=64', f'Random Rank up to {2**MAX_TEST_POWER}' + ], # label name for the lines + styles=[('blue', '-'), ('green', '-')], # line styles + ylabel="ms", # label name for the y-axis + plot_name= + "sgmv", # name for the plot. Used as file name for saving the plot too. + args={}, + )) +def benchmark_repeats_expand(S, R, repeats_per_lora=1): + weights, ranks, indices, repeats, _, R, dtype = setup_( + S, R, 4096, dtype=torch.bfloat16, repeats_per_lora=repeats_per_lora) + + buffer = torch.randn((S, R), device='cuda', dtype=torch.float32) + out = torch.randn((S, 4096), device='cuda', dtype=dtype) + ms = triton.testing.do_bench(lambda: sgmv.sgmv_expand(buffer, + weights, + out, + ranks, + indices, + repeats, + repeats_per_lora, + out_col_offset=0), + warmup=500, + rep=4000) + return ms + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['S'], # argument names to use as an x-axis for the plot + x_vals=[16 * 2**i for i in range(3, 6)] + + [4096], # different possible values for `x_name` + line_arg= + 'R', # argument name which corresponds to a different line in the plot + line_vals=[64, None], # possible values for `line_arg`` + line_names=['Rank=64', f'Random Rank up to {2**MAX_TEST_POWER}' + ], # label name for the lines + styles=[('blue', '-'), ('green', '-')], # line styles + ylabel="ms", # label name for the y-axis + plot_name= + "sgmv", # name for the plot. Used as file name for saving the plot too. + args={}, + )) +def benchmark_repeats_shrink(S, R, repeats_per_lora=1): + weights, ranks, indices, repeats, _, R, dtype = setup_( + S, R, 4096, dtype=torch.bfloat16, repeats_per_lora=repeats_per_lora) + + x = torch.rand((S, 4096), device='cuda', dtype=dtype) + out = torch.zeros((S, R), device='cuda', dtype=torch.float32) + ms = triton.testing.do_bench(lambda: sgmv.sgmv_shrink( + x, weights, out, ranks, indices, repeats, repeats_per_lora), + warmup=500, + rep=4000) + return ms + + +if __name__ == '__main__': + # NOTE: the random rank benchmark is random ranks up to 2^MAX_TEST_POWER, + # not random up to the rank specified, + # so it doesn't change when you change the rank you're testing + print('Times are in ms.') + print('-' * 40) + print('Expand | repeats [1]') + benchmark_repeats_expand.run(show_plots=False, + print_data=True, + repeats_per_lora=1) + print('-' * 40) + print('Shrink | repeats [1]') + benchmark_repeats_shrink.run(show_plots=False, + print_data=True, + repeats_per_lora=1) + + # print('-' * 40) + # print('Expand | repeats [8]') + # benchmark_repeats_expand.run(show_plots=False, + # print_data=True, + # repeats_per_lora=8) + # print('-' * 40) + # print('Shrink | repeats [8]') + # benchmark_repeats_shrink.run(show_plots=False, + # print_data=True, + # repeats_per_lora=8) + + # # set repeats >= 16 for plaid mode + # # (tl.dot is applicable which makes it fast) + # print('-' * 40) + # print('Expand | repeats [16]') + # benchmark_repeats_expand.run(show_plots=False, + # print_data=True, + # repeats_per_lora=16) + # print('-' * 40) + # print('Shrink | repeats [16]') + # benchmark_repeats_shrink.run(show_plots=False, + # print_data=True, + # repeats_per_lora=16) + + print('-' * 40) + print('Expand | repeats [32]') + benchmark_repeats_expand.run(show_plots=False, + print_data=True, + repeats_per_lora=32) + print('-' * 40) + print('Shrink | repeats [32]') + benchmark_repeats_shrink.run(show_plots=False, + print_data=True, + repeats_per_lora=32) + print('-' * 40) diff --git a/requirements-test.txt b/requirements-test.txt index a7604d2e1015..ff8352c996e9 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -14,6 +14,8 @@ peft requests ray sentence-transformers # required for embedding +pandas +matplotlib # comes from Triton benchmarking function sparseml==1.8.0 # required for compressed-tensors compressed-tensors==0.4.0 # required for compressed-tensors diff --git a/requirements-xpu.txt b/requirements-xpu.txt index 48d899ec70ed..c62abfcce147 100644 --- a/requirements-xpu.txt +++ b/requirements-xpu.txt @@ -7,5 +7,5 @@ torch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch- intel_extension_for_pytorch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.1.200%2Bxpu-cp310-cp310-linux_x86_64.whl -triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v3.0.0b2/triton_xpu-3.0.0b2-cp310-cp310-linux_x86_64.whl diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 7207af6b1a4b..465925b2dbf9 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -189,6 +189,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) + lora_id_to_r = {i + 1: 8 for i in range(max_loras)} def create_random_embedding_layer(): embedding = VocabParallelEmbedding(vocab_size, 256) @@ -219,8 +220,8 @@ def create_random_embedding_layer(): ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - vocab_size, + mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r, + max_loras, vocab_size, lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info) @@ -257,8 +258,8 @@ def create_random_embedding_layer(): ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - vocab_size, + mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r, + max_loras, vocab_size, lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) @@ -285,7 +286,10 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, - lora_dtype=torch.float16) + lora_dtype=torch.float16, + lora_extra_vocab_size=177) + # verify weird extra vocab sizes work without too many tests + lora_id_to_r = {i + 1: 8 for i in range(max_loras)} def create_random_embedding_layer(): embedding = VocabParallelEmbedding(vocab_size, 256) @@ -349,8 +353,8 @@ def create_random_embedding_layer(): (embedding_id + 1) * embeddings_tensor_len - 1) original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - vocab_size, + mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r, + max_loras, vocab_size, lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) @@ -394,8 +398,8 @@ def create_random_embedding_layer(): original_inputs = deepcopy(inputs) - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - vocab_size, + mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r, + max_loras, vocab_size, lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) @@ -420,7 +424,10 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, - lora_dtype=torch.float16) + lora_dtype=torch.float16, + lora_extra_vocab_size=133) + # verify weird extra vocab sizes work without too many tests + lora_id_to_r = {i + 1: 8 for i in range(max_loras)} def _pretest(): linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, @@ -467,10 +474,12 @@ def _pretest(): mapping_info = convert_mapping( lora_mapping, id_to_index, + lora_id_to_r, max_loras, vocab_size, lora_config.lora_extra_vocab_size, ) + lora_logits_processor.set_mapping(*mapping_info, ) lora_result = lora_logits_processor._get_logits( @@ -493,7 +502,8 @@ def _pretest(): lm_head=linear, embedding_bias=None) result[:, vocab_size + embeddings_tensor_len:] = float("-inf") - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + result += (input_ @ lora.lora_a @ lora.lora_b * + lora.scaling)[:, :logits_processor.org_vocab_size] expected_results.append(result) expected_result = torch.cat(expected_results) logits_processor.org_vocab_size = vocab_size @@ -512,8 +522,8 @@ def _pretest(): ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - vocab_size, + mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r, + max_loras, vocab_size, lora_config.lora_extra_vocab_size) lora_logits_processor.set_mapping(*mapping_info, ) @@ -547,6 +557,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, max_lora_rank=8, fully_sharded_loras=fully_shard, lora_dtype=torch.float16) + # verify weird extra vocab sizes work without too many tests + lora_id_to_r = {i + 1: 8 for i in range(max_loras)} def create_random_linear_parallel_layer(): if orientation == "row": @@ -594,6 +606,7 @@ def create_random_linear_parallel_layer(): mapping_info = convert_mapping( lora_mapping, id_to_index, + lora_id_to_r, max_loras, 512, lora_config.lora_extra_vocab_size, @@ -611,6 +624,9 @@ def create_random_linear_parallel_layer(): expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] + + diff = (lora_result - expected_result).abs() + print(f'diff max {diff.max()}, mean {diff.mean()}') assert torch.allclose(lora_result, expected_result, rtol=rtol, @@ -630,8 +646,9 @@ def create_random_linear_parallel_layer(): ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + mapping_info = convert_mapping(lora_mapping, id_to_index, lora_id_to_r, + max_loras, 512, + lora_config.lora_extra_vocab_size) lora_linear.set_mapping(*mapping_info, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -657,7 +674,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, fully_sharded_loras=fully_shard, - lora_dtype=torch.float16) + lora_dtype=torch.float16, + lora_extra_vocab_size=19) + # verify weird extra vocab sizes work without too many tests + lora_id_to_r = {i + 1: 8 for i in range(max_loras)} def create_column_parallel_packed_layer(): if repeats == 2: @@ -727,6 +747,7 @@ class FakeConfig: mapping_info = convert_mapping( lora_mapping, id_to_index, + lora_id_to_r, max_loras, 512, lora_config.lora_extra_vocab_size, @@ -767,6 +788,7 @@ class FakeConfig: mapping_info = convert_mapping( lora_mapping, id_to_index, + lora_id_to_r, max_loras, 512, lora_config.lora_extra_vocab_size, @@ -808,7 +830,10 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, long_lora_scaling_factors=scaling_factors, - lora_dtype=dtype) + lora_dtype=dtype, + lora_extra_vocab_size=538) + # verify weird extra vocab sizes work without too many tests + lora_id_to_r = {i + 1: 8 for i in range(max_loras)} if rotary_dim is None: rotary_dim = head_size @@ -857,6 +882,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, mapping_info = convert_mapping( lora_mapping, id_to_index, + lora_id_to_r, max_loras, 512, lora_config.lora_extra_vocab_size, diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py index 3415d36b7e34..bac4134587b1 100644 --- a/tests/lora/test_lora.py +++ b/tests/lora/test_lora.py @@ -55,17 +55,30 @@ def test_apply_lora(m, n, k, rank, dtype) -> None: lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T output = torch.zeros(k, m, device="cuda", dtype=dtype) + ranks = torch.full((8, ), + lora_a_stack.shape[2], + dtype=torch.int32, + device='cuda') + repeats = torch.arange(0, len(input) + 1, dtype=torch.int32, device='cuda') + _apply_lora( - input, lora_a_stack, lora_b_stack, + input, + lora_a_stack, + lora_b_stack, + output, torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"), - output) + ranks, + repeats, + 1, + ) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora(input, lora_a_stack, lora_b_stack, - torch.full((len(input), ), -1, device="cuda"), output) + _apply_lora(input, lora_a_stack, lora_b_stack, output, + torch.full((len(input), ), -1, device="cuda"), ranks, repeats, + 1) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() @@ -122,19 +135,24 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T output = torch.zeros(k, m, device="cuda", dtype=dtype) + ranks = torch.full((8, ), + lora_a_stacks[0].shape[2], + dtype=torch.int32, + device='cuda') + repeats = torch.arange(0, len(input) + 1, dtype=torch.int32, device='cuda') _apply_lora_packed_nslice( - input, lora_a_stacks, lora_b_stacks, + input, lora_a_stacks, lora_b_stacks, output, torch.randint(0, lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, (m // 2, m // 2)) + device="cuda"), ranks, repeats, 1) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, output, torch.full((len(input), ), -1, device="cuda"), - output, (m // 2, m // 2)) + ranks, repeats, 1) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() @@ -206,19 +224,24 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype) + ranks = torch.full((8, ), + lora_a_stacks[0].shape[2], + dtype=torch.int32, + device='cuda') + repeats = torch.arange(0, len(input) + 1, dtype=torch.int32, device='cuda') _apply_lora_packed_nslice( - input, lora_a_stacks, lora_b_stacks, + input, lora_a_stacks, lora_b_stacks, output, torch.randint(0, lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, (qkv[0], qkv[1], qkv[2])) + device="cuda"), ranks, repeats, 1) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, output, torch.full((len(input), ), -1, device="cuda"), - output, (qkv[0], qkv[1], qkv[2])) + ranks, repeats, 1) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py deleted file mode 100644 index dbeb16cb21ad..000000000000 --- a/tests/lora/test_punica.py +++ /dev/null @@ -1,258 +0,0 @@ -# Based on code from https://github.com/punica-ai/punica - -import pytest -import torch - -import vllm.lora.punica as punica - - -def assert_close(a, b): - rtol, atol = { - torch.float16: (5e-3, 5e-3), - torch.bfloat16: (3e-2, 2e-2), - torch.float32: (None, None), - }[a.dtype] - torch.testing.assert_close(a, b, rtol=rtol, atol=atol) - - -def _lora_ref_impl( - y_final: torch.Tensor, - x: torch.Tensor, - wa_T_all: torch.Tensor, - wb_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, -): - y_stage_1 = torch.empty( - (x.size(0), wa_T_all.size(-2)), - dtype=torch.float32, - device=x.device, - ) - bs = x.shape[0] - s = torch.tensor(scale, dtype=torch.float32, device=x.device) - for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): - xi = x[i].unsqueeze(0).to(torch.float32) - wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) - if wb_T_all is not None: - wb = wb_T_all[lora_idx, layer_idx].transpose(-1, - -2).to(torch.float32) - - tmp = xi @ wa - y_stage_1[i] = tmp.squeeze(0) - y_final[i] += ((tmp @ wb).squeeze(0) * - s if wb_T_all is not None else y_stage_1[i]) - return y_final, y_stage_1 - - -H1 = H2 = [ - 128, - 256, - 512, - 896, - 1024, - 1152, - 1216, - 1280, - 1536, - 1664, - 2048, - 2240, - 2304, - 2368, - 2432, - 2560, - 2752, - 3072, - 3328, - 3456, - 3584, - 3712, - 4096, - 4480, - 4608, - 4736, - 4864, - 5120, - 5504, - 5632, - 5888, - 6144, - 6400, - 6848, - 6912, - 7168, - 7424, - 8192, - 8960, - 9216, - 9472, - 10240, - 11008, - 11264, - 13824, - 14336, - 14784, - 14848, - 15360, - 18944, - 22016, - 22528, - 24576, - 27392, - 27648, - 29568, - 29696, - 32000, - 32256, - 32512, - 32768, - 33024, - 36864, - 43264, - 49152, - 49408, - 60544, - 60672, - 64000, - 64256, - 102400, - 102656, - 128000, - 128256, -] -H2 = [64] + H2 -R = [1, 2, 4] -SEED = [0xabcdabcd987] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - - -@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) -@pytest.mark.parametrize("h1", H1) -@pytest.mark.parametrize("r", R) -@pytest.mark.parametrize("seed", SEED) -@torch.inference_mode() -def test_lora_a_extra_shapes(dtype_str, h1, r, seed): - torch.manual_seed(seed) - num_loras = 4 - num_layers = 1 - bs = 32 - dtype = getattr(torch, dtype_str) - device = torch.device("cuda") - - wa_T_all = torch.randn(num_loras, - num_layers, - r, - h1, - dtype=dtype, - device=device) - indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) - - for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype, device=device) - y = torch.randn(bs, r, dtype=dtype, device=device) - - y_ref = y.clone() - _lora_ref_impl( - y_ref, - x, - wa_T_all, - None, - indices, - layer_idx, - 1.0, - ) - - y_our = y.clone() - punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0) - - assert_close(y_ref, y_our) - - -@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) -@pytest.mark.parametrize("h1", H1) -@pytest.mark.parametrize("h2", H2) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_lora_correctness(dtype_str, h1, h2, seed, device): - torch.manual_seed(seed) - num_loras = 4 - num_layers = 1 - r = 8 - bs = 32 - scale = 0.123 - dtype = getattr(torch, dtype_str) - torch.set_default_device(device) - - wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype) - indices = torch.randint(num_loras, (bs, ), dtype=torch.long) - - for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype) - y = torch.randn(bs, h2, dtype=dtype) - - y_ref = y.clone() - _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) - - y_our = y.clone() - punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx, - scale) - - assert_close(y_ref, y_our) - - -@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) -@pytest.mark.parametrize("h1", H1) -@pytest.mark.parametrize("h2", H2) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_lora_correctness_slice(dtype_str, h1, h2, seed, device): - if h2 % 3 != 0 or h2 // 3 not in H1: - pytest.skip("h2 must be divisible by 3 and in supported shapes") - torch.manual_seed(seed) - num_loras = 4 - num_layers = 1 - r = 8 - bs = 32 - scale = 0.123 - dtype = getattr(torch, dtype_str) - torch.set_default_device(device) - - wa_T_all_0 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wa_T_all_1 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wa_T_all_2 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wb_T_all_0 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) - wb_T_all_1 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) - wb_T_all_2 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) - - indices = torch.randint(num_loras, (bs, ), dtype=torch.long) - - for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype) - y = torch.randn(bs, h2, dtype=dtype) - s = h2 // 3 - - y_ref = y.clone() - _lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices, - layer_idx, scale) - _lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices, - layer_idx, scale) - _lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices, - layer_idx, scale) - - y_our = y.clone() - punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices, - layer_idx, scale, 0, s) - punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices, - layer_idx, scale, s, s) - punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices, - layer_idx, scale, s * 2, s) - - assert_close(y_ref[:, :s], y_our[:, :s]) - assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2]) - assert_close(y_ref[:, s * 2:], y_our[:, s * 2:]) diff --git a/tests/lora/test_sgmv_triton.py b/tests/lora/test_sgmv_triton.py new file mode 100644 index 000000000000..bb8addf56e50 --- /dev/null +++ b/tests/lora/test_sgmv_triton.py @@ -0,0 +1,121 @@ +import math + +import pytest +import torch + +from vllm.model_executor.layers.lora import sgmv_triton as sgmv + +MAX_TEST_POWER = 6 +SEED = 42 + + +def assert_close(a, b, dtype, tl_dot=False): + rtol, atol = { + torch.float16: (5e-3, 5e-3) if not tl_dot else (1e-2, 7e-2), + torch.bfloat16: (3e-2, 2e-2) if not tl_dot else (3e-2, 1e-1), + torch.float32: (2e-3, 3e-4) if not tl_dot else (1e-2, 7e-2), + }[dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +def setup_(S, R, H, dtype, repeats_per_lora=1): + S = math.ceil(S / repeats_per_lora) * repeats_per_lora + num_unique = S // repeats_per_lora + if R is None: + ranks = torch.randint(3, MAX_TEST_POWER, (S, ), device='cuda') + ranks = 2**ranks # random powers of 2 between [8, MAX_TEST_POWER] + R = 2**(MAX_TEST_POWER - 1) + else: + ranks = torch.full((S, ), R, device='cuda', dtype=torch.int32) + weights = torch.randn((num_unique, 1, H, R), device='cuda', dtype=dtype) + indices = torch.randint(0, num_unique, (num_unique, ), device='cuda') + repeats = torch.full((num_unique, ), + repeats_per_lora, + device='cuda', + dtype=torch.int32) + repeats = torch.cat([ + torch.zeros((1, ), device='cuda', dtype=torch.int32), + repeats.cumsum(dim=-1) + ]) + return (weights, ranks, indices, repeats, num_unique, R, dtype) + + +@pytest.mark.parametrize("S", [16 * 2**i for i in range(3, 4)] + [4096]) +@pytest.mark.parametrize("R", [2**r for r in range(MAX_TEST_POWER)]) +@pytest.mark.parametrize("H", [64, 4096, 7491]) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("repeats_per_lora", [1, 16]) +@pytest.mark.parametrize("seed", [SEED]) +@torch.inference_mode() +def test_correct(S, R, H, dtype, repeats_per_lora, seed): + torch.set_printoptions(precision=2, linewidth=1000, sci_mode=False) + torch.manual_seed(seed) + weights, ranks, indices, repeats, num_unique, R, dtype = (setup_( + S, R, H, dtype, repeats_per_lora)) + + buffer = torch.randn((S, R), device='cuda', dtype=torch.float32) + out_col_offset = 77 + out = torch.randn((S, H + out_col_offset), device='cuda', dtype=dtype) + ref_outs = [] + for ui in range(num_unique): + idx = indices[ui] + w = weights[idx, 0, :, :ranks[idx]].T.contiguous() + inp = buffer[repeats[ui]:repeats[ui + 1], :ranks[idx]].contiguous() + ref_out = inp.to(dtype=torch.float32) @ w.to(dtype=torch.float32) + ref_outs.append(ref_out) + + ref_out = torch.cat(ref_outs, dim=0) + # doing this apparently leads to incorrect results in the first row + # + out[:, out_col_offset:] + ref_out = (ref_out + + out[:, out_col_offset:].to(dtype=torch.float32)).to(dtype=dtype) + # but this does not (likely depends on torch version) + + sgmv.sgmv_expand(buffer, + weights, + out, + ranks, + indices, + repeats, + repeats_per_lora, + out_col_offset=out_col_offset) + + # diff = (ref_out - out[:, out_col_offset:]).abs() + # print(f'max diff {diff.max():0.5f}, mean {diff.mean():0.5f}') + # triton.language.dot, which is used for improved speed when + # rank and repeats are >= 16 + # gives larger differences from torch + assert_close(ref_out, + out[:, out_col_offset:], + dtype=dtype, + tl_dot=repeats_per_lora >= 9) + + weights = weights.permute(0, 1, 3, 2).contiguous() + x = torch.rand((S, H), device='cuda', dtype=dtype) + out = torch.zeros((S, R), device='cuda', dtype=torch.float32) + ref_outs = [] + for ui in range(num_unique): + idx = indices[ui] + w = weights[idx, 0, :ranks[idx], :].T.contiguous() + inp = x[repeats[ui]:repeats[ui + 1]].contiguous() + ref_out = inp.to(dtype=torch.float32) @ w.to(dtype=torch.float32) + ref_out = torch.cat([ + ref_out, + torch.zeros((ref_out.shape[0], R - ref_out.shape[-1]), + dtype=ref_out.dtype, + device='cuda') + ], + dim=-1) + ref_outs.append(ref_out) + + ref_out = torch.cat(ref_outs, dim=0) + ref_out += out + + sgmv.sgmv_shrink(x, weights, out, ranks, indices, repeats, + repeats_per_lora) + + # diff = (ref_out - out).abs() + # print(f'max diff {diff.max():0.5f}, mean {diff.mean():0.5f}') + assert_close(ref_out, out, dtype=dtype, tl_dot=repeats_per_lora >= 9) + torch.cuda.empty_cache() diff --git a/vllm/config.py b/vllm/config.py index 1ea288879680..f20f047f7a28 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1244,16 +1244,11 @@ class LoRAConfig: def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h - possible_max_ranks = (8, 16, 32, 64) - possible_lora_extra_vocab_size = (0, 256, 512) + possible_max_ranks = (1, 2, 4, 8, 16, 32, 64, 128) if self.max_lora_rank not in possible_max_ranks: raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " f"{possible_max_ranks}.") - if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: - raise ValueError( - f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " - f"must be one of {possible_lora_extra_vocab_size}.") if self.max_loras < 1: raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") if self.max_cpu_loras is None: diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index d27171f72083..40b2385ef4d9 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -14,7 +14,8 @@ MergedQKVParallelLinearWithLora, QKVParallelLinearWithLora, RowParallelLinearWithLoRA) -from vllm.lora.punica import bgmv, dispatch_bgmv_low_level +from vllm.model_executor.layers.lora.sgmv_triton import (sgmv_expand, + sgmv_shrink) if TYPE_CHECKING: pass @@ -63,11 +64,15 @@ def apply(self, x: torch.Tensor, dtype=torch.float32, device=x.device) - bgmv(buffer, x, self.lora_a_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + sgmv_shrink(x, self.lora_a_stacked, buffer, self.ranks, + self.indices[:self.indices_len[0]], + self.repeats[:self.indices_len[0] + 1], + self.max_repeats[0]) buffer = tensor_model_parallel_all_gather(buffer) - bgmv(output, buffer, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + sgmv_expand(buffer, self.lora_b_stacked, output, self.ranks, + self.indices[:self.indices_len[0]], + self.repeats[:self.indices_len[0] + 1], + self.max_repeats[0], 0, 1.0) # now have column partitioned output output = output.view(*out_orig_shape) @@ -108,17 +113,19 @@ def _mcp_apply(x, bias, layer): dtype=torch.float32, device=x.device) for idx in range(n): - bgmv(buffers[idx], x, layer.lora_a_stacked[idx], - layer.indices[:layer.indices_len[0]], 0, 1.0) + sgmv_shrink(x, layer.lora_a_stacked[idx], buffers[idx], layer.ranks, + layer.indices[:layer.indices_len[0]], + layer.repeats[:layer.indices_len[0] + 1], + layer.max_repeats[0]) buffers = tensor_model_parallel_all_gather(buffers) left_offset = 0 for idx in range(n): shard_size = layer.lora_b_stacked[idx].shape[2] - dispatch_bgmv_low_level(output, buffers[idx], - layer.lora_b_stacked[idx], - layer.indices[:layer.indices_len[0]], 0, 1.0, - left_offset, shard_size) + sgmv_expand(buffers[idx], layer.lora_b_stacked[idx], output, + layer.ranks, layer.indices[:layer.indices_len[0]], + layer.repeats[:layer.indices_len[0] + 1], + layer.max_repeats[0], left_offset, 1.0) left_offset += shard_size output = output.view(*out_orig_shape) @@ -194,11 +201,15 @@ def apply(self, x: torch.Tensor, dtype=torch.float32, device=x.device) - bgmv(buffer, x, self.lora_a_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + sgmv_shrink(x, self.lora_a_stacked, buffer, self.ranks, + self.indices[:self.indices_len[0]], + self.repeats[:self.indices_len[0] + 1], + self.max_repeats[0]) buffer = tensor_model_parallel_all_gather(buffer) - bgmv(output, buffer, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + sgmv_expand(buffer, self.lora_b_stacked, output, self.ranks, + self.indices[:self.indices_len[0]], + self.repeats[:self.indices_len[0] + 1], + self.max_repeats[0], 0, 1.0) # now have column partitioned output output = output.view(*out_orig_shape) @@ -286,8 +297,10 @@ def apply(self, x: torch.Tensor) -> torch.Tensor: buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), dtype=torch.float32, device=x.device) - bgmv(buffer, x, self.lora_a_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + sgmv_shrink(x, self.lora_a_stacked, buffer, self.ranks, + self.indices[:self.indices_len[0]], + self.repeats[:self.indices_len[0] + 1], + self.max_repeats[0]) buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce @@ -296,12 +309,13 @@ def apply(self, x: torch.Tensor) -> torch.Tensor: # remains is a standard all_reduce. User should be aware though that # the output is not the same as a normal row_parallel, it should be # reduced before being used - shard_size = self.lora_b_stacked.shape[2] - start_idx = self.tp_rank * shard_size - dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0, - start_idx, shard_size) + shard_size = self.lora_b_stacked.shape[2] + col_offset = self.tp_rank * shard_size + sgmv_expand(buffer, self.lora_b_stacked, output, self.ranks, + self.indices[:self.indices_len[0]], + self.repeats[:self.indices_len[0] + 1], + self.max_repeats[0], col_offset, 1.0) output = output.view(*out_orig_shape) return output diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 0a63f9ef012b..aadaec23f91c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -16,7 +16,7 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_gather) from vllm.distributed.utils import divide -from vllm.lora.punica import add_lora, add_lora_slice, bgmv +from vllm.lora.sgmv import add_lora, sgmv from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -62,13 +62,10 @@ def dec(*args, **kwargs): return dec -def _apply_lora( - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - indices: torch.Tensor, - output: torch.Tensor, -): +def _apply_lora(x: torch.Tensor, lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, output: torch.Tensor, + indices: torch.LongTensor, ranks: torch.LongTensor, + repeats: torch.LongTensor, max_repeats: int): """Applies lora to each input. This method applies all loras to each input. It uses the @@ -79,16 +76,19 @@ def _apply_lora( Input shapes: x: (batch_size, hidden_dim) - lora_a_stacked: (num_loras, lora_rank, hidden_dim) - lora_b_stacked: (num_loras, output_dim, lora_rank) - indices: (batch_size) + lora_a_stacked: (num_loras, 1, lora_rank, hidden_dim) + lora_b_stacked: (num_loras, 1, output_dim, lora_rank) output: (batch_size, output_dim) + indices: (num_lora_token_groups) + ranks: (num_lora_token_groups) + repeats: (num_lora_token_groups) """ org_output = output x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) - add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, ranks, + repeats, max_repeats, 0, 1.0) return output.view_as(org_output) @@ -96,9 +96,11 @@ def _apply_lora_packed_nslice( x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - indices: torch.Tensor, output: torch.Tensor, - output_slices: Tuple[int, ...], + indices: torch.Tensor, + ranks: torch.LongTensor, + repeats: torch.LongTensor, + max_repeats: int, ): """Applies lora to each input. @@ -113,23 +115,25 @@ def _apply_lora_packed_nslice( Input shapes: x: (batch_size, hidden_dim) - lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) - lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) - indices: (batch_size) + lora_a_stacked: 3 element tuple of + (num_loras, 1, lora_rank, hidden_dim) + lora_b_stacked: 3 element tuple of + (num_loras, 1, output_dim, lora_rank) output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices + indices: (num_lora_token_groups) + ranks: (num_lora_token_groups) + repeats: (num_lora_token_groups) """ org_output = output x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) offset_left = 0 - for slice_idx in range(len(output_slices)): - add_lora_slice(output, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, - output_slices[slice_idx]) - offset_left += output_slices[slice_idx] + for slice_idx in range(len(lora_a_stacked)): + add_lora(output, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], indices, ranks, repeats, + max_repeats, offset_left, 1.0) + offset_left += lora_b_stacked[slice_idx].shape[2] return output.view_as(org_output) @@ -184,6 +188,9 @@ def set_lora( def set_mapping( self, base_indices: torch.Tensor, + repeats: torch.Tensor, + max_repeats: List[int], + ranks: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, @@ -269,6 +276,9 @@ def create_lora_weights( self.indices: torch.Tensor self.indices_len: List[int] self.embeddings_indices: torch.Tensor + self.ranks: torch.Tensor + self.repeats: torch.Tensor + self.max_repeats: List[int] def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -306,6 +316,9 @@ def set_lora( def set_mapping( self, base_indices: torch.Tensor, + repeats: torch.Tensor, + max_repeats: List[int], + ranks: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, @@ -315,6 +328,9 @@ def set_mapping( self.indices = base_indices self.embeddings_indices = embeddings_indices self.indices_len = indices_len + self.ranks = ranks + self.repeats = repeats + self.max_repeats = max_repeats def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 @@ -336,8 +352,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings = full_lora_a_embeddings.view( full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1], -1) - bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + sgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, + self.indices[:self.indices_len[0]], self.ranks, + self.repeats[:self.indices_len[0] + 1], self.max_repeats[0], 0, + 1.0) return full_output.view_as(full_output_org) @classmethod @@ -432,6 +450,9 @@ def set_lora( def set_mapping( self, base_indices: torch.Tensor, + repeats: torch.Tensor, + max_repeats: List[int], + ranks: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, @@ -440,17 +461,17 @@ def set_mapping( ): self.indices = base_indices self.indices_len = indices_len + self.repeats = repeats + self.max_repeats = max_repeats + self.ranks = ranks def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - _apply_lora( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - ) + _apply_lora(x, self.lora_a_stacked, self.lora_b_stacked, output, + self.indices[:self.indices_len[0]], self.ranks, + self.repeats[:self.indices_len[0] + 1], + self.max_repeats[0]) return output def forward(self, input_): @@ -597,14 +618,11 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - _apply_lora_packed_nslice( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - (self.output_dim, self.output_dim), - ) + _apply_lora_packed_nslice(x, self.lora_a_stacked, self.lora_b_stacked, + output, self.indices[:self.indices_len[0]], + self.ranks, + self.repeats[:self.indices_len[0] + 1], + self.max_repeats[0]) return output @classmethod @@ -857,14 +875,11 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - _apply_lora_packed_nslice( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - self.output_slices, - ) + _apply_lora_packed_nslice(x, self.lora_a_stacked, self.lora_b_stacked, + output, self.indices[:self.indices_len[0]], + self.ranks, + self.repeats[:self.indices_len[0] + 1], + self.max_repeats[0]) return output @classmethod @@ -959,6 +974,9 @@ def set_lora( def set_mapping( self, base_indices: torch.Tensor, + repeats: torch.Tensor, + max_repeats: List[int], + ranks: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, @@ -967,16 +985,16 @@ def set_mapping( ): self.indices = base_indices self.indices_len = indices_len + self.repeats = repeats + self.max_repeats = max_repeats + self.ranks = ranks def apply(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) - _apply_lora( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - ) + _apply_lora(x, self.lora_a_stacked, self.lora_b_stacked, output, + self.indices[:self.indices_len[0]], self.ranks, + self.repeats[:self.indices_len[0] + 1], + self.max_repeats[0]) return output def forward(self, input_): @@ -1088,9 +1106,10 @@ def create_lora_weights( model_config: Optional[PretrainedConfig] = None, ) -> None: # Keep this in sync with csrc/punica/bgmv/bgmv_config.h - if 32000 < self.base_layer.vocab_size > 128512: - raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 128512") + # deprecated by Triton + # if 32000 < self.base_layer.vocab_size > 128512: + # raise ValueError("When using LoRA, vocab size must be " + # "32000 >= vocab_size <= 128512") self.lora_a_stacked = torch.zeros( ( max_loras, @@ -1159,6 +1178,9 @@ def set_lora( def set_mapping( self, base_indices: torch.Tensor, + repeats: torch.Tensor, + max_repeats: List[int], + ranks: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, @@ -1168,6 +1190,10 @@ def set_mapping( self.indices = sampler_indices self.indices_padded = sampler_indices_padded self.indices_len = indices_len + size = max(base_indices.shape[0], sampler_indices.shape[0]) + self.repeats = torch.arange(0, size + 1, device='cuda') + self.max_repeats = [1] + self.ranks = ranks def _get_logits( self, @@ -1226,13 +1252,10 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits - _apply_lora( - hidden_states, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[1]], - logits, - ) + _apply_lora(hidden_states, self.lora_a_stacked, self.lora_b_stacked, + logits, self.indices[:self.indices_len[1]], self.ranks, + self.repeats[:self.indices_len[1] + 1], + self.max_repeats[0]) # Remove paddings in vocab (if any). logits = logits[:, :self.base_layer.vocab_size] @@ -1311,6 +1334,9 @@ def set_lora( def set_mapping( self, base_indices: torch.Tensor, + repeats: torch.Tensor, + max_repeats: List[int], + ranks: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 689835def83d..d0b29b5ce2f3 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -18,6 +18,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.layers.lora.sgmv_triton import MAX_REPEATS_PER_BLOCK from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.utils import LRUCache, is_pin_memory_available @@ -38,20 +39,56 @@ class LongContextLoRAContext: offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) +def group_and_chunk_indices(indices: torch.Tensor): + """ + For sgmv, this first groups the loras in indices by their id + (assuming they are consecutive). The number of tokens for each lora + is counted in repeats (num times the lora is *repeated*) + + Then, it splits these into groups of MAX_REPEATS_PER_BLOCK + remainder + (the kernel launches a row of blocks for each value in indices, + or each lora, which can easily be too large and cause the kernel + to run out of shared memory, so it is expected to be + broken into chunks beforehand) + + Args: + indices: tensor [total #tokens in the batch] + mapping each token to a lora index + Returns: + indices: tensor [#lora token groups] + mapping each group to the lora index + repeats: tensor [#lora token groups + 1] a cumulative sum of the + number of tokens for each lora token group, with 0 at the beginning + """ + unique, repeats = indices.unique_consecutive(return_counts=True) + num_chunks = (repeats + MAX_REPEATS_PER_BLOCK - 1) // MAX_REPEATS_PER_BLOCK + indices = unique.repeat_interleave(num_chunks) + overcount = num_chunks * MAX_REPEATS_PER_BLOCK - repeats + repeats = torch.full_like(indices, MAX_REPEATS_PER_BLOCK) + repeats[num_chunks.cumsum(dim=0) - 1] -= overcount + repeats_ = torch.zeros((indices.shape[0] + 1, ), + device='cuda', + dtype=torch.long) + repeats_[1:] = repeats.cumsum(dim=0) + return indices, repeats_ + + def convert_mapping( mapping: LoRAMapping, lora_index_to_id: List[Optional[int]], + lora_id_to_r: Dict[int, int], max_loras: int, vocab_size: int, extra_vocab_size: int, long_lora_context: Optional[LongContextLoRAContext] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: +) -> Tuple[torch.Tensor, torch.Tensor, List[int], torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, Optional[torch.Tensor], List[int]]: """Converts LoRAMapping to index tensors. Args: mapping: LoRAMapping mapping rows in a batch to LoRA ids. lora_index_to_id: List mapping LoRA ids to LoRA indices. + lora_id_to_r: Dict mapping LoRA ids to LoRA rank max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. extra_vocab_size: Extra vocab size each LoRA can have. @@ -61,6 +98,12 @@ def convert_mapping( A tuple of tensors: base_indices: Tensor of shape [batch_size] mapping batch rows to LoRA indices. + repeats: Tensor [#lora token groups + 1] a cumulative + sum of the number of tokens for each lora token group, + with 0 at the beginning + max_repeats: List[1] caching the maximum of repeats, + so it isn't recalculated at every layer + ranks: Tensor [max_num_loras] mapping LoRA index to rank sampler_indices: Tensor of shape [batch_size] mapping requests to LoRA indices for sampler. For generation, this will be the same as base_indicies. For prefill, this will map requests @@ -85,6 +128,7 @@ def convert_mapping( index_mapping_indices: List[int] = list(mapping.index_mapping).copy() embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() + lora_ranks = [8] * len(lora_index_to_id) long_lora_offsets: Optional[torch.Tensor] = None if long_lora_context: long_lora_offsets = torch.zeros(len(index_mapping_indices), @@ -101,6 +145,8 @@ def convert_mapping( if index_mapping_indices[i] > 0 else -1) embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx + if lora_idx > 0: + lora_ranks[lora_idx] = lora_id_to_r[index_mapping_indices[i]] if long_lora_context: assert long_lora_offsets is not None lora_offset: int = long_lora_context.offsets_by_lora_id.get( @@ -114,6 +160,7 @@ def convert_mapping( assert long_lora_offsets is not None indices_list.append(long_lora_offsets) indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + ranks = torch.tensor(lora_ranks, device='cuda', dtype=torch.long) prompt_mapping_tensor = torch.tensor(prompt_mapping, device="cuda", dtype=torch.long) @@ -122,7 +169,7 @@ def convert_mapping( indices[2] * (vocab_size + extra_vocab_size) ]) embeddings_indices[embeddings_indices == -1] = max_loras - 1 - base_indices = indices[1] + base_indices, repeats = group_and_chunk_indices(indices[1]) sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 @@ -143,8 +190,11 @@ def convert_mapping( if long_lora_indices_len is not None: indices_len.append(long_lora_indices_len) - return (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices, indices_len) + max_repeats = ((repeats[1:] - + repeats[:-1]).max().item() if repeats.numel() > 1 else 0) + return (base_indices, repeats, [max_repeats], ranks, sampler_indices, + sampler_indices_padded, embeddings_indices, long_lora_indices, + indices_len) def get_lora_id(): @@ -415,11 +465,19 @@ def __init__( assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots + self.lora_id_to_r: Dict[int, int] = {} self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") + self.repeats = torch.empty(self.max_num_batched_tokens + 1, + dtype=torch.long, + device='cuda') + self.max_repeats = [1] + self.ranks = torch.empty(self.lora_slots, + dtype=torch.long, + device="cuda") self.sampler_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") @@ -502,6 +560,8 @@ def _deactivate_lora(self, lora_id: int): try: index = self.lora_index_to_id.index(lora_id) self.lora_index_to_id[index] = None + if lora_id in self.lora_id_to_r: + del self.lora_id_to_r[lora_id] except ValueError: pass @@ -531,6 +591,7 @@ def _set_long_lora_context(self, lora: LoRAModel): def _add_lora(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) self._registered_loras[lora.id] = lora + self.lora_id_to_r[lora.id] = lora.rank self._set_long_lora_context(lora) def add_lora(self, lora: LoRAModel) -> bool: @@ -562,13 +623,17 @@ def pin_lora(self, lora_id: int) -> bool: # TODO see if this can be vectorized def _set_lora_mapping(self, mapping: LoRAMapping) -> None: - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_offsets_tensor, + (base_indices, repeats, max_repeats, ranks, sampler_indices, + sampler_indices_padded, embeddings_indices, long_lora_offsets_tensor, indices_len) = convert_mapping(mapping, self.lora_index_to_id, - self.lora_slots + 1, self.vocab_size, + self.lora_id_to_r, self.lora_slots + 1, + self.vocab_size, self.lora_config.lora_extra_vocab_size, self.long_lora_context) self.base_indices[:base_indices.shape[0]].copy_(base_indices) + self.repeats[:repeats.shape[0]].copy_(repeats) + self.max_repeats[:] = max_repeats + self.ranks.copy_(ranks) self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( sampler_indices_padded) @@ -631,7 +696,9 @@ def _create_lora_modules(self): self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) - new_module.set_mapping(self.base_indices, self.sampler_indices, + new_module.set_mapping(self.base_indices, self.repeats, + self.max_repeats, self.ranks, + self.sampler_indices, self.sampler_indices_padded, self.embeddings_indices, self.long_lora_indices, self.indices_len) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py deleted file mode 100644 index 64f87a4b2c69..000000000000 --- a/vllm/lora/punica.py +++ /dev/null @@ -1,207 +0,0 @@ -# Based on code from https://github.com/punica-ai/punica - -from typing import Optional - -import torch - -from vllm import _custom_ops as ops -from vllm.platforms import current_platform - - -def _check_punica_support(): - if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): - return - - if current_platform.get_device_capability() < (8, 0): - raise ImportError( - "punica LoRA kernels require compute capability >= 8.0") - else: - raise ImportError( - "punica LoRA kernels could not be imported. If you built vLLM " - "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " - "was set.") - - -def bgmv( - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, -): - """ - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight - matrices. - indicies: Shape: `[B]`. Indices of the weight matrices. - layer_idx: Layer index of the weight matrices. - scale: Scaling factor. - """ - _check_punica_support() - - ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) - - -def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, indicies: torch.LongTensor, - layer_idx: int, scale: float, y_offset: int, - y_slice_size: int): - """ - Same as `bgmv` but you can operate on slices of y. - Pass whole y, define y_offset and y_slice_size. - - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of - all of the transposed LoRA matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. - y_offset: Offset to apply to the starting column of y. - y_slice_size: Size of the y column slice. - """ - _check_punica_support() - - ops.dispatch_bgmv_low_level( - y, - x, - w_t_all, - indicies, - layer_idx, - scale, - x.size(1), - y_slice_size, - y_offset, - ) - - -def add_lora(y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, - *, - buffer: Optional[torch.Tensor] = None): - """ - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed - LoRA A matrices. - wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed - LoRA B matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. - buffer: Optional. Shape: `[B, R]`. Temporary buffer. - """ - _check_punica_support() - - r = wb_t_all.size(-1) - if buffer is None: - # We set the buffer to be float32 by default to avoid - # numerical inaccuracies that would otherwise happen - # due to downcasting. - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) - ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale) - - -def add_lora_slice(y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, - y_offset: int, - y_slice_size: int, - *, - buffer: Optional[torch.Tensor] = None): - """ - Same as `add_lora` but you can operate on slices of y. - Pass whole y, define y_offset and y_slice_size. - - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed - LoRA A matrices. - wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed - LoRA B matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. - y_offset: Offset to apply to the starting column of y. - y_slice_size: Size of the y column slice. - """ - _check_punica_support() - - r = wb_t_all.size(-1) - if buffer is None: - # We set the buffer to be float32 by default to avoid - # numerical inaccuracies that would otherwise happen - # due to downcasting. - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - ops.dispatch_bgmv_low_level( - buffer, - x, - wa_t_all, - indicies, - layer_idx, - 1.0, - x.size(1), - buffer.size(1), - 0, - ) - ops.dispatch_bgmv_low_level( - y, - buffer, - wb_t_all, - indicies, - layer_idx, - scale, - buffer.size(1), - y_slice_size, - y_offset, - ) diff --git a/vllm/lora/sgmv.py b/vllm/lora/sgmv.py new file mode 100644 index 000000000000..abdaeaf06b49 --- /dev/null +++ b/vllm/lora/sgmv.py @@ -0,0 +1,99 @@ +# Based on code from https://github.com/punica-ai/punica + +from typing import Optional + +import torch + +from vllm.model_executor.layers.lora import sgmv_triton + + +def sgmv( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indices: torch.LongTensor, + ranks: torch.LongTensor, + repeats: torch.LongTensor, + max_repeats: int, + out_col_offset: int = 0, + scale: float = 1.0, +): + """ + Semantics: + y[i, out_col_offset : out_col_offset + w_t_all.shape[2]] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], 0, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight + matrices. + indices: Shape: `[B]`. Indices of the weight matrices. + ranks: rank of the LoRA for each group of 32 tokens + remainder + the LoRA applies to for each LoRA in the batch + repeats: similar to ranks, but number of tokens for the LoRA group + max_repeats: repeats.max(), just cached so it isn't recomputed + out_col_offset: for sgmv_expand/LoRA B, offset output along hidden out + scale: Scaling factor. + """ + h_out, h_in = w_t_all.shape[-2:] + if h_out <= h_in: + sgmv_triton.sgmv_shrink(x, w_t_all, y, ranks, indices, repeats, + max_repeats) + else: + sgmv_triton.sgmv_expand(x, w_t_all, y, ranks, indices, repeats, + max_repeats, out_col_offset, scale) + + +def add_lora(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indices: torch.LongTensor, + ranks: torch.LongTensor, + repeats: torch.LongTensor, + max_repeats: int, + out_col_offset: int = 0, + scale: float = 1.0, + *, + buffer: Optional[torch.Tensor] = None): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indices: Shape: `[B]`. Indices of the LoRA weights. + ranks: rank of the LoRA for each group of 32 tokens + remainder + the LoRA applies to for each LoRA in the batch + repeats: similar to ranks, but number of tokens for the LoRA group + max_repeats: repeats.max(), just cached so it isn't recomputed + out_col_offset: for sgmv_expand/LoRA B, offset output along hidden out + scale: Scaling factor. + buffer: Optional. Shape: `[B, R]`. Temporary buffer. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical inaccuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + sgmv( # LoRA A shrink + buffer, x, wa_t_all, indices, ranks, repeats, max_repeats) + sgmv( # LoRA B expand + y, buffer, wb_t_all, indices, ranks, repeats, max_repeats, + out_col_offset, scale) diff --git a/vllm/model_executor/layers/lora/__init__.py b/vllm/model_executor/layers/lora/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/model_executor/layers/lora/sgmv_triton.py b/vllm/model_executor/layers/lora/sgmv_triton.py new file mode 100644 index 000000000000..9169703db947 --- /dev/null +++ b/vllm/model_executor/layers/lora/sgmv_triton.py @@ -0,0 +1,311 @@ +import torch +import triton +import triton.language as tl + +# generally faster than 16, but can be lowered to 16 to reduce the +# shared memory required by the kernel. +MAX_REPEATS_PER_BLOCK = 32 + + +@triton.autotune(configs=[ + triton.Config({'BLOCK_SIZE_H_IN': 32}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_IN': 32}, num_warps=4), + triton.Config({'BLOCK_SIZE_H_IN': 32}, num_warps=8), + triton.Config({'BLOCK_SIZE_H_IN': 64}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_IN': 64}, num_warps=4), + triton.Config({'BLOCK_SIZE_H_IN': 64}, num_warps=8), + triton.Config({'BLOCK_SIZE_H_IN': 128}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_IN': 128}, num_warps=4), + triton.Config({'BLOCK_SIZE_H_IN': 128}, num_warps=8), +], + key=['R', 'H', 'BLOCK_SIZE_INPUT_PER_LORA'], + restore_value=['o_ptr']) +@triton.jit +def sgmv_shrink_multi_lora_rank( + # Same arguments as below, some renamed + x_ptr, + w_ptr, + o_ptr, + ranks, + indices, + repeats, + S, + R: tl.constexpr, + H, + stride_xs, + stride_xh, + stride_wl, + stride_wr, + stride_wh, + stride_os, + stride_or, + # Meta-parameters + BLOCK_SIZE_INPUT_PER_LORA: tl.constexpr, + BLOCK_SIZE_H_IN: tl.constexpr): + """ + The shrink side of the lora, very similar implementation to expand, but + uses the split-k strategy as in punica. + """ + # grid will be [num_unique, h out // block size h out] + h_id, lora_id = tl.program_id(axis=0), tl.program_id(axis=1) + idx = tl.load(indices + lora_id) + if idx < 0: + return + rank = tl.load(ranks + idx) + repeats_0, repeats_1 = (tl.load(repeats + lora_id), + tl.load(repeats + lora_id + 1)) + + n_inputs = repeats_1 - repeats_0 + input_range = tl.arange(0, BLOCK_SIZE_INPUT_PER_LORA) + offs_xs = repeats_0 + input_range + rank_range = tl.arange(0, R) + offs_h = h_id * BLOCK_SIZE_H_IN + tl.arange(0, BLOCK_SIZE_H_IN) + offs_os = offs_xs + + w_ptrs = (w_ptr + idx * stride_wl + offs_h[:, None] * stride_wh + + rank_range[None, :] * stride_wr) + w = tl.load(w_ptrs, + mask=(offs_h[:, None] < H) & (rank_range[None, :] < rank), + other=0.0).to(dtype=tl.float32) # [H_OUT, R] + + # tl.dot works only on sizes >= 16 + if BLOCK_SIZE_INPUT_PER_LORA >= 16 and R >= 16: + x_ptrs = (x_ptr + offs_xs[:, None] * stride_xs + + offs_h[None, :] * stride_xh) + # [next pow 2 inputs for this lora, R] + x = tl.load(x_ptrs, + mask=(input_range[:, None] < n_inputs) & + (offs_h[None, :] < H), + other=0.0).to(dtype=tl.float32) + + o_ptrs = (o_ptr + offs_os[:, None] * stride_os + + rank_range[None, :] * stride_or) + tl.atomic_add(o_ptrs, + tl.dot(x, w), + mask=(input_range[:, None] < n_inputs) & + (rank_range[None, :] < rank)) + else: + for i in range(n_inputs): + x_ptrs = x_ptr + (repeats_0 + i) * stride_xs + offs_h * stride_xh + o_ptrs = (o_ptr + (repeats_0 + i) * stride_os + + rank_range * stride_or) + x = tl.load(x_ptrs, mask=offs_h < H, + other=0.0).to(dtype=tl.float32) + tl.atomic_add(o_ptrs, + tl.sum(x[:, None] * w, axis=0), + mask=rank_range < rank) + + +@torch.inference_mode() +def sgmv_shrink(x, weights, out, ranks, indices, repeats, max_repeats): + ''' + weights shape: (max_loras, 1, out, in) + Tokens for a LoRA (repeats) should be split into groups of + MAX_REPEATS_PER_BLOCK for load balancing and shared memory constraints. + This should be done at the beginning of the forward pass, so it isn't + repeated every call. + + max rank in ranks should not be greater than buffer.shape[2] + weights.shape[-2] shouldn't be larger than out.shape[-1] (hidden dim) + buffer.shape[0] == out.shape[0] (sequence length) + + buffer, weights and out should be contiguous + ''' + S, H = x.shape + R = out.shape[-1] + + BLOCK_SIZE_INPUT_PER_LORA = triton.next_power_of_2(max_repeats) + grid = lambda META: (triton.cdiv(H, META['BLOCK_SIZE_H_IN']), len(repeats) + - 1) + sgmv_shrink_multi_lora_rank[grid]( + x, + weights, + out, + ranks, + indices, + repeats, + S, + R, + H, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(2), + weights.stride(3), + out.stride(0), + out.stride(1), + BLOCK_SIZE_INPUT_PER_LORA=BLOCK_SIZE_INPUT_PER_LORA, + ) + return out + + +@triton.autotune(configs=[ + triton.Config({'BLOCK_SIZE_H_OUT': 32}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_OUT': 32}, num_warps=4), + triton.Config({'BLOCK_SIZE_H_OUT': 32}, num_warps=8), + triton.Config({'BLOCK_SIZE_H_OUT': 64}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_OUT': 64}, num_warps=4), + triton.Config({'BLOCK_SIZE_H_OUT': 64}, num_warps=8), + triton.Config({'BLOCK_SIZE_H_OUT': 128}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_OUT': 128}, num_warps=4), + triton.Config({'BLOCK_SIZE_H_OUT': 128}, num_warps=8), +], + key=['R', 'H', 'BLOCK_SIZE_INPUT_PER_LORA'], + restore_value=['o_ptr']) +@triton.jit +def sgmv_expand_multi_lora_rank( + # NOTE: Inputs MUST be grouped by lora + # Pointers to buffer, weight page and output respectively + b_ptr, + w_ptr, + o_ptr, + # indices a tensor of [num unique loras in seq] + # repeats [num unique loras in seq + 1] + # indices contains, for each group of inputs, the unique lora idx + # repeats, r such that sum(r)=seq_length, repeats=cumsum(r). + # Cumulative sum of how many inputs are using the same lora, + # starting at 0 + ranks, + indices, + repeats, + # optional output column offset + out_col_offset, + scale, + # Dimensions, sequence length/batch, max rank, hidden out + S, + R: tl.constexpr, + H, + # row, col stride for each + stride_bs, + stride_br, + stride_wl, + stride_wh, + stride_wr, + stride_os, + stride_oh, + # Meta-parameters + BLOCK_SIZE_INPUT_PER_LORA: tl.constexpr, + BLOCK_SIZE_H_OUT: tl.constexpr): + """ + The punica expand kernel in Triton. Can take advantage of tl.dot() for + increased speed when the rank and number of inputs are larger than 16. + i.e. prefill or grouped + """ + # grid will be [num_unique, h out // block size h out] + h_id, lora_id = tl.program_id(axis=0), tl.program_id(axis=1) + idx = tl.load(indices + lora_id) + if idx < 0: + return + rank = tl.load(ranks + idx) + repeats_0, repeats_1 = (tl.load(repeats + lora_id), + tl.load(repeats + lora_id + 1)) + + n_inputs = repeats_1 - repeats_0 + input_range = tl.arange(0, BLOCK_SIZE_INPUT_PER_LORA) + offs_bs = repeats_0 + input_range + rank_range = tl.arange(0, R) + offs_wh = h_id * BLOCK_SIZE_H_OUT + tl.arange(0, BLOCK_SIZE_H_OUT) + + # compare transpose after vs transpose ptrs + w_ptrs = (w_ptr + idx * stride_wl + rank_range[:, None] * stride_wr + + offs_wh[None, :] * stride_wh) + + offs_os = offs_bs + offs_oh = offs_wh + + w = tl.load(w_ptrs, + mask=(rank_range[:, None] < rank) & (offs_wh[None, :] < H), + other=0.0).to(dtype=tl.float32) # [R, H_OUT] + + # tl.dot works only on sizes >= 16 + if BLOCK_SIZE_INPUT_PER_LORA >= 16 and R >= 16: + b_ptrs = (b_ptr + offs_bs[:, None] * stride_bs + + rank_range[None, :] * stride_br) + buffer = tl.load(b_ptrs, + mask=(input_range[:, None] < n_inputs) & + (rank_range[None, :] < rank), + other=0.0) # [next pow 2 inputs for this lora, R] + buffer *= scale + + o_ptrs = (o_ptr + offs_os[:, None] * stride_os + + (offs_oh[None, :] + out_col_offset) * stride_oh) + accumulator = tl.load(o_ptrs, + mask=(input_range[:, None] < n_inputs) & + (offs_oh[None, :] < H), + other=0.0).to(dtype=tl.float32) + accumulator += tl.dot(buffer, w) + + tl.store(o_ptrs, + accumulator, + mask=(input_range[:, None] < n_inputs) & + (offs_oh[None, :] < H)) + else: + for i in range(n_inputs): + b_ptrs = b_ptr + (repeats_0 + + i) * stride_bs + rank_range * stride_br + o_ptrs = (o_ptr + (repeats_0 + i) * stride_os + + (offs_oh + out_col_offset) * stride_oh) + out = tl.load(o_ptrs, mask=offs_oh < H, + other=0.0).to(dtype=tl.float32) + buffer = tl.load(b_ptrs, mask=rank_range < rank, + other=0.0).to(dtype=tl.float32) + buffer *= scale + + out += tl.sum(buffer[:, None] * w, axis=0) + tl.store(o_ptrs, out, mask=offs_oh < H) + + +@torch.inference_mode() +def sgmv_expand(buffer, + weights, + out, + ranks, + indices, + repeats, + max_repeats, + out_col_offset=0, + scale=1.0): + ''' + weights shape: (max_loras, 1, out, in) + Tokens for a LoRA (repeats) should be split into groups of + MAX_REPEATS_PER_BLOCK for load balancing and shared memory constraints. + This should be done at the beginning of the forward pass, so it isn't + repeated every call. + + max rank in ranks should not be greater than buffer.shape[2] + buffer.shape[0] == out.shape[0] (sequence length) + out_col_offset + weights.shape[-2] can't be greater than out.shape[-1] + + buffer, weights and out should be contiguous + ''' + assert out_col_offset + weights.shape[-1] <= out.shape[-1], ( + f"Output column offset {out_col_offset} with output dim " + + f"{weights.shape[-1]} is too high for given output tensor {out.shape}") + S, R = buffer.shape + H = weights.shape[-2] + + BLOCK_SIZE_INPUT_PER_LORA = triton.next_power_of_2(max_repeats) + grid = lambda META: (triton.cdiv(H, META['BLOCK_SIZE_H_OUT']), len(repeats) + - 1) + sgmv_expand_multi_lora_rank[grid]( + buffer, + weights, + out, + ranks, + indices, + repeats, + out_col_offset, + scale, + S, + R, + H, + buffer.stride(0), + buffer.stride(1), + weights.stride(0), + weights.stride(2), + weights.stride(3), + out.stride(0), + out.stride(1), + BLOCK_SIZE_INPUT_PER_LORA=BLOCK_SIZE_INPUT_PER_LORA, + ) + return out