diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 07b898787eba..cae1bffe6a3a 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -50,6 +50,9 @@ docker run --privileged --net host --shm-size=16G -it \ && pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \ && echo TEST_12 \ && pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \ + # Disable the TPU LoRA tests until the feature is activated + # && echo TEST_13 \ + # && pytest -s -v /workspace/vllm/tests/tpu/lora/" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index dc433f9dad26..b940f7190bb2 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -47,7 +47,7 @@ def dist_init(): temp_file = tempfile.mkstemp()[1] backend = "nccl" - if current_platform.is_cpu(): + if current_platform.is_cpu() or current_platform.is_tpu(): backend = "gloo" init_distributed_environment(world_size=1, diff --git a/tests/tpu/lora/__init__.py b/tests/tpu/lora/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/tpu/lora/test_lora.py b/tests/tpu/lora/test_lora.py new file mode 100644 index 000000000000..21d7fce691c9 --- /dev/null +++ b/tests/tpu/lora/test_lora.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +import vllm +from vllm.lora.request import LoRARequest + +# This file contains tests to ensure that LoRA works correctly on the TPU +# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct +# for this. The adapters are: +# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges +# from 1 to 4. + +# These adapters are trained using a standard huggingface peft training script, +# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run +# 100 training iterations with a training batch size of 100. + + +@pytest.fixture(scope="function", autouse=True) +def use_v1_only(monkeypatch: pytest.MonkeyPatch): + """ + Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1 + for all tests in this file + """ + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + yield + + +def setup_vllm(num_loras: int) -> vllm.LLM: + return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", + num_scheduler_steps=1, + max_model_len=256, + max_seq_len_to_capture=256, + max_num_seqs=8, + enable_lora=True, + max_loras=num_loras, + max_lora_rank=8) + + +def test_single_lora(): + """ + This test ensures we can run a single LoRA adapter on the TPU backend. + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which + will force Qwen2.5-3B-Instruct to claim 1+1=1. + """ + + llm = setup_vllm(1) + + prompt = "What is 1+1? \n" + + lora_request = LoRARequest( + "lora_adapter_1", 1, + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter") + output = llm.generate(prompt, + sampling_params=vllm.SamplingParams(max_tokens=256, + temperature=0), + lora_request=lora_request)[0].outputs[0].text + + answer = output.strip()[0] + + assert answer.isdigit() + assert int(answer) == 1 + + +def test_lora_hotswapping(): + """ + This test ensures we can run multiple LoRA adapters on the TPU backend, even + if we only have space to store 1. + + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which + will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. + """ + + lora_name_template = \ + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_requests = [ + LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) + for i in range(1, 5) + ] + + llm = setup_vllm(1) + + prompt = "What is 1+1? \n" + + for i, req in enumerate(lora_requests): + output = llm.generate(prompt, + sampling_params=vllm.SamplingParams( + max_tokens=256, temperature=0), + lora_request=req)[0].outputs[0].text + answer = output.strip()[0] + + assert answer.isdigit() + assert int(answer) == i + 1 + + +def test_multi_lora(): + """ + This test ensures we can run multiple LoRA adapters on the TPU backend, when + we have enough space to store all of them. + + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which + will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. + """ + lora_name_template = \ + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_requests = [ + LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) + for i in range(1, 5) + ] + + llm = setup_vllm(4) + + prompt = "What is 1+1? \n" + + for i, req in enumerate(lora_requests): + output = llm.generate(prompt, + sampling_params=vllm.SamplingParams( + max_tokens=256, temperature=0), + lora_request=req)[0].outputs[0].text + + answer = output.strip()[0] + + assert answer.isdigit() + assert int(output.strip()[0]) == i + 1 diff --git a/tests/tpu/lora/test_pallas_kernels.py b/tests/tpu/lora/test_pallas_kernels.py new file mode 100644 index 000000000000..8bd47de50c34 --- /dev/null +++ b/tests/tpu/lora/test_pallas_kernels.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +# Required to register the custom ops +import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import + +N_TOKENS = [16, 1024, 4096] +HIDDEN_SIZES = [1024, 2048, 4096] + +DTYPES = [torch.bfloat16] +NUM_LORA = [1, 4, 16] +RANKS = [32, 256, 512] + + +def generate_test_data(T, D, L, N, seed, dtype=torch.float32): + """ + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: torch.Tensor - shape (T, D) + loras: torch.Tensor - shape (N, 1, L, D) + idxs: torch.Tensor - shape (T, ) - all values must be in [0, N) + + ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T + """ + torch.manual_seed(seed) + + inputs = torch.randn((T, D), device="xla", dtype=dtype) + loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype) + idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla") + + ref_output = ref_bgmv(inputs, loras, idxs) + return inputs, loras, idxs, ref_output + + +def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): + selected_loras = loras[idxs] + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(axis=1) + + batch_size, output_size, input_size = selected_loras.shape + return (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + +# Parameterize tests with various shapes and dtypes +@pytest.mark.parametrize("T", N_TOKENS) +@pytest.mark.parametrize("D", HIDDEN_SIZES) +@pytest.mark.parametrize("L", RANKS) +@pytest.mark.parametrize("N", NUM_LORA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", [0]) +def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): + if op_type == "expand": + D, L = L, D + + inputs, loras, idxs, ref_output = generate_test_data( + T, D, L, N, seed, dtype) + + # Run bgmv + output = torch.ops.xla.bgmv(inputs, loras, idxs) + + # Make sure we have no NaNs + assert not torch.any(torch.isnan(output)) + + # Compare with reference output + assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) diff --git a/vllm/config.py b/vllm/config.py index 11e4e500aa09..0bbf588fb3e8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2694,8 +2694,8 @@ class LoRAConfig: lora_extra_vocab_size: int = 256 """Maximum size of extra vocabulary that can be present in a LoRA adapter (added to the base model vocabulary).""" - # This is a constant. - lora_vocab_padding_size: ClassVar[int] = 256 + lora_vocab_padding_size: ClassVar[int] = current_platform\ + .get_lora_vocab_padding_size() long_lora_scaling_factors: Optional[tuple[float, ...]] = None """Specify multiple scaling factors (which can be different from base model scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters @@ -2723,6 +2723,7 @@ def compute_hash(self) -> str: factors.append(self.fully_sharded_loras) factors.append(self.lora_dtype) factors.append(self.lora_extra_vocab_size) + factors.append(self.lora_vocab_padding_size) factors.append(self.long_lora_scaling_factors) factors.append(self.bias_enabled) hash_str = hashlib.md5(str(factors).encode(), diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 41e1ec94145d..e195f8cf5e8e 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -16,6 +16,7 @@ MergedQKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA) +from vllm.platforms import current_platform if TYPE_CHECKING: pass @@ -57,15 +58,25 @@ def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): device=x.device, ) - layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0) + shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( + buffers, x, layer.lora_a_stacked, 1.0) + + if not current_platform.can_update_inplace(): + buffers = shrunk_buffers + buffers = tensor_model_parallel_all_gather(buffers) - layer.punica_wrapper.add_expand(output, - buffers, - layer.lora_b_stacked, - layer.lora_bias_stacked, - layer.output_slices, - offset_start=0, - add_input=True) + + lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( + output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + if not current_platform.can_update_inplace(): + output = lora_output output = output.view(*out_orig_shape) # now have column partitioned and packed output @@ -292,7 +303,11 @@ def apply(self, device=x.device, ) - self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) + shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( + buffer, x, self.lora_a_stacked, 1.0) + if not current_platform.can_update_inplace(): + buffer = shrunk_buffer + buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce @@ -304,7 +319,7 @@ def apply(self, # NOTE offset are based on the rank. shard_size = self.lora_b_stacked[0].shape[2] offset_start = self.tp_rank * shard_size - self.punica_wrapper.add_expand( + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( output, buffer, self.lora_b_stacked, @@ -313,6 +328,10 @@ def apply(self, offset_start=offset_start, add_input=True, ) + + if not current_platform.can_update_inplace(): + output = lora_output + output = output.view(*out_orig_shape) return output diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index d9de0f3cfeb3..6749ec16a097 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -261,10 +261,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings.shape[1], -1, ) - self.punica_wrapper.add_lora_embedding(full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + + if not current_platform.can_update_inplace(): + full_output = lora_output + return full_output.view_as(full_output_org) @classmethod @@ -410,10 +417,13 @@ def apply(self, output = output.flatten(0, 1) x = x.flatten(0, 1) - self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, - self.lora_b_stacked, - self.lora_bias_stacked, 1.0, - self.output_slices) + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output + return output @property @@ -1133,15 +1143,23 @@ def _get_logits( torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) - lora_logits[-1] = float("-inf") + + neg_inf, pos_inf = current_platform.get_infinity_values( + lora_logits.dtype) + + lora_logits[-1] = neg_inf lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded + + if current_platform.is_tpu(): + indices_padded = indices_padded[:logits.size(0)] + lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), - posinf=float("inf"), - neginf=float("-inf"))) + ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, + posinf=pos_inf, + neginf=neg_inf)) # HPU needs special handling to prune out dummy samples. if current_platform.is_hpu(): @@ -1151,10 +1169,13 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits - # LogitsProcessorWithLoRA always using bgmv - self.punica_wrapper.add_lora_logits(logits, hidden_states, - self.lora_a_stacked, - self.lora_b_stacked, 1.0) + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + if not current_platform.can_update_inplace(): + logits = lora_output # Remove paddings in vocab (if any). logits = logits[:, :self.base_layer.vocab_size] diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py new file mode 100644 index 000000000000..94062b05d916 --- /dev/null +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink) + +__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py new file mode 100644 index 000000000000..acbec0cfab9c --- /dev/null +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# Required to register the custom ops +import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, hidden_size * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ + + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + n_tokens = outputs.size(0) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + outputs = torch.cat( + (outputs, + torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]), + device=outputs.device)), + dim=1) + + if add_inputs: + return output_tensor + outputs[:limit, :] + else: + return outputs[:limit, :] + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + output_tensor (torch.Tensor): (Unused) output tensor (placeholder). + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + scaling (float, optional): Scalar multiplier applied to the output. + """ + + return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, + lora_indices_tensor) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, hidden_size * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + n_tokens = outputs.size(0) + + outputs = torch.cat(( + torch.zeros((n_tokens, slice_offset), device=outputs.device), + outputs, + torch.zeros( + (n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)), + device=outputs.device), + ), + dim=1) + + if add_inputs: + return output_tensor + outputs + else: + return outputs diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py new file mode 100644 index 000000000000..35dc307539bf --- /dev/null +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +import functools + +import jax +import jax.numpy as jnp +import torch +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from torch.library import impl +from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, + make_kernel_from_pallas) + +# TODO: Tune these +TOKENS_BLOCK = 16 +LORA_RANK_BLOCK = 128 +DIM_BLOCK_SIZE = 128 + + +def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, + acc_ref, mask_ref): + + @pl.when(pl.program_id(2) == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + + t = pl.program_id(0) + + for i in range(bT): + idx = idx_ref[i + bT * t] + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32) + + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[idx, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[...] + + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + + +@jax.jit +def _bgmv( + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras: jax.Array # (N, L, D) model dtype +) -> jax.Array: # (T, L) model dtype + T, D = inputs.shape + N, L, _ = loras.shape + + return pl.pallas_call( + kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK), + out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK, + D // DIM_BLOCK_SIZE), + in_specs=[ + pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE), + lambda i, j, k, block_idx: (i, k)), + pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE), + lambda i, j, k, block_idx: (0, j, k)), + ], + out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK), + lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32), + pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), + name="bgmv")(idxs, inputs, loras) + + +def bgmv_shape_function(idxs, inputs, loras): + T, _ = inputs.shape + _, L, _ = loras.shape + + return [((T, L), inputs.dtype)] + + +XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) + + +@impl(XLA_LIB, "bgmv", "XLA") +def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): + inputs = inputs.to(dtype=loras.dtype) + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + jax_import_guard() + kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) + + T, _ = inputs.shape + _, L, D = loras.shape + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU + # register. This has to happen in pytorch, doing it in Jax will lead to NaNs + L1 = L + if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0: + L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK + + D1 = D + if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0: + D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE + + T1 = T + if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0: + T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK + + if D1 != D or L1 != L: + loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0)) + if D1 != D or T1 != T: + inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T)) + if T1 != T: + idxs = torch.nn.functional.pad(idxs, ((0, T1 - T))) + + return kernel(idxs, inputs, loras)[:T, :L] + + +@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") +def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): + T, _ = inputs.shape + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + _, L, _ = loras.shape + + return torch.empty((T, L), device=inputs.device) diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 94fa3f27ab60..78866c51895b 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -48,7 +48,7 @@ def add_shrink( lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. """ @@ -66,7 +66,7 @@ def add_expand( offset_start: int = 0, add_inputs=True, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Performs GEMM and bias addition for multiple slices of lora_b. """ @@ -80,7 +80,7 @@ def add_lora_embedding( lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA, and this layer only requires the expand operation. @@ -98,7 +98,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. """ @@ -114,7 +114,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. """ @@ -207,7 +207,8 @@ def _update_base_metadata( self._long_lora_indices.zero_() self.indices_len[:] = indices_len - def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + def _update_prefill_metadata(self, + token_lora_tensor: torch.Tensor) -> None: (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, batch_size, max_length, token_nums, @@ -334,7 +335,7 @@ def update_metadata( long_lora_context) if mapping.is_prefill: # Update metadata required for prefill-related operators. - self._update_prefill_metada(self.token_lora_indices) + self._update_prefill_metadata(self.token_lora_indices) self.is_prefill = True else: self.is_prefill = False @@ -342,7 +343,7 @@ def update_metadata( @abstractmethod def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, **kwargs) -> None: + scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. @@ -369,7 +370,7 @@ def add_expand(self, output_slices: Tuple[int, ...], offset_start: int = 0, add_inputs=True, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -401,7 +402,7 @@ def add_lora_embedding(self, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. and this layer only requires the expand operation. @@ -428,7 +429,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. @@ -463,7 +464,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py new file mode 100644 index 000000000000..37544c755d90 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink + +from .punica_base import PunicaWrapperBase + + +class PunicaWrapperTPU(PunicaWrapperBase): + """ + PunicaWrapperTPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + # PunicaWrapperBase defines some tensors with dtype=torch.int64, which + # isn't supported by the TPU. So convert those tensors to int32. + # Not all of them are used by the TPU so only convert the useful ones. + self._token_lora_indices = self._token_lora_indices.to( + dtype=torch.int32) + self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) + self._sampler_indices_padded = self._sampler_indices_padded.to( + dtype=torch.int32) + + torch._dynamo.mark_dynamic(self._token_lora_indices, 0) + torch._dynamo.mark_dynamic(self._embeddings_indices, 1) + torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) + + def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: + return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + return self._embeddings_indices[:] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + return self._sampler_indices_padded[:] + + def shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + if self.no_lora: + return y + return bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x), + scale) + + def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + add_inputs: bool): + return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), + add_inputs) + + def expand_slice(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + y_total_size: int, add_inputs: bool) -> torch.Tensor: + return bgmv_expand_slice(x, w_t_all, y, + self._get_token_lora_indices(x), y_offset, + y_slice_size, add_inputs) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs) -> Optional[torch.Tensor]: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + torch.ops.xla.dynamo_set_buffer_donor_(y, True) + x = x.view(-1, x.shape[-1]) + + for slice_idx in range(len(lora_a_stacked)): + y_s = y[slice_idx] + lora_s = lora_a_stacked[slice_idx] + y_s = self.shrink(y_s, x, lora_s, scale) + y[slice_idx, :, :] = y_s # type: ignore[index] + return y + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> torch.Tensor: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = 0 + + if lora_bias_stacked is not None: + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + y = self.expand_slice( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + y_total_size=sum(output_slices), + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + return y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> torch.Tensor: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only needs the expand op + return self.expand(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> torch.Tensor: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will not be changed in-place. + x (torch.Tensor): Input tensor (T, E) + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, consistent with the + # triton op + T = x.size(0) + buffer = torch.zeros( + (len(output_slices), T, r), + dtype=torch.float32, + device=x.device, + ) + buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + return self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + if self.no_lora: + return y + + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, + scale) + y = bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + return y.view_as(y_org) + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias = torch.where(indices[:, None] == -1, 0, bias) + + bias = F.pad(bias, (offset_left, output.shape[1] - + (offset_left + slice), 0, 0)) + + output += bias + offset_left += slice + + return output.view_as(org_output) + + def _update_prefill_metadata(self, + token_lora_tensor: torch.Tensor) -> None: + self.batch_size = 1 + self._lora_indices_per_batch[:self.batch_size].copy_( + token_lora_tensor[:self.batch_size]) + # TODO: .item() is extremely inefficient on TPU, so find a way around it + self.no_lora = torch.all(token_lora_tensor == -1).item() diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index dbc2d27c597f..f4e5542b177d 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -125,11 +125,13 @@ def convert_mapping( indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size), ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 + embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1, + embeddings_indices) base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.where(sampler_indices_padded == -1, + max_loras - 1, sampler_indices_padded) sampler_indices_padded = torch.arange( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( sampler_indices_padded * len(sampler_indices_padded)) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 5df0e9d3d072..1ea99b7b2c3f 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -332,6 +332,27 @@ def get_punica_wrapper(cls) -> str: """ raise NotImplementedError + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + """ + Return the platform specific values for (-inf, inf) + """ + return float("-inf"), float("inf") + + @classmethod + def can_update_inplace(cls) -> bool: + """ + Checks if the platform allows inplace memory updates + """ + return True + + @classmethod + def get_lora_vocab_padding_size(cls) -> int: + """ + Returns how much padding the LoRA logits need for kernels + """ + return 256 + @classmethod def get_device_communicator_cls(cls) -> str: """ diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 8c968e7df3ef..2782a3866d76 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union import torch from tpu_info import device @@ -67,6 +67,22 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return not envs.VLLM_USE_V1 + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" + + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + return torch.finfo(dtype).min, torch.finfo(dtype).max + + @classmethod + def can_update_inplace(cls): + return False + + @classmethod + def get_lora_vocab_padding_size(cls) -> int: + return 1 + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f5626abb2a12..be059c30435c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -39,6 +39,7 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from .utils import sanity_check_mm_encoder_outputs @@ -90,7 +91,7 @@ # The dummy_run should be comprehensive, ensuring all potential input shapes and # branch predictions are included as subgraph inputs to facilitate # pre-compilation. -class TPUModelRunner: +class TPUModelRunner(LoRAModelRunnerMixin): def __init__( self, @@ -568,6 +569,17 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device) seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) + if self.lora_config is not None: + # We need to respect padding when activating LoRA adapters + padded_num_scheduled_tokens_per_req = np.copy( + num_scheduled_tokens_per_req + ) # Copying to avoid accidental state corruption bugs + padded_num_scheduled_tokens_per_req[-1] += \ + padded_total_num_scheduled_tokens - total_num_scheduled_tokens + + self.set_active_loras(self.input_batch, + padded_num_scheduled_tokens_per_req) + attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, @@ -907,6 +919,11 @@ def load_model(self) -> None: "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) + if self.lora_config is not None: + model = self.load_lora_model(model, self.model_config, + self.scheduler_config, + self.lora_config, self.device) + # Sync all pending XLA execution during model initialization and weight # loading. xm.mark_step() @@ -970,7 +987,10 @@ def _dummy_run(self, num_tokens: int) -> None: for layer_name in layer_names } - with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0): + with self.maybe_dummy_run_with_lora( + self.lora_config, + np.array([num_tokens], dtype=np.int32)), set_forward_context( + per_layer_attn_metadata, self.vllm_config, 0): out = self.model(input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index de676541effa..9eea26d85249 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -15,6 +15,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput @@ -82,6 +83,10 @@ def __init__( if self.model_config.seed is None: self.model_config.seed = 0 + if vllm_config.lora_config is not None: + raise NotImplementedError( + "The V1 TPU backend doesn't support LoRA serving") + def init_device(self): os.environ["PJRT_DEVICE"] = "TPU" # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D @@ -211,6 +216,9 @@ def profile(self, is_start: bool = True): else: xp.stop_trace() + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + def load_model(self) -> None: self.model_runner.load_model() diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index bbcc4d59ae1c..4bb9bea022f9 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -54,6 +54,10 @@ def __init__( if self.model_config.seed is None: self.model_config.seed = 0 + if vllm_config.lora_config is not None: + raise NotImplementedError( + "The V0 TPU backend doesn't support LoRA serving") + def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" torch.set_grad_enabled(False)