diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 03e2267a1b4e..4cc9c70a6adb 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -57,6 +57,7 @@ steps: agents: queue: tpu_queue_postmerge commands: + - "yes | docker system prune -a" - "git fetch --all" - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f docker/Dockerfile.tpu ." - "docker push vllm/vllm-tpu:nightly" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d3c07cdda454..84ee991f5659 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -293,6 +293,7 @@ steps: parallelism: 4 - label: PyTorch Compilation Unit Tests + torch_nightly: true source_file_dependencies: - vllm/ - tests/compile @@ -302,6 +303,7 @@ steps: - pytest -v -s compile/test_sequence_parallelism.py - label: PyTorch Fullgraph Smoke Test # 9min + torch_nightly: true source_file_dependencies: - vllm/ - tests/compile @@ -312,6 +314,7 @@ steps: - pytest -v -s compile/piecewise/test_toy_llama.py - label: PyTorch Fullgraph Test # 18min + torch_nightly: true source_file_dependencies: - vllm/ - tests/compile @@ -436,6 +439,7 @@ steps: ##### models test ##### - label: Basic Models Test # 24min + torch_nightly: true source_file_dependencies: - vllm/ - tests/models diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 90ed492d992a..5ecd7b70ea54 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,7 +46,7 @@ repos: rev: 0.6.17 hooks: - id: pip-compile - args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match] + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128] files: ^requirements/test\.(in|txt)$ - repo: local hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index 72740279d0e0..d530646cd78b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX) # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") - message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") @@ -250,9 +249,8 @@ set(VLLM_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") - # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - # Please keep this in sync with FetchContent_Declare line below. - set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use") + # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. + set(CUTLASS_REVISION "v3.9.1" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -270,7 +268,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git # Please keep this in sync with CUTLASS_REVISION line above. - GIT_TAG v3.9.0 + GIT_TAG ${CUTLASS_REVISION} GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. @@ -682,6 +680,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() +if(VLLM_GPU_LANG STREQUAL "CUDA") + set(MOE_PERMUTE_SRC + "csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu" + "csrc/moe/moe_permute_unpermute_op.cu") + + set_gencode_flags_for_srcs( + SRCS "${MARLIN_PERMUTE_SRC}" + CUDA_ARCHS "${MOE_PERMUTE_ARCHS}") + + list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}") +endif() message(STATUS "Enabling moe extension.") define_gpu_extension_target( _moe_C @@ -690,6 +699,8 @@ define_gpu_extension_target( SOURCES ${VLLM_MOE_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index bcdbf6c7551a..c92ea43e8260 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -90,7 +90,8 @@ def bench_run(results: list[benchmark.Measurement], model: str, score = torch.randn((m, num_experts), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score, topk, renormalize=False) def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index a274537a6751..9407747f7843 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -115,8 +115,8 @@ def run(): from vllm.model_executor.layers.fused_moe import override_config with override_config(config): if use_deep_gemm: - topk_weights, topk_ids = fused_topk(x, input_gating, topk, - False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + x, input_gating, topk, False) return fused_experts( x, w1, @@ -442,8 +442,14 @@ def tune( hidden_size, search_space, is_fp16, topk) - with torch.cuda.device(self.device_id) if current_platform.is_rocm( - ) else nullcontext(): + need_device_guard = False + if current_platform.is_rocm(): + visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None) + if visible_device != f"{self.device_id}": + need_device_guard = True + + with torch.cuda.device( + self.device_id) if need_device_guard else nullcontext(): for config in tqdm(search_space): try: kernel_time = benchmark_config( @@ -578,6 +584,15 @@ def main(args: argparse.Namespace): use_deep_gemm = bool(args.use_deep_gemm) + if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ: + # Ray will set ROCR_VISIBLE_DEVICES for device visibility + logger.warning( + "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility." + "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES.") + val = os.environ["HIP_VISIBLE_DEVICES"] + os.environ["ROCR_VISIBLE_DEVICES"] = val + del os.environ["HIP_VISIBLE_DEVICES"] + ray.init() num_gpus = int(ray.available_resources()["GPU"]) workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py new file mode 100644 index 000000000000..937df9624651 --- /dev/null +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -0,0 +1,349 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +from typing import Any, TypedDict + +import ray +import torch +from transformers import AutoConfig + +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _moe_permute, _moe_unpermute_and_reduce) +from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import * +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.platforms import current_platform +from vllm.utils import FlexibleArgumentParser + +FP8_DTYPE = current_platform.fp8_dtype() + + +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def benchmark_permute(num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + use_customized_permute: bool = False) -> float: + # init_dtype = torch.float16 if use_fp8_w8a8 else dtype + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + # output_hidden_states = torch.empty_like(hidden_states) + if use_fp8_w8a8: + align_block_size = 128 # deepgemm needs 128 m aligned block + qhidden_states, scale = _fp8_quantize(hidden_states, None, None) + else: + align_block_size = None + qhidden_states = hidden_states + + gating_output = torch.randn(num_iters, + num_tokens, + num_experts, + dtype=torch.float32) + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + topk_weights, topk_ids, token_expert_indices = fused_topk( + qhidden_states, input_gating, topk, False) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + if use_customized_permute: + (permuted_hidden_states, first_token_off, inv_perm_idx, + m_indices) = moe_permute( + qhidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + token_expert_indices=token_expert_indices, + topk=topk, + n_expert=num_experts, + n_local_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + else: + (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) = _moe_permute(qhidden_states, None, topk_ids, + num_experts, None, align_block_size) + + # JIT compilation & warmup + run() + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +def benchmark_unpermute(num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + use_customized_permute: bool = False) -> float: + # init_dtype = torch.float16 if use_fp8_w8a8 else dtype + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + output_hidden_states = torch.empty_like(hidden_states) + if use_fp8_w8a8: + align_block_size = 128 # deepgemm needs 128 m aligned block + qhidden_states, scale = _fp8_quantize(hidden_states, None, None) + else: + align_block_size = None + qhidden_states = hidden_states + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + + topk_weights, topk_ids, token_expert_indices = fused_topk( + qhidden_states, input_gating, topk, False) + + def prepare(): + if use_customized_permute: + (permuted_hidden_states, first_token_off, inv_perm_idx, + m_indices) = moe_permute( + qhidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + token_expert_indices=token_expert_indices, + topk=topk, + n_expert=num_experts, + n_local_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + # convert to fp16/bf16 as gemm output + return (permuted_hidden_states.to(dtype), first_token_off, + inv_perm_idx, m_indices) + else: + (permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) = _moe_permute(qhidden_states, None, topk_ids, + num_experts, None, align_block_size) + # convert to fp16/bf16 as gemm output + return (permuted_qhidden_states.to(dtype), a1q_scale, + sorted_token_ids, expert_ids, inv_perm) + + def run(input: tuple): + if use_customized_permute: + (permuted_hidden_states, first_token_off, inv_perm_idx, + m_indices) = input + moe_unpermute(permuted_hidden_states, topk_weights, topk_ids, + inv_perm_idx, first_token_off, topk, num_experts, + num_experts) + else: + (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) = input + _moe_unpermute_and_reduce(output_hidden_states, + permuted_hidden_states, inv_perm, + topk_weights) + + # JIT compilation & warmup + input = prepare() + run(input) + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run(input) + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +@ray.remote(num_gpus=1) +class BenchmarkWorker: + + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + current_platform.seed_everything(seed) + self.seed = seed + # Get the device ID to allocate tensors and kernels + # on the respective GPU. This is required for Ray to work + # correctly with multi-GPU tuning on the ROCm platform. + self.device_id = int(ray.get_gpu_ids()[0]) + + def benchmark( + self, + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_customized_permute: bool = False, + ) -> tuple[dict[str, int], float]: + current_platform.seed_everything(self.seed) + + permute_time = benchmark_permute( + num_tokens, + num_experts, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + use_customized_permute=use_customized_permute) + unpermute_time = benchmark_unpermute( + num_tokens, + num_experts, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + use_customized_permute=use_customized_permute) + return permute_time, unpermute_time + + +def get_weight_block_size_safety(config, default_value=None): + + quantization_config = getattr(config, 'quantization_config', {}) + if isinstance(quantization_config, dict): + return quantization_config.get('weight_block_size', default_value) + return default_value + + +def main(args: argparse.Namespace): + print(args) + + config = AutoConfig.from_pretrained( + args.model, trust_remote_code=args.trust_remote_code) + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + elif (config.architectures[0] == "DeepseekV3ForCausalLM" + or config.architectures[0] == "DeepseekV2ForCausalLM"): + E = config.n_routed_experts + topk = config.num_experts_per_tok + elif config.architectures[0] in [ + "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" + ]: + E = config.num_experts + topk = config.num_experts_per_tok + + else: + # Support for llama4 + config = config.get_text_config() + # Default: Mixtral. + E = config.num_local_experts + topk = config.num_experts_per_tok + + hidden_size = config.hidden_size + dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype + use_fp8_w8a8 = args.dtype == "fp8_w8a8" + use_int8_w8a16 = args.dtype == "int8_w8a16" + use_customized_permute = args.use_customized_permute + + if args.batch_size is None: + batch_sizes = [ + 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, + 2048, 3072, 4096 + ] + else: + batch_sizes = [args.batch_size] + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + + def _distribute(method: str, inputs: list[Any]) -> list[Any]: + outputs = [] + worker_idx = 0 + for input_args in inputs: + worker = workers[worker_idx] + worker_method = getattr(worker, method) + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + return ray.get(outputs) + + outputs = _distribute( + "benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8, + use_int8_w8a16, use_customized_permute) + for batch_size in batch_sizes]) + + for batch_size, (permute, unpermute) in zip(batch_sizes, outputs): + print(f"Batch size: {batch_size}") + print(f"Permute time: {permute:.2f} us") + print(f"Unpermute time: {unpermute:.2f} us") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--model", + type=str, + default="mistralai/Mixtral-8x7B-Instruct-v0.1") + parser.add_argument("--dtype", + type=str, + choices=["auto", "fp8_w8a8", "int8_w8a16"], + default="auto") + parser.add_argument("--use-customized-permute", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--trust-remote-code", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu new file mode 100644 index 000000000000..76d5f0eab021 --- /dev/null +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -0,0 +1,133 @@ +#include +#include +#include +#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h" +#include "permute_unpermute_kernels/dispatch.h" +#include "core/registration.h" + +void moe_permute( + const torch::Tensor& input, // [n_token, hidden] + const torch::Tensor& topk_weights, //[n_token, topk] + torch::Tensor& topk_ids, // [n_token, topk] + const torch::Tensor& token_expert_indicies, // [n_token, topk] + const std::optional& expert_map, // [n_expert] + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& + permuted_input, // [topk * n_token/align_block_size_m, hidden] + torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] + torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] + torch::Tensor& m_indices) { // [align_expand_m] + TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float, + "topk_weights must be float32"); + TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, + "expert_first_token_offset must be int64"); + TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, + "topk_ids must be int32"); + TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int, + "token_expert_indicies must be int32"); + TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, + "src_row_id2dst_row_id_map must be int32"); + TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, + "expert_first_token_offset shape != n_local_expert+1") + TORCH_CHECK( + src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(), + "token_expert_indicies shape must be same as src_row_id2dst_row_id_map"); + auto n_token = input.sizes()[0]; + auto n_hidden = input.sizes()[1]; + auto align_block_size_value = + align_block_size.has_value() ? align_block_size.value() : -1; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const long sorter_size = + CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert); + auto sort_workspace = torch::empty( + {sorter_size}, + torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + auto permuted_experts_id = torch::empty_like(topk_ids); + auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map); + auto align_expert_first_token_offset = + torch::zeros_like(expert_first_token_offset); + + CubKeyValueSorter sorter{}; + int64_t* valid_num_ptr = nullptr; + // pre-process kernel for expert-parallelism: + // no local expert id plus "n_expert" offset for priority to local expert + // map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1] + // For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id + // [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids + // and map global expert id [2, 3] to local_expert id [0, 1] and map global + // expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map + // operation is to make local expert high priority in following sort topk_ids + // and scan local expert_first_token_offset for each ep rank for next group + // gemm. + if (expert_map.has_value()) { + const int* expert_map_ptr = get_ptr(expert_map.value()); + valid_num_ptr = + get_ptr(expert_first_token_offset) + n_local_expert; + preprocessTopkIdLauncher(get_ptr(topk_ids), n_token * topk, + expert_map_ptr, n_expert, stream); + } + // expert sort topk expert id and scan expert id get expert_first_token_offset + sortAndScanExpert(get_ptr(topk_ids), get_ptr(token_expert_indicies), + get_ptr(permuted_experts_id), + get_ptr(dst_row_id2src_row_id_map), + get_ptr(expert_first_token_offset), n_token, + n_expert, n_local_expert, topk, sorter, + get_ptr(sort_workspace), stream); + + // dispatch expandInputRowsKernelLauncher + MOE_DISPATCH(input.scalar_type(), [&] { + expandInputRowsKernelLauncher( + get_ptr(input), get_ptr(permuted_input), + get_ptr(topk_weights), get_ptr(permuted_experts_id), + get_ptr(dst_row_id2src_row_id_map), + get_ptr(src_row_id2dst_row_id_map), + get_ptr(expert_first_token_offset), n_token, valid_num_ptr, + n_hidden, topk, n_local_expert, align_block_size_value, stream); + }); + + // get m_indices and update expert_first_token_offset with align block + getMIndices(get_ptr(expert_first_token_offset), + get_ptr(align_expert_first_token_offset), + get_ptr(m_indices), n_local_expert, align_block_size_value, + stream); + if (align_block_size.has_value()) { + // update align_expert_first_token_offset + expert_first_token_offset.copy_(align_expert_first_token_offset); + } +} + +void moe_unpermute( + const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] + const torch::Tensor& topk_weights, //[n_token, topk] + const torch::Tensor& topk_ids, // [n_token, topk] + const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] + const torch::Tensor& expert_first_token_offset, // [n_local_expert+1] + int64_t n_expert, int64_t n_local_expert, int64_t topk, + torch::Tensor& hidden_states // [n_token, hidden] +) { + TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(), + "topk_ids shape must be same as src_row_id2dst_row_id_map"); + TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, + "topk_ids must be int32"); + TORCH_CHECK( + permuted_hidden_states.scalar_type() == hidden_states.scalar_type(), + "topk_ids dtype must be same as src_row_id2dst_row_id_map"); + auto n_token = hidden_states.size(0); + auto n_hidden = hidden_states.size(1); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const int64_t* valid_ptr = + get_ptr(expert_first_token_offset) + n_local_expert; + MOE_DISPATCH(hidden_states.scalar_type(), [&] { + finalizeMoeRoutingKernelLauncher( + get_ptr(permuted_hidden_states), + get_ptr(hidden_states), get_ptr(topk_weights), + get_ptr(src_row_id2dst_row_id_map), get_ptr(topk_ids), + n_token, n_hidden, topk, valid_ptr, stream); + }); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("moe_permute", &moe_permute); + m.impl("moe_unpermute", &moe_unpermute); +} \ No newline at end of file diff --git a/csrc/moe/permute_unpermute_kernels/dispatch.h b/csrc/moe/permute_unpermute_kernels/dispatch.h new file mode 100644 index 000000000000..41932cdd85bc --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/dispatch.h @@ -0,0 +1,53 @@ +#pragma once +#include +#define MOE_SWITCH(TYPE, ...) \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ + __VA_ARGS__ \ + default: \ + TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \ + } + +#define MOE_DISPATCH_CASE(enum_type, ...) \ + case enum_type: { \ + using scalar_t = ScalarType2CudaType::type; \ + __VA_ARGS__(); \ + break; \ + } +#define MOE_DISPATCH_FLOAT_CASE(...) \ + MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) + +#define MOE_DISPATCH(TYPE, ...) \ + MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__)) + +template +struct ScalarType2CudaType; + +template <> +struct ScalarType2CudaType { + using type = float; +}; +template <> +struct ScalarType2CudaType { + using type = half; +}; +template <> +struct ScalarType2CudaType { + using type = __nv_bfloat16; +}; + +// #if __CUDA_ARCH__ >= 890 +// fp8 +template <> +struct ScalarType2CudaType { + using type = __nv_fp8_e5m2; +}; +template <> +struct ScalarType2CudaType { + using type = __nv_fp8_e4m3; +}; +// #endif \ No newline at end of file diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu new file mode 100644 index 000000000000..aa353d0f0437 --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -0,0 +1,229 @@ + +#include "moe_permute_unpermute_kernel.h" + +// CubKeyValueSorter definition begin +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +int CubKeyValueSorter::expertsToBits(int num_experts) { + // Max value we represent is V = num_experts + (num_experts - 1) = 2 * + // num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1 + return static_cast(log2(2 * num_experts - 1)) + 1; +} + +CubKeyValueSorter::CubKeyValueSorter(int const num_experts) + : num_experts_(num_experts), num_bits_(expertsToBits(num_experts)) {} + +void CubKeyValueSorter::updateNumExperts(int const num_experts) { + num_experts_ = num_experts; + num_bits_ = expertsToBits(num_experts); +} + +size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs, + int const num_experts) { + int num_bits = expertsToBits(num_experts); + size_t required_storage = 0; + int* null_int = nullptr; + cub::DeviceRadixSort::SortPairs(nullptr, required_storage, null_int, null_int, + null_int, null_int, num_key_value_pairs, 0, + num_bits); + + // when num_key_value_pairs, num_experts, num_bits, required_storage = 64, + // 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same + // inputs + if (required_storage == 0) { + required_storage = 1; + } + return required_storage; +} + +void CubKeyValueSorter::run(void* workspace, size_t const workspace_size, + int const* keys_in, int* keys_out, + int const* values_in, int* values_out, + size_t const num_key_value_pairs, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_); + size_t actual_ws_size = workspace_size; + + TORCH_CHECK(expected_ws_size <= workspace_size, + "[CubKeyValueSorter::run] The allocated workspace is too small " + "to run this problem."); + cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out, + values_in, values_out, num_key_value_pairs, 0, + num_bits_, stream); +} +// CubKeyValueSorter definition end + +static inline size_t pad_to_multiple_of_16(size_t const& input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} +template +__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, + int64_t const arr_length, + T const target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] >= target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +// Calculates the start offset of the tokens for a given expert. The last +// element is the total number of valid tokens +__global__ void computeExpertFirstTokenOffsetKernel( + int const* sorted_experts, int64_t const sorted_experts_len, + int const num_experts, int64_t* expert_first_token_offset) { + // First, compute the global tid. We only need 1 thread per expert. + int const expert = blockIdx.x * blockDim.x + threadIdx.x; + + // Note that expert goes [0, num_experts] (inclusive) because we want a count + // for the total number of active tokens at the end of the scan. + if (expert >= num_experts + 1) { + return; + } + expert_first_token_offset[expert] = + findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert); +} + +void computeExpertFirstTokenOffset(int const* sorted_indices, + int const total_indices, + int const num_experts, + int64_t* expert_first_token_offset, + cudaStream_t stream) { + int const num_entries = num_experts + 1; + int const threads = std::min(1024, num_entries); + int const blocks = (num_entries + threads - 1) / threads; + + computeExpertFirstTokenOffsetKernel<<>>( + sorted_indices, total_indices, num_experts, expert_first_token_offset); +} + +void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, + int* permuted_experts, int* permuted_rows, + int64_t* expert_first_token_offset, int num_rows, + int num_experts, int num_experts_per_node, int k, + CubKeyValueSorter& sorter, void* sorter_ws, + cudaStream_t stream) { + int64_t const expanded_num_rows = static_cast(k) * num_rows; + // We need to use the full num_experts because that is the sentinel value used + // by topk for disabled experts + sorter.updateNumExperts(num_experts); + size_t const sorter_ws_size_bytes = pad_to_multiple_of_16( + sorter.getWorkspaceSize(expanded_num_rows, num_experts)); + sorter.run((void*)sorter_ws, sorter_ws_size_bytes, expert_for_source_row, + permuted_experts, source_rows, permuted_rows, expanded_num_rows, + stream); + computeExpertFirstTokenOffset(permuted_experts, expanded_num_rows, + num_experts_per_node, expert_first_token_offset, + stream); +} + +__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size, + const int* expert_map_ptr, + int num_experts) { + auto tidx = threadIdx.x; + auto bidx = blockIdx.x; + auto lidx = tidx & 31; + auto widx = tidx >> 5; + auto warp_count = (blockDim.x + 31) >> 5; + auto offset = bidx * blockDim.x; + auto bound = min(offset + blockDim.x, size); + extern __shared__ int smem_expert_map[]; + // store expert_map in smem + for (int i = tidx; i < num_experts; i += blockDim.x) { + smem_expert_map[i] = expert_map_ptr[i]; + } + __syncthreads(); + + // query global expert id in expert map. + // if global expert id = -1 in exert map, plus n_expert + // else set global expert id = exert map[global expert id] + if (offset + tidx < bound) { + auto topk_id = topk_id_ptr[offset + tidx]; + auto local_expert_idx = smem_expert_map[topk_id]; + if (local_expert_idx == -1) { + topk_id += num_experts; + } else { + topk_id = local_expert_idx; + } + __syncwarp(); + topk_id_ptr[offset + tidx] = topk_id; + } +} +void preprocessTopkIdLauncher(int* topk_id_ptr, int size, + const int* expert_map_ptr, int num_experts, + cudaStream_t stream) { + int block = std::min(size, 1024); + int grid = (size + block - 1) / block; + int smem_size = (num_experts) * sizeof(int); + preprocessTopkIdKernel<<>>( + topk_id_ptr, size, expert_map_ptr, num_experts); +} + +template +__global__ void getMIndicesKernel(int64_t* expert_first_token_offset, + int64_t* align_expert_first_token_offset, + int* m_indices, const int num_local_expert, + const int align_block_size) { + int eidx = blockIdx.x; + int tidx = threadIdx.x; + extern __shared__ int64_t smem_expert_first_token_offset[]; + for (int i = tidx; i <= num_local_expert; i += blockDim.x) { + smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i); + } + __syncthreads(); + auto last_token_offset = smem_expert_first_token_offset[eidx + 1]; + auto first_token_offset = smem_expert_first_token_offset[eidx]; + int n_token_in_expert = last_token_offset - first_token_offset; + + if constexpr (ALIGN_BLOCK_SIZE) { + n_token_in_expert = (n_token_in_expert + align_block_size - 1) / + align_block_size * align_block_size; + // round up to ALIGN_BLOCK_SIZE + int64_t accumulate_align_offset = 0; + for (int i = 1; i <= eidx + 1; i++) { + int n_token = smem_expert_first_token_offset[i] - + smem_expert_first_token_offset[i - 1]; + accumulate_align_offset = + accumulate_align_offset + (n_token + align_block_size - 1) / + align_block_size * align_block_size; + if (i == eidx) { + first_token_offset = accumulate_align_offset; + } + // last block store align_expert_first_token_offset + if (eidx == num_local_expert - 1 && threadIdx.x == 0) { + align_expert_first_token_offset[i] = accumulate_align_offset; + } + } + } + for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) { + // update m_indice with expert id + m_indices[first_token_offset + idx] = eidx; + } +} + +void getMIndices(int64_t* expert_first_token_offset, + int64_t* align_expert_first_token_offset, int* m_indices, + int num_local_expert, const int align_block_size, + cudaStream_t stream) { + int block = 256; + int grid = num_local_expert; + int smem_size = sizeof(int64_t) * (num_local_expert + 1); + if (align_block_size == -1) { + getMIndicesKernel<<>>( + expert_first_token_offset, align_expert_first_token_offset, m_indices, + num_local_expert, align_block_size); + } else { + getMIndicesKernel<<>>( + expert_first_token_offset, align_expert_first_token_offset, m_indices, + num_local_expert, align_block_size); + } +} \ No newline at end of file diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h new file mode 100644 index 000000000000..43c29721cd16 --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -0,0 +1,95 @@ +#pragma once +// reference from tensorrt_llm moe kernel implementation archive in +// https://github.com/BBuf/tensorrt-llm-moe/tree/master + +#include +#include +#include "dispatch.h" +#include +#include +#include +#include "cutlass/numeric_size.h" +#include "cutlass/array.h" + +template +inline T* get_ptr(torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} + +template +inline const T* get_ptr(const torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} + +class CubKeyValueSorter { + public: + CubKeyValueSorter(); + + CubKeyValueSorter(int const num_experts); + + void updateNumExperts(int const num_experts); + + static size_t getWorkspaceSize(size_t const num_key_value_pairs, + int const num_experts); + + void run(void* workspace, size_t const workspace_size, int const* keys_in, + int* keys_out, int const* values_in, int* values_out, + size_t const num_key_value_pairs, cudaStream_t stream); + + private: + static int expertsToBits(int experts); + int num_experts_; + int num_bits_; +}; + +void computeExpertFirstTokenOffset(int const* sorted_indices, + int const total_indices, + int const num_experts, + int64_t* expert_first_token_offset, + cudaStream_t stream); + +void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, + int* permuted_experts, int* permuted_rows, + int64_t* expert_first_token_offset, int num_rows, + int num_experts, int num_experts_per_node, int k, + CubKeyValueSorter& sorter, void* sorter_ws, + cudaStream_t stream); + +template +void expandInputRowsKernelLauncher( + T const* unpermuted_input, T* permuted_output, + const float* unpermuted_scales, int* sorted_experts, + int const* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, + int num_local_experts, const int& align_block_size, cudaStream_t stream); + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and +// performs the final skip connection. +template +__global__ void finalizeMoeRoutingKernel( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, + int64_t const* num_valid_ptr); + +template +void finalizeMoeRoutingKernelLauncher( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const num_rows, + int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, + cudaStream_t stream); + +void preprocessTopkIdLauncher(int* topk_id_ptr, int size, + const int* expert_map_ptr, int num_experts, + cudaStream_t stream); + +void getMIndices(int64_t* expert_first_token_offset, + int64_t* align_expert_first_token_offset, int* m_indices, + int num_local_expert, const int align_block_size, + cudaStream_t stream); + +#include "moe_permute_unpermute_kernel.inl" diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl new file mode 100644 index 000000000000..42441800fb11 --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -0,0 +1,211 @@ +#pragma once + +template +__global__ void expandInputRowsKernel( + T const* unpermuted_input, T* permuted_output, + const float* unpermuted_scales, int* sorted_experts, + int const* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* num_dest_rows, int64_t const cols, int64_t k, + int num_local_experts, int align_block_size) { + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + int64_t expanded_dest_row = blockIdx.x; + int64_t const expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + int expert_id = sorted_experts[expanded_dest_row]; + + extern __shared__ int64_t smem_expert_first_token_offset[]; + int64_t align_expanded_row_accumulate = 0; + if constexpr (ALIGN_BLOCK_SIZE) { + // load g2s + for (int idx = threadIdx.x; idx < num_local_experts + 1; + idx += blockDim.x) { + smem_expert_first_token_offset[idx] = + __ldg(expert_first_token_offset + idx); + } + __syncthreads(); + int lane_idx = threadIdx.x & 31; + + if (lane_idx == 0) { + // set token_offset_in_expert = 0 if this expert is not local expert + int token_offset_in_expert = + expert_id >= num_local_experts + ? 0 + : expanded_dest_row - smem_expert_first_token_offset[expert_id]; + int64_t accumulate_align_offset = 0; +#pragma unroll 1 + for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) { + auto n_token_in_expert = smem_expert_first_token_offset[eidx] - + smem_expert_first_token_offset[eidx - 1]; + accumulate_align_offset += (n_token_in_expert + align_block_size - 1) / + align_block_size * align_block_size; + } + expanded_dest_row = accumulate_align_offset + token_offset_in_expert; + } + // lane0 shuffle broadcast align_expanded_dest_row + expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0); + } + + if (threadIdx.x == 0) { + assert(expanded_dest_row <= INT32_MAX); + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + static_cast(expanded_dest_row); + } + + if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { + // Load 128-bits per thread + constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits::value; + using DataElem = cutlass::Array; + + // Duplicate and permute rows + int64_t const source_k_rank = expanded_source_row / num_rows; + int64_t const source_row = expanded_source_row % num_rows; + + auto const* source_row_ptr = + reinterpret_cast(unpermuted_input + source_row * cols); + auto* dest_row_ptr = + reinterpret_cast(permuted_output + expanded_dest_row * cols); + + int64_t const start_offset = threadIdx.x; + int64_t const stride = blockDim.x; + int64_t const num_elems_in_col = cols / ELEM_PER_THREAD; + + for (int elem_index = start_offset; elem_index < num_elems_in_col; + elem_index += stride) { + dest_row_ptr[elem_index] = source_row_ptr[elem_index]; + } + } +} + +template +void expandInputRowsKernelLauncher( + T const* unpermuted_input, T* permuted_output, + const float* unpermuted_scales, int* sorted_experts, + int const* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, + int num_local_experts, const int& align_block_size, cudaStream_t stream) { + int64_t const blocks = num_rows * k; + int64_t const threads = 256; + using FuncPtr = decltype(&expandInputRowsKernel); + FuncPtr func_map[2][2] = { + {&expandInputRowsKernel, + &expandInputRowsKernel}, + {&expandInputRowsKernel, + &expandInputRowsKernel}, + }; + bool is_check_skip = num_valid_tokens_ptr != nullptr; + bool is_align_block_size = align_block_size != -1; + auto func = func_map[is_check_skip][is_align_block_size]; + + int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1); + + func<<>>( + unpermuted_input, permuted_output, unpermuted_scales, sorted_experts, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, expert_first_token_offset, + num_rows, num_valid_tokens_ptr, cols, k, num_local_experts, + align_block_size); +} + +template +__host__ __device__ constexpr static U arrayConvert(T const& input) { + using Type = typename U::Element; + static_assert(T::kElements == U::kElements); + U u; +#pragma unroll + for (int i = 0; i < U::kElements; i++) { + u[i] = static_cast(input[i]); + } + return u; +} + +template +__global__ void finalizeMoeRoutingKernel( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, + int64_t const* num_valid_ptr) { + assert(orig_cols % 4 == 0); + int64_t const original_row = blockIdx.x; + int64_t const num_rows = gridDim.x; + auto const offset = original_row * orig_cols; + OutputType* reduced_row_ptr = reduced_unpermuted_output + offset; + int64_t const num_valid = *num_valid_ptr; + + // Load 128-bits per thread, according to the smallest data type we read/write + constexpr int64_t FINALIZE_ELEM_PER_THREAD = + 128 / std::min(cutlass::sizeof_bits::value, + cutlass::sizeof_bits::value); + + int64_t const start_offset = threadIdx.x; + int64_t const stride = blockDim.x; + int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD; + + using InputElem = cutlass::Array; + using OutputElem = cutlass::Array; + using ComputeElem = cutlass::Array; + auto const* expanded_permuted_rows_v = + reinterpret_cast(expanded_permuted_rows); + auto* reduced_row_ptr_v = reinterpret_cast(reduced_row_ptr); + +#pragma unroll + for (int elem_index = start_offset; elem_index < num_elems_in_col; + elem_index += stride) { + ComputeElem thread_output; + thread_output.fill(0); + float row_rescale{0.f}; + for (int k_idx = 0; k_idx < k; ++k_idx) { + int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_permuted_row = + expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + int64_t const k_offset = original_row * k + k_idx; + float const row_scale = scales[k_offset]; + + // Check after row_rescale has accumulated + if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) { + continue; + } + + auto const* expanded_permuted_rows_row_ptr = + expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; + + int64_t const expert_idx = expert_for_source_row[k_offset]; + + ComputeElem expert_result = arrayConvert( + expanded_permuted_rows_row_ptr[elem_index]); + thread_output = thread_output + row_scale * (expert_result); + } + + OutputElem output_elem = + arrayConvert(thread_output); + reduced_row_ptr_v[elem_index] = output_elem; + } +} + +template +void finalizeMoeRoutingKernelLauncher( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const num_rows, + int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, + cudaStream_t stream) { + int64_t const blocks = num_rows; + int64_t const threads = 256; + bool const check_finished = num_valid_ptr != nullptr; + using FuncPtr = decltype(&finalizeMoeRoutingKernel); + FuncPtr func_map[2] = {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}; + auto* const kernel = func_map[check_finished]; + kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, + num_valid_ptr); +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d0de42251f97..2a8b9bb39caa 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -53,7 +53,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "int size_m, int size_n, int size_k," "bool is_full_k, bool use_atomic_add," "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + m.def( + "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " + "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " + "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " + "int b_q_type, SymInt size_m, " + "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int " + "topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); + + m.def( + "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," + "Tensor token_expert_indicies, Tensor? expert_map, int n_expert," + "int n_local_expert," + "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " + "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " + "m_indices)->()"); + m.def( + "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," + "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " + "expert_first_token_offset, int n_expert, int n_local_expert,int " + "topk, Tensor! hidden_states)->()"); // conditionally compiled so impl registration is in source file #endif diff --git a/docs/source/contributing/overview.md b/docs/source/contributing/overview.md index 7c4016cae1d3..89b31f0311e2 100644 --- a/docs/source/contributing/overview.md +++ b/docs/source/contributing/overview.md @@ -58,6 +58,12 @@ Therefore, we recommend developing with Python 3.12 to minimise the chance of yo Currently, the repository is not fully checked by `mypy`. ::: +:::{note} +Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU +platform to run unit tests locally, rely on the continuous integration system to run the tests for +now. +::: + ## Issues If you encounter a bug or have a feature request, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible. diff --git a/docs/source/features/quantization/fp8.md b/docs/source/features/quantization/fp8.md index b90bb49ef87e..95e105357bd3 100644 --- a/docs/source/features/quantization/fp8.md +++ b/docs/source/features/quantization/fp8.md @@ -30,6 +30,7 @@ from vllm import LLM model = LLM("facebook/opt-125m", quantization="fp8") # INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB result = model.generate("Hello, my name is") +print(result[0].outputs[0].text) ``` :::{warning} @@ -105,7 +106,8 @@ Load and run the model in `vllm`: ```python from vllm import LLM model = LLM("./Meta-Llama-3-8B-Instruct-FP8-Dynamic") -model.generate("Hello my name is") +result = model.generate("Hello my name is") +print(result[0].outputs[0].text) ``` Evaluate accuracy with `lm_eval` (for example on 250 samples of `gsm8k`): @@ -188,4 +190,5 @@ from vllm import LLM model = LLM(model="Meta-Llama-3-8B-Instruct-FP8/") # INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB result = model.generate("Hello, my name is") +print(result[0].outputs[0].text) ``` diff --git a/docs/source/features/quantization/index.md b/docs/source/features/quantization/index.md index c7c8aeb662a5..7ad46b7094ee 100644 --- a/docs/source/features/quantization/index.md +++ b/docs/source/features/quantization/index.md @@ -17,6 +17,7 @@ gptqmodel int4 int8 fp8 +modelopt quark quantized_kvcache torchao diff --git a/docs/source/features/quantization/modelopt.md b/docs/source/features/quantization/modelopt.md new file mode 100644 index 000000000000..001d18657dad --- /dev/null +++ b/docs/source/features/quantization/modelopt.md @@ -0,0 +1,78 @@ +# NVIDIA TensorRT Model Optimizer + +The [NVIDIA TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a library designed to optimize models for inference with NVIDIA GPUs. It includes tools for Post-Training Quantization (PTQ) and Quantization Aware Training (QAT) of Large Language Models (LLMs), Vision Language Models (VLMs), and diffusion models. + +We recommend installing the library with: + +```console +pip install nvidia-modelopt +``` + +## Quantizing HuggingFace Models with PTQ + +You can quantize HuggingFace models using the example scripts provided in the TensorRT Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory. + +Below is an example showing how to quantize a model using modelopt's PTQ API: + +```python +import modelopt.torch.quantization as mtq +from transformers import AutoModelForCausalLM + +# Load the model from HuggingFace +model = AutoModelForCausalLM.from_pretrained("") + +# Select the quantization config, for example, FP8 +config = mtq.FP8_DEFAULT_CFG + +# Define a forward loop function for calibration +def forward_loop(model): + for data in calib_set: + model(data) + +# PTQ with in-place replacement of quantized modules +model = mtq.quantize(model, config, forward_loop) +``` + +After the model is quantized, you can export it to a quantized checkpoint using the export API: + +```python +import torch +from modelopt.torch.export import export_hf_checkpoint + +with torch.inference_mode(): + export_hf_checkpoint( + model, # The quantized model. + export_dir, # The directory where the exported files will be stored. + ) +``` + +The quantized checkpoint can then be deployed with vLLM. As an example, the following code shows how to deploy `nvidia/Llama-3.1-8B-Instruct-FP8`, which is the FP8 quantized checkpoint derived from `meta-llama/Llama-3.1-8B-Instruct`, using vLLM: + +```python +from vllm import LLM, SamplingParams + +def main(): + + model_id = "nvidia/Llama-3.1-8B-Instruct-FP8" + # Ensure you specify quantization='modelopt' when loading the modelopt checkpoint + llm = LLM(model=model_id, quantization="modelopt", trust_remote_code=True) + + sampling_params = SamplingParams(temperature=0.8, top_p=0.9) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +if __name__ == "__main__": + main() +``` diff --git a/docs/source/features/quantization/supported_hardware.md b/docs/source/features/quantization/supported_hardware.md index 08893f0e9595..f8af1ba60b12 100644 --- a/docs/source/features/quantization/supported_hardware.md +++ b/docs/source/features/quantization/supported_hardware.md @@ -129,7 +129,17 @@ The table below shows the compatibility of various quantization implementations * ❌ * ❌ * ❌ - +- * modelopt + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎︎ + * ❌ + * ❌ + * ❌ + * ❌ + * ❌ ::: - Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index c75a990120e0..c2c28d5ae6ae 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -47,8 +47,7 @@ def get_mixed_modalities_query() -> QueryResult: "image": ImageAsset("cherry_blossom").pil_image.convert("RGB"), "video": - VideoAsset(name="sample_demo_1.mp4", - num_frames=16).np_ndarrays, + VideoAsset(name="baby_reading", num_frames=16).np_ndarrays, }, }, limit_mm_per_prompt={ @@ -66,7 +65,7 @@ def get_use_audio_in_video_query() -> QueryResult: "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" f"{question}<|im_end|>\n" f"<|im_start|>assistant\n") - asset = VideoAsset(name="sample_demo_1.mp4", num_frames=16) + asset = VideoAsset(name="baby_reading", num_frames=16) audio = asset.get_audio(sampling_rate=16000) assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. " "Please launch this example with " diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 755e19bb2699..aca11f5c50ba 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1109,7 +1109,7 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question - video = VideoAsset(name="sample_demo_1.mp4", + video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays vid_questions = ["Why is this video funny?"] diff --git a/requirements/neuron.txt b/requirements/neuron.txt index 5f25bd0546e6..f8e3030834e2 100644 --- a/requirements/neuron.txt +++ b/requirements/neuron.txt @@ -2,5 +2,7 @@ -r common.txt # Dependencies for Neuron devices +packaging>=24.2 +setuptools>=77.0.3,<80.0.0 torch-neuronx >= 2.5.0 neuronx-cc diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 199bcafe0bdd..e2711354ac10 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -23,5 +23,11 @@ runai-model-streamer-s3==0.11.0 tensorizer>=2.9.0 lm-eval==0.4.8 buildkite-test-collector==0.1.9 - lm-eval[api]==0.4.8 # required for model evaluation test + +# required for quantization test +bitsandbytes>=0.45.3 + +# required for minicpmo_26 test +vector_quantize_pytorch +vocos diff --git a/requirements/test.txt b/requirements/test.txt index d4c92f15025f..9a15d9a0d824 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match +# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128 absl-py==2.1.0 # via rouge-score accelerate==1.0.1 @@ -349,28 +349,28 @@ numpy==1.26.4 # transformers # tritonclient # vocos -nvidia-cublas-cu12==12.6.4.1 +nvidia-cublas-cu12==12.8.3.14 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.6.80 +nvidia-cuda-cupti-cu12==12.8.57 # via torch -nvidia-cuda-nvrtc-cu12==12.6.77 +nvidia-cuda-nvrtc-cu12==12.8.61 # via torch -nvidia-cuda-runtime-cu12==12.6.77 +nvidia-cuda-runtime-cu12==12.8.57 # via torch -nvidia-cudnn-cu12==9.5.1.17 +nvidia-cudnn-cu12==9.7.1.26 # via torch -nvidia-cufft-cu12==11.3.0.4 +nvidia-cufft-cu12==11.3.3.41 # via torch -nvidia-cufile-cu12==1.11.1.6 +nvidia-cufile-cu12==1.13.0.11 # via torch -nvidia-curand-cu12==10.3.7.77 +nvidia-curand-cu12==10.3.9.55 # via torch -nvidia-cusolver-cu12==11.7.1.2 +nvidia-cusolver-cu12==11.7.2.55 # via torch -nvidia-cusparse-cu12==12.5.4.2 +nvidia-cusparse-cu12==12.5.7.53 # via # nvidia-cusolver-cu12 # torch @@ -378,13 +378,13 @@ nvidia-cusparselt-cu12==0.6.3 # via torch nvidia-nccl-cu12==2.26.2 # via torch -nvidia-nvjitlink-cu12==12.6.85 +nvidia-nvjitlink-cu12==12.8.61 # via # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.6.77 +nvidia-nvtx-cu12==12.8.55 # via torch opencv-python-headless==4.11.0.86 # via @@ -687,7 +687,7 @@ tomli==2.2.1 # via schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.7.0 +torch==2.7.0+cu128 # via # -r requirements/test.in # accelerate @@ -705,12 +705,12 @@ torch==2.7.0 # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.7.0 +torchaudio==2.7.0+cu128 # via # -r requirements/test.in # encodec # vocos -torchvision==0.22.0 +torchvision==0.22.0+cu128 # via # -r requirements/test.in # timm diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 0b76779b3a75..b6b45d1cbe88 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -103,7 +103,8 @@ def test_compile_correctness( method = test_setting.method fullgraph = test_setting.fullgraph if cuda_device_count_stateless() != pp_size * tp_size: - pytest.skip("Not correct CUDA devices for the test.") + pytest.skip(f"Need exactly {pp_size}*{tp_size} CUDA gpus but got " + f"{cuda_device_count_stateless()}") with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) diff --git a/tests/conftest.py b/tests/conftest.py index f02b5a8c0520..b1b4af86fab7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 - import json import os import tempfile -from collections import UserList from enum import Enum from typing import Any, Callable, Optional, TypedDict, TypeVar, Union @@ -58,16 +56,12 @@ def _read_prompts(filename: str) -> list[str]: return prompts -class _ImageAssetPrompts(TypedDict): +class ImageAssetPrompts(TypedDict): stop_sign: str cherry_blossom: str -class _ImageAssetsBase(UserList[ImageAsset]): - pass - - -class _ImageAssets(_ImageAssetsBase): +class ImageTestAssets(list[ImageAsset]): def __init__(self) -> None: super().__init__([ @@ -75,7 +69,7 @@ def __init__(self) -> None: ImageAsset("cherry_blossom"), ]) - def prompts(self, prompts: _ImageAssetPrompts) -> list[str]: + def prompts(self, prompts: ImageAssetPrompts) -> list[str]: """ Convenience method to define the prompt for each test image. @@ -85,30 +79,27 @@ def prompts(self, prompts: _ImageAssetPrompts) -> list[str]: return [prompts["stop_sign"], prompts["cherry_blossom"]] -class _VideoAssetPrompts(TypedDict): - sample_demo_1: str - - -class _VideoAssetsBase(UserList[VideoAsset]): - pass +class VideoAssetPrompts(TypedDict): + baby_reading: str -class _VideoAssets(_VideoAssetsBase): +class VideoTestAssets(list[VideoAsset]): def __init__(self) -> None: super().__init__([ - VideoAsset("sample_demo_1.mp4"), + VideoAsset("baby_reading"), ]) - def prompts(self, prompts: _VideoAssetPrompts) -> list[str]: - return [prompts["sample_demo_1"]] + def prompts(self, prompts: VideoAssetPrompts) -> list[str]: + return [prompts["baby_reading"]] -class _AudioAssetsBase(UserList[AudioAsset]): - pass +class AudioAssetPrompts(TypedDict): + mary_had_lamb: str + winning_call: str -class _AudioAssets(_AudioAssetsBase): +class AudioTestAssets(list[AudioAsset]): def __init__(self) -> None: super().__init__([ @@ -116,13 +107,16 @@ def __init__(self) -> None: AudioAsset("winning_call"), ]) + def prompts(self, prompts: AudioAssetPrompts) -> list[str]: + return [prompts["mary_had_lamb"], prompts["winning_call"]] -IMAGE_ASSETS = _ImageAssets() -"""Singleton instance of :class:`_ImageAssets`.""" -VIDEO_ASSETS = _VideoAssets() -"""Singleton instance of :class:`_VideoAssets`.""" -AUDIO_ASSETS = _AudioAssets() -"""Singleton instance of :class:`_AudioAssets`.""" + +IMAGE_ASSETS = ImageTestAssets() +"""Singleton instance of :class:`ImageTestAssets`.""" +VIDEO_ASSETS = VideoTestAssets() +"""Singleton instance of :class:`VideoTestAssets`.""" +AUDIO_ASSETS = AudioTestAssets() +"""Singleton instance of :class:`AudioTestAssets`.""" @pytest.fixture(scope="function", autouse=True) @@ -270,17 +264,17 @@ def example_long_prompts() -> list[str]: @pytest.fixture(scope="session") -def image_assets() -> _ImageAssets: +def image_assets() -> ImageTestAssets: return IMAGE_ASSETS @pytest.fixture(scope="session") -def video_assets() -> _VideoAssets: +def video_assets() -> VideoTestAssets: return VIDEO_ASSETS @pytest.fixture(scope="session") -def audio_assets() -> _AudioAssets: +def audio_assets() -> AudioTestAssets: return AUDIO_ASSETS @@ -779,7 +773,7 @@ def __init__( def get_inputs( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, @@ -801,16 +795,18 @@ def get_inputs( if audios is not None and (audio := audios[i]) is not None: multi_modal_data["audio"] = audio - inputs.append( - TextPrompt(prompt=prompt, - multi_modal_data=multi_modal_data - if multi_modal_data else None)) + text_prompt_kwargs = { + ("prompt" if isinstance(prompt, str) else "prompt_embeds"): + prompt, + "multi_modal_data": multi_modal_data or None + } + inputs.append(TextPrompt(**text_prompt_kwargs)) return inputs def generate( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, @@ -836,7 +832,7 @@ def generate( output_str = sample.text output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) - req_sample_output_strs.append(prompt_str + output_str) + req_sample_output_strs.append((prompt_str or "") + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs @@ -903,7 +899,7 @@ def generate_encoder_decoder_w_logprobs( def generate_greedy( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 8bd64923fe22..a5ba16898d89 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -2,16 +2,18 @@ import time from collections import deque +from typing import Optional from unittest.mock import MagicMock import pytest # noqa +import torch from torch import Use # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroup +from vllm.sequence import SequenceGroup, SequenceStatus from .utils import (append_new_token, append_new_token_seq, append_new_token_seq_group, create_dummy_prompt, @@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( ), "A partial prefix of C (4 tokens) should be prefilled, with the " "remaining tokens fit into 3 token budget (4-1 from the seqA). It will " "then be rounded down to 2 tokens on block size, thus 6 tokens in total." + + +def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds(): + """ + Test that the scheduler does not schedule batches with prompt tokens and + prompt embeddings co-mingled. + """ + block_size = 2 + max_seq_group = 3 + scheduler = initialize_scheduler( + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16, + max_num_seqs=max_seq_group, + max_model_len=100, + enable_prefix_caching=True, + ) + + # the odd indexed inputs should be passed in via embeddings, + # evens via token_ids + seq_length = 7 + embedding_size = 5 + num_seqs = 11 + seq_tokens: list[list[int]] = [] + seq_embeds: list[Optional[torch.Tensor]] = [] + for i in range(num_seqs): + if i % 2: + seq_tokens.append(list(range(seq_length))) + seq_embeds.append(None) + else: + seq_tokens.append([0] * seq_length) + seq_embeds.append(torch.rand(embedding_size)) + + seq_and_seq_groups = [ + create_dummy_prompt(f"{i}", + prompt_tokens=seq_tokens[i], + prompt_embeds=seq_embeds[i], + block_size=block_size) + for i in range(len(seq_tokens)) + ] + + for _, seq_group in seq_and_seq_groups: + scheduler.add_seq_group(seq_group) + + while not all(seq.is_finished() for seq, _ in seq_and_seq_groups): + unfinished_seq_groups = [ + seq_group for _, seq_group in seq_and_seq_groups + if not seq_group.is_finished() + ] + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) > 0 + batch_is_prompt_embeds = out.scheduled_seq_groups[ + 0].seq_group.uses_prompt_embeds() + expected_scheduled_seq_groups = [ + seq_group for seq_group in unfinished_seq_groups + if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds + ] + + # We should have as many scheduled groups as possible, without mixing + assert len(out.scheduled_seq_groups) == min( + max_seq_group, len(expected_scheduled_seq_groups)) + assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() == + batch_is_prompt_embeds + for scheduled_seq_group in out.scheduled_seq_groups) + + # Finish the scheduled groups + for scheduled_seq_group in out.scheduled_seq_groups: + for seq in scheduled_seq_group.seq_group.seqs: + seq.status = SequenceStatus.FINISHED_STOPPED + scheduler.free_finished_seq_groups() diff --git a/tests/core/utils.py b/tests/core/utils.py index ea18b879a317..84b0426b470b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -5,9 +5,11 @@ from collections.abc import Sequence as GenericSequence from typing import Any, Optional +import torch + from vllm import SamplingParams from vllm.core.scheduler import Scheduler, SchedulerOutputs -from vllm.inputs import EncoderDecoderInputs, token_inputs +from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import (Logprob, Sequence, SequenceGroup, SequenceGroupMetadata) @@ -19,6 +21,7 @@ def create_dummy_prompt( block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_tokens: Optional[list[int]] = None, + prompt_embeds: Optional[torch.Tensor] = None, min_tokens: int = 0, max_tokens: int = 16, ) -> tuple[Sequence, SequenceGroup]: @@ -31,9 +34,13 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) + inputs = token_inputs( + prompt_token_ids=prompt_tokens, + prompt=prompt_str) if prompt_embeds is None else embeds_inputs( + prompt_embeds=prompt_embeds) prompt = Sequence( int(request_id), - inputs=token_inputs(prompt_tokens, prompt=prompt_str), + inputs=inputs, block_size=block_size, ) seq_group = SequenceGroup( diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 16721ee9ce74..65471cb3af38 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -106,6 +106,8 @@ class DummyConfigClass: """List with literal choices""" literal_literal: Literal[Literal[1], Literal[2]] = 1 """Literal of literals with default 1""" + json_tip: dict = field(default_factory=dict) + """Dict which will be JSON in CLI""" @pytest.mark.parametrize(("type_hint", "expected"), [ @@ -137,6 +139,9 @@ def test_get_kwargs(): assert kwargs["list_literal"]["choices"] == [1, 2] # literals of literals should have merged choices assert kwargs["literal_literal"]["choices"] == [1, 2] + # dict should have json tip in help + json_tip = "\n\nShould be a valid JSON string." + assert kwargs["json_tip"]["help"].endswith(json_tip) @pytest.mark.parametrize(("arg", "expected"), [ diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 425f36984a33..f2cca65ae420 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -420,7 +420,8 @@ def test_fused_marlin_moe( score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score, topk, False) torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py new file mode 100644 index 000000000000..dfcd61f77587 --- /dev/null +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE permute/unpermute kernel + +Run `pytest tests/kernels/test_moe_permute_unpermute.py`. +""" + +from typing import Optional + +import numpy as np +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.layer import determine_expert_map +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + moe_permute, moe_unpermute) +from vllm.platforms import current_platform + +NUM_EXPERTS = [16, 64] +TOP_KS = [2, 4, 6, 8] +EP_SIZE = [1, 4, 16] +current_platform.seed_everything(0) + + +def torch_permute(hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + start_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1) -> list[torch.Tensor]: + n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] + if expert_map is not None: + is_local_expert = (expert_map[topk_ids] != -1) + not_local_expert = (expert_map[topk_ids] == -1) + topk_ids = is_local_expert * ( + topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) + + sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), + stable=True) + dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices] + + expert_first_token_offset = torch.zeros(n_local_expert + 1, + dtype=torch.int64, + device="cuda") + idx = 0 + for i in range(0, n_local_expert): + cnt = 0 + while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i: + cnt += 1 + idx += 1 + expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt + + _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map) + valid_row_idx = [] + if align_block_size is None: + + permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % + n_token, ...] + permuted_row_size = permuted_hidden_states.shape[0] + m_indices = torch.empty(permuted_row_size, + device="cuda", + dtype=torch.int32).fill_(fill_invalid_expert) + for i in range(1, n_local_expert + 1): + first_token_offset = expert_first_token_offset[i - 1] + last_token_offset = expert_first_token_offset[i] + m_indices[first_token_offset:last_token_offset] = i - 1 + src_row_id2dst_row_id_map = torch.arange( + 0, n_token * topk, device="cuda", + dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) + valid_row_idx += [i for i in range(expert_first_token_offset[-1])] + return [ + permuted_hidden_states, expert_first_token_offset, + src_row_id2dst_row_id_map, m_indices, valid_row_idx + ] + else: + permuted_row_size = (topk * n_token + n_expert * + (align_block_size - 1) + align_block_size - + 1) // align_block_size * align_block_size + permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), + device="cuda", + dtype=hidden_states.dtype) + align_src_row_id2dst_row_id = torch.empty(n_token * topk, + device="cuda", + dtype=torch.int32) + align_expert_first_token_offset = torch.zeros_like( + expert_first_token_offset) + m_indices = torch.empty(permuted_row_size, + device="cuda", + dtype=torch.int32).fill_(fill_invalid_expert) + # get align_permuted_hidden_states, + # valid row_idx and align_expert_first_token_offset + for i in range(1, n_local_expert + 1): + first_token_offset = expert_first_token_offset[i - 1] + last_token_offset = expert_first_token_offset[i] + n_token_in_expert = last_token_offset - first_token_offset + align_expert_first_token_offset[ + i] = align_expert_first_token_offset[ + i - 1] + (n_token_in_expert + align_block_size - + 1) // align_block_size * align_block_size + align_first_token_offset = align_expert_first_token_offset[i - 1] + align_last_token_offset = align_expert_first_token_offset[i] + dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ + first_token_offset:first_token_offset + + n_token_in_expert] % n_token + # store token in current expert with align_first_token_offset + permuted_hidden_states[align_first_token_offset:\ + align_first_token_offset+n_token_in_expert,\ + ...] = hidden_states[\ + dst_row_id2src_row_id_in_expert, ...] + # set current expert m_indices + m_indices[align_first_token_offset:align_last_token_offset] = i - 1 + valid_row_idx += [ + i for i in range(align_first_token_offset, + align_first_token_offset + n_token_in_expert) + ] + # get align_src_row_id2dst_row_id + for i in range(n_token * topk): + eid = sorted_topk_ids[i] + if (eid >= n_local_expert): + # check token not in local expert + align_src_row_id2dst_row_id[ + i] = align_expert_first_token_offset[-1] + continue + first_token_offset = expert_first_token_offset[eid] + align_first_token_offset = align_expert_first_token_offset[eid] + token_offset = i - first_token_offset + align_src_row_id2dst_row_id[ + i] = align_first_token_offset + token_offset + align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\ + src2dst_idx].reshape((n_token, topk)) + return [ + permuted_hidden_states, align_expert_first_token_offset, + align_src_row_id2dst_row_id, m_indices, valid_row_idx + ] + + +def torch_unpermute(permuted_hidden_states: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + src_row_id2dst_row_id_map: torch.Tensor, + valid_row_idx: torch.Tensor, topk: int, + n_expert: int) -> torch.Tensor: + # ignore invalid row + mask = torch.zeros(permuted_hidden_states.shape[0], + dtype=bool, + device="cuda") + mask[valid_row_idx] = True + permuted_hidden_states[~mask] = 0 + idx = src_row_id2dst_row_id_map.flatten()[ + token_expert_indices.flatten()].reshape(token_expert_indices.shape) + output = permuted_hidden_states[idx, ...] * topk_weights[..., None] + output = output.sum(dim=1).to(permuted_hidden_states.dtype) + return output + + +@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000]) +@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168]) +@pytest.mark.parametrize("n_expert", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("ep_size", EP_SIZE) +@pytest.mark.parametrize("align_block_size", [None, 128]) +def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, + n_expert: int, ep_size: int, dtype: torch.dtype, + align_block_size: Optional[int]): + fill_invalid_expert = 0 + ep_rank = np.random.randint(0, ep_size) + expert_map = None + n_local_expert = n_expert + if (ep_size != 1): + n_local_expert, expert_map = determine_expert_map( + ep_size, ep_rank, n_expert) + expert_map = expert_map.cuda() + start_expert = n_local_expert * ep_rank + current_platform.seed_everything(0) + hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype) + gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, gating_output, topk, False) + gold0, gold1, gold2, gold3, valid_row_idx = torch_permute( + hidden_states, + topk_ids, + token_expert_indices, + topk, + n_expert, + n_local_expert, + start_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert) + + result0, result1, result2, result3 = moe_permute( + hidden_states, topk_weights, topk_ids, token_expert_indices, topk, + n_expert, n_local_expert, expert_map, align_block_size, + fill_invalid_expert) + + # check expert_first_token_offset + torch.testing.assert_close(gold1, result1, atol=0, rtol=0) + # check src_row_id2dst_row_id_map + torch.testing.assert_close(gold2, result2, atol=0, rtol=0) + # check mindice + torch.testing.assert_close(gold3, result3, atol=0, rtol=0) + # check permuted_hidden_states, only valid token + torch.testing.assert_close(gold0[valid_row_idx], + result0[valid_row_idx], + atol=0, + rtol=0) + + # add a random tensor to simulate group gemm + result0 = 0.5 * result0 + torch.randn_like(result0) + + result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1, + topk, n_expert, n_local_expert) + gold4 = torch_unpermute(result0, topk_weights, topk_ids, + token_expert_indices, result2, valid_row_idx, topk, + n_local_expert) + + # check unpermuted hidden + torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/tests/kernels/quantization/test_awq_marlin.py b/tests/kernels/quantization/test_awq_marlin.py index 939b0e7157be..c30fe60becdf 100644 --- a/tests/kernels/quantization/test_awq_marlin.py +++ b/tests/kernels/quantization/test_awq_marlin.py @@ -84,7 +84,8 @@ def test_fused_marlin_moe_awq( score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score, topk, False) marlin_output = torch.ops.vllm.fused_marlin_moe( a, qweight1, diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c57e39f42506..38c7e461bb9c 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -338,7 +338,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, M, K = a.shape N = w2.shape[-1] - topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) + topk_weight, topk_ids, token_expert_indices = fused_topk( + a, score.float(), topk, False) block_m = deep_gemm.get_m_alignment_for_contiguous_layout() @@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score.float(), topk, False) out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) diff --git a/tests/kv_transfer/test_disagg.py b/tests/kv_transfer/test_disagg.py index 5b9ea6dba401..dc948a48bf32 100644 --- a/tests/kv_transfer/test_disagg.py +++ b/tests/kv_transfer/test_disagg.py @@ -14,8 +14,8 @@ # Fixture to set up environment variables and teardown servers after tests @pytest.fixture(scope="module", autouse=True) def setup_servers(): - if torch.cuda.device_count() < 4: - pytest.skip("Skipping test: fewer than 4 GPUs available") + if torch.cuda.device_count() < 2: + pytest.skip("Skipping test: fewer than 2 GPUs available") # Set up environment variables VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index ab2898ffb2d0..fcd3fa036cfd 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import os +from typing import Optional + import pytest import torch @@ -110,6 +113,18 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) + prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv( + "VLLM_USE_V1") == "0" else None + prompt_token_ids = [] + for prompt in example_prompts: + token_ids = hf_model.tokenizer(prompt, + return_tensors="pt").input_ids.to( + hf_model.model.device) + prompt_token_ids.append(token_ids) + if prompt_embeds is not None: + prompt_embeds.append(hf_model.model.get_input_embeddings()( + token_ids).squeeze(0)) + with vllm_runner( model, tokenizer_name=model_info.tokenizer or model, @@ -119,6 +134,9 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + if prompt_embeds is not None: + vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( + prompt_embeds, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -126,6 +144,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, name_0="hf", name_1="vllm", ) + if prompt_embeds is not None: + check_logprobs_close( + outputs_0_lst=vllm_outputs, + outputs_1_lst=vllm_outputs_from_embeds, + name_0="vllm", + name_1="vllm_from_embeds", + ) + if use_rocm_aiter: # this is to ensure that vllm engine # has deallocated the memory before running the next diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index b21c80bef927..44cdd6f44aa9 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -13,8 +13,8 @@ from vllm.platforms import current_platform from vllm.utils import identity -from ....conftest import (IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets, - _VideoAssets) +from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, + VideoTestAssets, VllmRunner) from ....utils import (create_new_process_for_each_test, large_gpu_mark, multi_gpu_marks) from ...utils import check_outputs_equal @@ -691,7 +691,7 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -716,7 +716,7 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -741,7 +741,7 @@ def test_image_embedding_models(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -763,7 +763,7 @@ def test_image_embedding_models(model_type: str, )) def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: _VideoAssets, monkeypatch): + video_assets: VideoTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -814,7 +814,7 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -840,7 +840,7 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -866,7 +866,8 @@ def test_image_embedding_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, + monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -889,7 +890,7 @@ def test_image_embedding_models_heavy(model_type: str, def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: _VideoAssets, monkeypatch): + video_assets: VideoTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] diff --git a/tests/models/multimodal/generation/test_florence2.py b/tests/models/multimodal/generation/test_florence2.py index 14b64393bf52..b8225f5f1243 100644 --- a/tests/models/multimodal/generation/test_florence2.py +++ b/tests/models/multimodal/generation/test_florence2.py @@ -9,7 +9,7 @@ from vllm.multimodal.image import rescale_image_size from vllm.sequence import SampleLogprobs -from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ....conftest import IMAGE_ASSETS, HfRunner, ImageTestAssets, VllmRunner from ...utils import check_logprobs_close MODELS = ["microsoft/Florence-2-base"] @@ -118,7 +118,7 @@ def run_test( @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, model: str, + image_assets: ImageTestAssets, model: str, size_factors: list[int], dtype: str, max_tokens: int, num_logprobs: int) -> None: images = [asset.pil_image for asset in image_assets] diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index 7c14845ec54d..96c444441e3d 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -9,7 +9,8 @@ from vllm.lora.request import LoRARequest from vllm.sequence import SampleLogprobs -from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets +from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput, + VllmRunner) from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close @@ -116,9 +117,9 @@ def run_test( @pytest.mark.parametrize("max_model_len", [2048]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets, - dtype: str, max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: +def test_models(hf_runner, vllm_runner, model: str, + audio_assets: AudioTestAssets, dtype: str, max_model_len: int, + max_tokens: int, num_logprobs: int) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") diff --git a/tests/models/multimodal/generation/test_interleaved.py b/tests/models/multimodal/generation/test_interleaved.py index 92c8155fe1e2..eec84751e450 100644 --- a/tests/models/multimodal/generation/test_interleaved.py +++ b/tests/models/multimodal/generation/test_interleaved.py @@ -29,7 +29,7 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None: image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB") image_stop = ImageAsset("stop_sign").pil_image.convert("RGB") images = [image_cherry, image_stop] - video = VideoAsset(name="sample_demo_1.mp4", num_frames=16).np_ndarrays + video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays inputs = [ ( diff --git a/tests/models/multimodal/generation/test_mllama.py b/tests/models/multimodal/generation/test_mllama.py index 1e09c8673dc3..99aa3c2d3bd9 100644 --- a/tests/models/multimodal/generation/test_mllama.py +++ b/tests/models/multimodal/generation/test_mllama.py @@ -14,8 +14,8 @@ from vllm.multimodal.image import rescale_image_size from vllm.sequence import SampleLogprobs -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) +from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, + PromptImageInput, VllmRunner) from ....quantization.utils import is_quant_method_supported from ....utils import (create_new_process_for_each_test, large_gpu_test, multi_gpu_test) @@ -90,7 +90,7 @@ def vllm_to_hf_output(vllm_output: tuple[list[int], str, def _get_inputs( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, *, size_factors: Optional[list[float]] = None, sizes: Optional[list[tuple[int, int]]] = None, @@ -126,7 +126,7 @@ def _get_inputs( def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model: str, *, size_factors: list[float], @@ -143,7 +143,7 @@ def run_test( def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model: str, *, sizes: list[tuple[int, int]], @@ -159,7 +159,7 @@ def run_test( def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model: str, *, size_factors: Optional[list[float]] = None, @@ -433,7 +433,7 @@ def test_models_distributed( @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') def test_bnb_regression( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model: str, dtype: str, max_tokens: int, @@ -473,7 +473,7 @@ def test_bnb_regression( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [32]) def test_explicit_implicit_prompt( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model: str, dtype: str, max_tokens: int, diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index 0b27a4caf6eb..6be401b775ec 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -50,7 +50,7 @@ def qwen2_vl_chat_template(*query): }) VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ - "sample_demo_1": + "baby_reading": qwen2_vl_chat_template( VIDEO_PLACEHOLDER, "Describe this video with a short sentence ", diff --git a/tests/models/multimodal/generation/test_ultravox.py b/tests/models/multimodal/generation/test_ultravox.py index 1d7de946a3f8..322d886a593d 100644 --- a/tests/models/multimodal/generation/test_ultravox.py +++ b/tests/models/multimodal/generation/test_ultravox.py @@ -11,13 +11,22 @@ from vllm.multimodal.audio import resample_audio_librosa from vllm.sequence import SampleLogprobs -from ....conftest import HfRunner, VllmRunner, _AudioAssets +from ....conftest import AUDIO_ASSETS, AudioTestAssets, HfRunner, VllmRunner from ....utils import RemoteOpenAIServer from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" +AUDIO_PROMPTS = AUDIO_ASSETS.prompts({ + "mary_had_lamb": + "Transcribe this into English.", + "winning_call": + "What is happening in this audio clip?", +}) + +MULTI_AUDIO_PROMPT = "Describe each of the audios above." + AudioTuple = tuple[np.ndarray, int] VLLM_PLACEHOLDER = "<|audio|>" @@ -31,12 +40,6 @@ } -@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call")) -def audio(request): - from vllm.assets.audio import AudioAsset - return AudioAsset(request.param) - - def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]: """Convert kwargs to CLI args.""" args = [] @@ -53,7 +56,7 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]: pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), ]) -def server(request, audio_assets: _AudioAssets): +def server(request, audio_assets: AudioTestAssets): args = [ "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", "--limit-mm-per-prompt", @@ -199,15 +202,19 @@ def run_multi_audio_test( pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), ]) -def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, - num_logprobs: int, vllm_kwargs: dict) -> None: +def test_models(hf_runner, vllm_runner, audio_assets: AudioTestAssets, + dtype: str, max_tokens: int, num_logprobs: int, + vllm_kwargs: dict) -> None: + audio_inputs = [( + _get_prompt(1, audio, VLLM_PLACEHOLDER), + _get_prompt(1, audio, HF_PLACEHOLDER), + audio.audio_and_sample_rate, + ) for audio in audio_assets] - vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER) - hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER) run_test( hf_runner, vllm_runner, - [(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)], + audio_inputs, MODEL_NAME, dtype=dtype, max_tokens=max_tokens, @@ -224,13 +231,12 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), ]) -def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets, - dtype: str, max_tokens: int, - num_logprobs: int, +def test_models_with_multiple_audios(vllm_runner, + audio_assets: AudioTestAssets, dtype: str, + max_tokens: int, num_logprobs: int, vllm_kwargs: dict) -> None: - vllm_prompt = _get_prompt(len(audio_assets), - "Describe each of the audios above.", + vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT, VLLM_PLACEHOLDER) run_multi_audio_test( vllm_runner, @@ -245,7 +251,7 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets, @pytest.mark.asyncio -async def test_online_serving(client, audio_assets: _AudioAssets): +async def test_online_serving(client, audio_assets: AudioTestAssets): """Exercises online serving with/without chunked prefill enabled.""" messages = [{ diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index bf5f87ebf984..e3ba955a96a6 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -11,7 +11,7 @@ from vllm.multimodal.video import (rescale_video_size, resize_video, sample_frames_from_video) -from .....conftest import _ImageAssets, _VideoAssets +from .....conftest import ImageTestAssets, VideoTestAssets from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER, TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT, ImageSizeWrapper, SizeType, VLMTestInfo) @@ -69,7 +69,7 @@ def get_model_prompts(base_prompts: Iterable[str], def build_single_image_inputs_from_test_info( test_info: VLMTestInfo, - image_assets: _ImageAssets, + image_assets: ImageTestAssets, size_wrapper: ImageSizeWrapper, tmp_path: Optional[PosixPath] = None): if test_info.prompt_formatter is None: @@ -116,7 +116,7 @@ def build_single_image_inputs(images, model_prompts, def build_multi_image_inputs_from_test_info( test_info: VLMTestInfo, - image_assets: _ImageAssets, + image_assets: ImageTestAssets, size_wrapper: ImageSizeWrapper, tmp_path: Optional[PosixPath] = None): if test_info.prompt_formatter is None: @@ -159,7 +159,7 @@ def build_multi_image_inputs(image_lists, model_prompts, def build_embedding_inputs_from_test_info( test_info: VLMTestInfo, - image_assets: _ImageAssets, + image_assets: ImageTestAssets, size_wrapper: ImageSizeWrapper, ): # These conditions will always be true if invoked through filtering, @@ -192,7 +192,7 @@ def build_embedding_inputs_from_test_info( def build_video_inputs_from_test_info( test_info: VLMTestInfo, - video_assets: _VideoAssets, + video_assets: VideoTestAssets, size_wrapper: ImageSizeWrapper, num_frames: int, ): diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index c856fb198b32..aa9d3901fa36 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -16,7 +16,7 @@ from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side -from .....conftest import HfRunner, ImageAsset, _ImageAssets +from .....conftest import HfRunner, ImageAsset, ImageTestAssets from .types import RunnerOutput @@ -238,14 +238,14 @@ def minimax_vl_01_hf_output(hf_output: RunnerOutput, ####### Functions for converting image assets to embeddings -def get_llava_embeddings(image_assets: _ImageAssets): +def get_llava_embeddings(image_assets: ImageTestAssets): return [asset.image_embeds for asset in image_assets] ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( - tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset], - _ImageAssets]) -> str: + tmp_path: PosixPath, prompt: str, + assets: Union[list[ImageAsset], ImageTestAssets]) -> str: """Given a temporary dir path, export one or more image assets into the tempdir & replace its contents with the local path to the string so that the HF version of Qwen-VL can resolve the path and load the image in its diff --git a/tests/models/multimodal/generation/vlm_utils/runners.py b/tests/models/multimodal/generation/vlm_utils/runners.py index 023df5f16188..34753121ea90 100644 --- a/tests/models/multimodal/generation/vlm_utils/runners.py +++ b/tests/models/multimodal/generation/vlm_utils/runners.py @@ -4,7 +4,8 @@ """ from pathlib import PosixPath -from .....conftest import HfRunner, VllmRunner, _ImageAssets, _VideoAssets +from .....conftest import (HfRunner, ImageTestAssets, VideoTestAssets, + VllmRunner) from . import builders, core from .types import ExpandableVLMTestArgs, VLMTestInfo @@ -14,7 +15,7 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets): + image_assets: ImageTestAssets): assert test_case.size_wrapper is not None inputs = builders.build_single_image_inputs_from_test_info( model_test_info, image_assets, test_case.size_wrapper, tmp_path) @@ -37,7 +38,7 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets): + image_assets: ImageTestAssets): assert test_case.size_wrapper is not None inputs = builders.build_multi_image_inputs_from_test_info( model_test_info, image_assets, test_case.size_wrapper, tmp_path) @@ -60,7 +61,7 @@ def run_embedding_test(*, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets): + image_assets: ImageTestAssets): assert test_case.size_wrapper is not None inputs, vllm_embeddings = builders.build_embedding_inputs_from_test_info( model_test_info, image_assets, test_case.size_wrapper) @@ -86,7 +87,7 @@ def run_video_test( test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: _VideoAssets, + video_assets: VideoTestAssets, ): assert test_case.size_wrapper is not None assert test_case.num_video_frames is not None diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index 1ae61ea47229..56629323394d 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -15,7 +15,7 @@ from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer -from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, _ImageAssets +from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, ImageTestAssets from ....utils import check_logprobs_close # meta image tag; will be replaced by the appropriate tag for the model @@ -85,7 +85,7 @@ class VLMTestInfo(NamedTuple): # Function for converting ImageAssets to image embeddings; # We need to define this explicitly for embedding tests - convert_assets_to_embeddings: Optional[Callable[[_ImageAssets], + convert_assets_to_embeddings: Optional[Callable[[ImageTestAssets], torch.Tensor]] = None # Exposed options for vLLM runner; we change these in a several tests, @@ -141,7 +141,7 @@ class VLMTestInfo(NamedTuple): # for Qwen-VL, which requires encoding the image path / url into the prompt # for HF runner prompt_path_encoder: Optional[ - Callable[[PosixPath, str, Union[list[ImageAsset], _ImageAssets]], + Callable[[PosixPath, str, Union[list[ImageAsset], ImageTestAssets]], str]] = None # noqa: E501 # Allows configuring a test to run with custom inputs diff --git a/tests/models/multimodal/generation/test_intern_vit.py b/tests/models/multimodal/pooling/test_intern_vit.py similarity index 84% rename from tests/models/multimodal/generation/test_intern_vit.py rename to tests/models/multimodal/pooling/test_intern_vit.py index a842d14fee2e..76f9fbe02550 100644 --- a/tests/models/multimodal/generation/test_intern_vit.py +++ b/tests/models/multimodal/pooling/test_intern_vit.py @@ -1,33 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - import pytest import torch import torch.nn as nn from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModel, CLIPImageProcessor -from ....conftest import _ImageAssets +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + +from ....conftest import ImageTestAssets # we use snapshot_download to prevent conflicts between # dynamic_module and trust_remote_code for hf_runner DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] +@torch.inference_mode() def run_intern_vit_test( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, *, dtype: str, - distributed_executor_backend: Optional[str] = None, ): model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN) + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] img_processor = CLIPImageProcessor.from_pretrained(model) images = [asset.pil_image for asset in image_assets] pixel_values = [ - img_processor(images, return_tensors='pt').pixel_values.to(dtype) + img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype) for images in images ] @@ -36,14 +37,13 @@ def run_intern_vit_test( config.norm_type = "rms_norm" hf_model = AutoModel.from_pretrained(model, - torch_dtype=dtype, + torch_dtype=torch_dtype, trust_remote_code=True).to("cuda") hf_outputs_per_image = [ hf_model(pixel_value.to("cuda")).last_hidden_state for pixel_value in pixel_values ] - from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.models.intern_vit import InternVisionModel vllm_model = InternVisionModel(config) vllm_model.load_weights(hf_model.state_dict().items()) @@ -51,7 +51,7 @@ def run_intern_vit_test( del hf_model cleanup_dist_env_and_memory() - vllm_model = vllm_model.to("cuda", dtype) + vllm_model = vllm_model.to("cuda", torch_dtype) vllm_outputs_per_image = [ vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values @@ -69,8 +69,7 @@ def run_intern_vit_test( "OpenGVLab/InternViT-300M-448px", "OpenGVLab/InternViT-6B-448px-V1-5", ]) -@pytest.mark.parametrize("dtype", [torch.half]) -@torch.inference_mode() +@pytest.mark.parametrize("dtype", ["half"]) def test_models(dist_init, image_assets, model_id, dtype: str) -> None: run_intern_vit_test( image_assets, diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 709a686577f3..37142b6dd36f 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -11,7 +11,7 @@ from vllm.multimodal.image import rescale_image_size from vllm.multimodal.processing import BaseMultiModalProcessor -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -137,7 +137,7 @@ def _run_check( @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( model_id: str, - image_assets: _ImageAssets, + image_assets: ImageTestAssets, size_factors: list[int], min_dynamic_patch: int, max_dynamic_patch: int, diff --git a/tests/models/multimodal/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py index f5b5cf6b5ba9..c35ce2f6ab29 100644 --- a/tests/models/multimodal/processing/test_idefics3.py +++ b/tests/models/multimodal/processing/test_idefics3.py @@ -5,7 +5,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -21,7 +21,7 @@ @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict[str, object], expected_toks_per_img: int, diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index 5ac47ecc5cc1..7ec81197a3db 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -11,7 +11,7 @@ from vllm.multimodal.image import rescale_image_size from vllm.multimodal.processing import BaseMultiModalProcessor -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -94,7 +94,7 @@ def _run_check( @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( model_id: str, - image_assets: _ImageAssets, + image_assets: ImageTestAssets, size_factors: list[int], min_dynamic_patch: int, max_dynamic_patch: int, diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index 2bfc2785feb6..614f17dbbeda 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -6,7 +6,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.transformers_utils.tokenizer import encode_tokens -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -17,7 +17,7 @@ @pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False]) @pytest.mark.parametrize("tokenized_prompt", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict, num_imgs: int, diff --git a/tests/models/multimodal/processing/test_minimax_vl_01.py b/tests/models/multimodal/processing/test_minimax_vl_01.py index 10de28ab54ce..9bd2b9887294 100644 --- a/tests/models/multimodal/processing/test_minimax_vl_01.py +++ b/tests/models/multimodal/processing/test_minimax_vl_01.py @@ -7,14 +7,14 @@ from vllm.multimodal.parse import ImageSize from vllm.multimodal.processing import BaseMultiModalProcessor -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"]) @pytest.mark.parametrize("num_imgs", [1, 2]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, num_imgs: int, ): diff --git a/tests/models/multimodal/processing/test_phi3v.py b/tests/models/multimodal/processing/test_phi3v.py index ed0d04c5c5f5..b53351544c45 100644 --- a/tests/models/multimodal/processing/test_phi3v.py +++ b/tests/models/multimodal/processing/test_phi3v.py @@ -4,7 +4,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -22,7 +22,7 @@ @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict[str, int], expected_toks_per_img: int, diff --git a/tests/models/multimodal/processing/test_phi4mm.py b/tests/models/multimodal/processing/test_phi4mm.py index 797986adba4a..c6e272650e08 100644 --- a/tests/models/multimodal/processing/test_phi4mm.py +++ b/tests/models/multimodal/processing/test_phi4mm.py @@ -4,7 +4,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -22,7 +22,7 @@ @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict[str, int], expected_toks_per_img: int, diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py index d8c2ca414d41..02abe1ca8b02 100644 --- a/tests/models/multimodal/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -4,7 +4,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -19,7 +19,7 @@ @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict[str, object], expected_toks_per_img: int, diff --git a/tests/models/multimodal/processing/test_smolvlm.py b/tests/models/multimodal/processing/test_smolvlm.py index 56edc58a71ba..224d1bcedb96 100644 --- a/tests/models/multimodal/processing/test_smolvlm.py +++ b/tests/models/multimodal/processing/test_smolvlm.py @@ -5,7 +5,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -21,7 +21,7 @@ @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict[str, object], expected_toks_per_img: int, diff --git a/tests/models/quantization/test_awq.py b/tests/models/quantization/test_awq.py index c02c3d90e345..597c8e48fb64 100644 --- a/tests/models/quantization/test_awq.py +++ b/tests/models/quantization/test_awq.py @@ -7,7 +7,7 @@ from vllm.multimodal.image import rescale_image_size -from ...conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets +from ...conftest import IMAGE_ASSETS, ImageTestAssets, VllmRunner from ..utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ @@ -20,7 +20,7 @@ def run_awq_test( vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, + image_assets: ImageTestAssets, source_model: str, quant_model: str, *, diff --git a/tests/multimodal/assets/image1.png b/tests/multimodal/assets/image1.png new file mode 100644 index 000000000000..17c7d4cdffe9 Binary files /dev/null and b/tests/multimodal/assets/image1.png differ diff --git a/tests/multimodal/assets/image2.png b/tests/multimodal/assets/image2.png new file mode 100644 index 000000000000..0f13ce5d983d Binary files /dev/null and b/tests/multimodal/assets/image2.png differ diff --git a/tests/multimodal/test_hasher.py b/tests/multimodal/test_hasher.py new file mode 100644 index 000000000000..17b36b36888d --- /dev/null +++ b/tests/multimodal/test_hasher.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image, ImageDraw + +from vllm.multimodal.hasher import MultiModalHasher + +ASSETS_DIR = Path(__file__).parent / "assets" +assert ASSETS_DIR.exists() + + +# NOTE: Images that are the same visually are allowed to have the same hash +@pytest.mark.parametrize("mode_pair", [("1", "L"), ("RGBA", "CMYK")]) +def test_hash_collision_image_mode(mode_pair): + mode1, mode2 = mode_pair + image1 = Image.new(mode1, size=(10, 10), color=1) + image2 = Image.new(mode2, size=(10, 10), color=1) + + hasher = MultiModalHasher + assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2) + + +def test_hash_collision_image_palette(): + # These images differ only in Image.palette._palette + image1 = Image.open(ASSETS_DIR / "image1.png") + image2 = Image.open(ASSETS_DIR / "image2.png") + + hasher = MultiModalHasher + assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2) + + +def test_hash_collision_image_transpose(): + image1 = Image.new("1", size=(10, 20)) + ImageDraw.Draw(image1).line([(0, 0), (10, 0)]) + + image2 = Image.new("1", size=(20, 10)) + ImageDraw.Draw(image2).line([(0, 0), (0, 10)]) + + hasher = MultiModalHasher + assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2) + + +def test_hash_collision_tensor_shape(): + # The hash should be different though the data is the same when flattened + arr1 = torch.zeros((5, 10, 20, 3)) + arr2 = torch.zeros((10, 20, 5, 3)) + + hasher = MultiModalHasher + assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2) + + +def test_hash_collision_array_shape(): + # The hash should be different though the data is the same when flattened + arr1 = np.zeros((5, 10, 20, 3)) + arr2 = np.zeros((10, 20, 5, 3)) + + hasher = MultiModalHasher + assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2) diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index 314ec90e34f9..1a20228765e8 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -3,6 +3,7 @@ import importlib.util import pytest +import torch DTYPE = ["bfloat16"] @@ -21,5 +22,30 @@ def test_pre_quantized_model(vllm_runner): print(output) +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.parametrize( + "pt_load_map_location", + [ + "cuda:0", + # {"": "cuda"}, + ]) +def test_opt_125m_int4wo_model_loading_with_params(vllm_runner, + pt_load_map_location): + """ + Test loading roberta-base model with no lm_head. + """ + torch._dynamo.reset() + model_name = "jerryzh168/opt-125m-int4wo" + with vllm_runner(model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location=pt_load_map_location) as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + print(output) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_config.py b/tests/test_config.py index f2155d954db0..7db95e3f6450 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,8 @@ import pytest -from vllm.config import ModelConfig, PoolerConfig, config, get_field +from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, + config, get_field) from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform @@ -410,3 +411,16 @@ def test_generation_config_loading(): override_generation_config=override_generation_config) assert model_config.get_diff_sampling_param() == override_generation_config + + +@pytest.mark.parametrize("pt_load_map_location", [ + "cuda", + { + "": "cuda" + }, +]) +def test_load_config_pt_load_map_location(pt_load_map_location): + load_config = LoadConfig(pt_load_map_location=pt_load_map_location) + config = VllmConfig(load_config=load_config) + + assert config.load_config.pt_load_map_location == pt_load_map_location diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index ee4e95856f23..9987688b02fa 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1165,3 +1165,80 @@ def test_kv_connector_handles_preemption(): # All memory should be freed since nothing is running. assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ == NUM_BLOCKS - 1 + + +def make_output(scheduler: Scheduler): + return ModelRunnerOutput( + req_ids=[req.request_id for req in scheduler.running], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(scheduler.running) + }, + sampled_token_ids=[[1000]] * len(scheduler.running), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + +def assert_scheduler_empty(scheduler: Scheduler): + """Confirm the scheduler is "empty" - i.e. no leaks.""" + # Scheduler Metadata. + assert len(scheduler.requests) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.finished_req_ids) == 0 + assert len(scheduler._cached_reqs_data) == 0 + + # EncoderCacheManager. + assert len(scheduler.encoder_cache_manager.freed) == 0 + assert len(scheduler.encoder_cache_manager.cached) == 0 + + # KVCache Manager. + assert len(scheduler.kv_cache_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 + assert len(scheduler.kv_cache_manager.num_cached_block) == 0 + num_free_blocks = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + assert num_free_blocks == ( + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. + for block in scheduler.kv_cache_manager.block_pool.blocks: + assert block.ref_cnt == 0 + # assert block._block_hash is None + # assert ( + # len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block + # ) == 0) + + +def test_memory_leak(): + """Test that we do not have a memory leak.""" + + scheduler = create_scheduler(enable_prefix_caching=True) + + NUM_REQUESTS = 5 + NUM_TOKENS = 10 + MAX_TOKENS = 10 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + + # Add each request. + for request in requests: + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Iterate until done. + while True: + scheduler_output = scheduler.schedule() + if len(scheduler.running) == 0: + break + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm no memory leak. + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/__init__.py b/tests/v1/kv_connector/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/kv_connector/run_accuracy_test.sh b/tests/v1/kv_connector/run_accuracy_test.sh new file mode 100644 index 000000000000..9679a070525f --- /dev/null +++ b/tests/v1/kv_connector/run_accuracy_test.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +set -xe + +# Model to run. +MODEL_NAME=Qwen/Qwen3-0.6B + +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Prefill instance. +CUDA_VISIBLE_DEVICES=0 NIXL_ROLE="SENDER" vllm serve $MODEL_NAME \ + --port 8100 \ + --enforce-eager \ + --disable-log-requests \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' & + +# Decode instance. +CUDA_VISIBLE_DEVICES=1 NIXL_ROLE="RECVER" vllm serve $MODEL_NAME \ + --port 8200 \ + --enforce-eager \ + --disable-log-requests \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' & + +# wait until prefill and decode instances are ready +wait_for_server 8100 +wait_for_server 8200 + +# Proxy server. +python ${GIT_ROOT}/tests/v1/kv_connector/toy_proxy_server.py --port 8192 & + +# Run lm eval. +python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/test_accuracy.py diff --git a/tests/v1/kv_connector/test_accuracy.py b/tests/v1/kv_connector/test_accuracy.py new file mode 100644 index 000000000000..60878a664eb9 --- /dev/null +++ b/tests/v1/kv_connector/test_accuracy.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +import lm_eval + +MODEL_NAME = "Qwen/Qwen3-0.6B" +NUM_CONCURRENT = 100 +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 +EXPECTED_VALUE = 0.41 + + +def test_accuracy(): + """Run the end to end accuracy test.""" + + model_args = (f"model={MODEL_NAME}," + f"base_url=http://localhost:8192/v1/completions," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/v1/kv_connector/test_nixl_connector.py b/tests/v1/kv_connector/test_nixl_connector.py new file mode 100644 index 000000000000..684823408c94 --- /dev/null +++ b/tests/v1/kv_connector/test_nixl_connector.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorMetadata) + +from .utils import create_request, create_scheduler, create_vllm_config + + +def test_scheduler_worker_inferface(): + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_id = request.request_id + + scheduler.add_request(request) + + # Remote Prefill, triggers NixlConnectorMetdata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, NixlConnectorMetadata) + + assert len(kv_connector_metadata.requests) == 1 + assert request_id in kv_connector_metadata.requests + req_meta = kv_connector_metadata.requests[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, + scheduler.kv_cache_manager.req_to_blocks[request_id]): + assert block_id == block.block_id diff --git a/tests/v1/kv_connector/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/test_remote_decode_lifecycle.py new file mode 100644 index 000000000000..bfe97efeb3ee --- /dev/null +++ b/tests/v1/kv_connector/test_remote_decode_lifecycle.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from .utils import (assert_scheduler_empty, create_model_runner_output, + create_request, create_scheduler, create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a Remote Decode request.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + + # Ensure the request is finished after 1 tokens. + assert request.is_finished() + assert request.status == RequestStatus.FINISHED_REMOTE_DECODE + output = engine_core_outputs.outputs[0] + assert output.finish_reason == FinishReason.REMOTE_DECODE + assert output.kv_transfer_params is not None + + # Request freed in Scheduler and in Persistent Batch ... + assert request_id in scheduler.finished_req_ids + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 + + # ... but blocks should not be freed. + blocks = scheduler.kv_cache_manager.req_to_blocks[request_id] + for block in blocks: + assert block.ref_cnt == 1 + + # STEP (2): Send Finished to PB. + # (2a): schedule() - pass finished request to PB. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 0 + assert len(scheduler_output.finished_req_ids) == 1 + assert request_id in scheduler_output.finished_req_ids + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.finished_req_ids) == 0 + + # (2b): execute_model() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (2c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP (3): Finished sending. + # (3a): schedule() - pass finished request to PB. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 0 + assert len(scheduler_output.finished_req_ids) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.finished_req_ids) == 0 + + # (3b): execute_model() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_sending = [request_id] + + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm we do not have any memory leaks after req lifecycle. + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py new file mode 100644 index 000000000000..91fcbf53fb3d --- /dev/null +++ b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from .utils import (assert_scheduler_empty, create_model_runner_output, + create_request, create_scheduler, create_vllm_config) + + +def test_basic_lifecycle(): + """Test Remote Prefills Lifecycle.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + START_FREE_BLOCK_QUEUE_SIZE = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): + # (1a): schedule() + scheduler_output = scheduler.schedule() + + # Nothing running and empty scheduler output. + assert len(scheduler.running) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler_output.num_scheduled_tokens) == 0 + assert scheduler_output.total_num_scheduled_tokens == 0 + + # Req waiting for KVs with no computed/scheduled toks ... + assert len(scheduler.waiting) == 1 + assert request in scheduler.waiting + assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + assert (request.num_computed_tokens == 0) + + # ... but should have (uncached) blocks allocated to it. + block_pool = scheduler.kv_cache_manager.block_pool + assert (block_pool.free_block_queue.num_free_blocks + < START_FREE_BLOCK_QUEUE_SIZE) + assert len(block_pool.cached_block_hash_to_block) == 0 + for block in scheduler.kv_cache_manager.req_to_blocks[request_id]: + assert block._block_hash is None + + # (1b): forward() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(engine_core_outputs.outputs) == 0 + + # STEP (2): + # (2a): schedule(): nothing happens! + scheduler_output = scheduler.schedule() + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 0 + + # (2b): forward(): request finishes recv. + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + + # (2c): update_from_output(): + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_KV_req_ids) + + # STEP (3): + # (3a): schedule(): this should actually schedule. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + # Confirm the block are actually allocated. + num_hashed_blocks = 0 + for block in scheduler.kv_cache_manager.req_to_blocks[request_id]: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS + + # Confirm the rest of the prompt is scheduled in this step. + scheduled_req = scheduler_output.scheduled_new_reqs[0] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] + num_computed_tokens = scheduled_req.num_computed_tokens + total_prompt_tokens = len(scheduled_req.prompt_token_ids) + assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + + # (3b): execute_model() + model_runner_output = create_model_runner_output([request]) + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs.outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) + + +def test_interleaved_lifecycle(): + """Test Remote Prefills Work Well With Other Requests.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_local_a = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + ) + request_local_b = create_request( + request_id=3, + num_tokens=NUM_TOKENS, + ) + + # STEP 1: Regular request is running. + scheduler.add_request(request_local_a) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + model_runner_output = create_model_runner_output([request_local_a]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 2: Add a local and remote request. + scheduler.add_request(request_local_b) + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 1 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 3: continue running, KVs not arrived yet. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + reqs=[request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + # STEP 4: KVs arrive. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b], + finished_recving=[request_remote.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 5: RECVed KVs are sent to ModelRunner. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 3 + assert len(scheduler.waiting) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b, request_remote]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 6: Hit EOS and free. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b, request_remote], + use_eos=True, + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler.schedule() + assert_scheduler_empty(scheduler) + + +def test_no_spurious_prefix_caching(): + """ + With P/D, blocks can be allocated but uncomputed for + multiple engine steps. This test confirms that we do + not accidentally have cache hits against uncomputed + blocks. + """ + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 and a half full external blocks. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + # Both of these requests have prompts like [1,1,1,1,1, ...] + request_remote = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + use_all_1s_for_prompt_tokens=True, + ) + + request_local = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + do_remote_prefill=False, + use_all_1s_for_prompt_tokens=True, + ) + + # Schedule the remote prefill request. This should not + # cause any blocks to be cached. + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + assert len(scheduler.waiting) == 1 + + # Schedule the local prefill request. This should + # cause blocks to be cached, but separately from + scheduler.add_request(request_local) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + local_blocks = scheduler.kv_cache_manager.req_to_blocks[ + request_local.request_id] + remote_blocks = scheduler.kv_cache_manager.req_to_blocks[ + request_remote.request_id] + + # Local should have cached blocks (but not all due to preallocate). + num_hashed_blocks = 0 + for block in local_blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks > 0 + + # Remote blocks should not be cached. + for block in remote_blocks: + assert block.ref_cnt == 1 + assert block._block_hash is None diff --git a/tests/v1/kv_connector/toy_proxy_server.py b/tests/v1/kv_connector/toy_proxy_server.py new file mode 100644 index 000000000000..89e3c4493fdb --- /dev/null +++ b/tests/v1/kv_connector/toy_proxy_server.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' + decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict, request_id: str): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['do_remote_decode'] = True + req_data["stream"] = False + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict, remote_block_ids: list[int], + remote_engine_id: str, request_id: str): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + req_data['do_remote_prefill'] = True + req_data["remote_block_ids"] = remote_block_ids + req_data['remote_engine_id'] = remote_engine_id + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + try: + req_data = await request.json() + + request_id = str(uuid.uuid4()) + + # Send request to prefill service + response = await send_request_to_service(app.state.prefill_client, + "/completions", req_data, + request_id) + + # Extract the needed fields + response_json = response.json() + remote_block_ids = response_json.get('remote_block_ids', []) + remote_engine_id = response_json.get('remote_engine_id', '') + + # Add these to the request data for the decoder + req_data['remote_block_ids'] = remote_block_ids + req_data['remote_engine_id'] = remote_engine_id + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response( + app.state.decode_client, + "/completions", + req_data, + remote_block_ids=remote_block_ids, + remote_engine_id=remote_engine_id, + request_id=request_id): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py new file mode 100644 index 000000000000..0387cd58ab0f --- /dev/null +++ b/tests/v1/kv_connector/utils.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, + ModelConfig, SchedulerConfig, VllmConfig) +from vllm.sampling_params import KVTransferParams, SamplingParams +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +EOS_TOKEN_ID = 50256 + + +def assert_scheduler_empty(scheduler: Scheduler): + """Confirm the scheduler is "empty" - i.e. no leaks.""" + # Scheduler Metadata. + assert len(scheduler.requests) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.finished_req_ids) == 0 + assert len(scheduler.finished_recving_KV_req_ids) == 0 + assert len(scheduler._cached_reqs_data) == 0 + + # EncoderCacheManager. + assert len(scheduler.encoder_cache_manager.freed) == 0 + assert len(scheduler.encoder_cache_manager.cached) == 0 + + # KVCache Manager. + assert len(scheduler.kv_cache_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 + assert len(scheduler.kv_cache_manager.num_cached_block) == 0 + num_free_blocks = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + assert num_free_blocks == ( + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. + for block in scheduler.kv_cache_manager.block_pool.blocks: + assert block.ref_cnt == 0 + + +def create_vllm_config( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 64, + block_size: int = 16, +) -> VllmConfig: + """Initialize VllmConfig For Testing.""" + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=True, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + return VllmConfig(scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu")) + + +def create_scheduler( + vllm_config: VllmConfig, + num_blocks: int = 10000, +) -> Scheduler: + """Initialize Scheduler For Testing.""" + block_size = vllm_config.cache_config.block_size + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + vllm_config.cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_request( + request_id: int, + num_tokens: int = 10, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + use_all_1s_for_prompt_tokens: bool = False, +) -> Request: + """Make dummy request for testing.""" + + if do_remote_decode: + assert not do_remote_prefill + kv_transfer_params = KVTransferParams(do_remote_decode=True) + elif do_remote_prefill: + kv_transfer_params = KVTransferParams( + do_remote_prefill=True, + remote_engine_id="remote_engine_id", + remote_block_ids=[1, 2, 3], + ) + else: + kv_transfer_params = None + + sampling_params = SamplingParams( + max_tokens=max_tokens, + kv_transfer_params=kv_transfer_params, + ) + + if use_all_1s_for_prompt_tokens: + prompt_token_ids = [1] * num_tokens + else: + prompt_token_ids = [i * request_id for i in range(num_tokens)] + + return Request( + request_id=f"id-{request_id}", + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + multi_modal_inputs=None, + multi_modal_placeholders=None, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + + +def create_model_runner_output( + reqs: list[Request], + finished_sending: Optional[list[str]] = None, + finished_recving: Optional[list[str]] = None, + use_eos: bool = False, +) -> ModelRunnerOutput: + """Make dummy model runner output for testing.""" + + # Make request data. + req_ids = [req.request_id for req in reqs] + req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)} + + # Make sampled tokens. + sampled_token = EOS_TOKEN_ID if use_eos else 0 + sampled_token_ids = [[sampled_token] for _ in req_ids] + + # Make output data structure. + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index b8ba69b0dd8f..a1bdea687a85 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -31,8 +31,13 @@ def test_deepseek_mla_attn_backend_module(): assert model_runner.attn_backend.__name__ == "TritonMLABackend" -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_prompt(batch_size): +@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) +@pytest.mark.parametrize("use_prompt_embeds", [True, False]) +def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch): + if use_prompt_embeds: + # Prompt Embeddings is only currently supported on V0 + monkeypatch.setenv("VLLM_USE_V1", "0") + model_runner = _create_model_runner( "facebook/opt-125m", max_num_batched_tokens=100000, @@ -43,11 +48,20 @@ def test_prepare_prompt(batch_size): seq_lens: list[int] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = [] block_tables = {0: [1]} + expected_input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) + if use_prompt_embeds: + seq_data = SequenceData.from_seqs( + prompt_token_ids=[0] * seq_len, + prompt_embeds=torch.rand(seq_len, 10), + ) + expected_input_embeds_len += seq_len + else: + seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len)) + seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -68,6 +82,7 @@ def test_prepare_prompt(batch_size): seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens slot_mapping = attn_metadata.slot_mapping @@ -121,7 +136,11 @@ def test_prepare_prompt(batch_size): assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) - torch.testing.assert_close(input_tokens, input_positions) + if expected_input_embeds_len == 0: + torch.testing.assert_close(input_tokens, input_positions) + assert input_embeds is None + else: + assert len(input_embeds) == expected_input_embeds_len sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, @@ -145,8 +164,13 @@ def test_prepare_prompt(batch_size): torch.testing.assert_close(actual, expected) -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_decode_cuda_graph(batch_size): +@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) +@pytest.mark.parametrize("use_prompt_embeds", [True, False]) +def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): + if use_prompt_embeds: + # Prompt Embeddings is only currently supported on V0 + monkeypatch.setenv("VLLM_USE_V1", "0") + model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -164,10 +188,19 @@ def test_prepare_decode_cuda_graph(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) - seq_data = SequenceData.from_seqs(range(context_len)) + if use_prompt_embeds: + seq_data = SequenceData.from_seqs( + prompt_token_ids=[0] * context_len, + prompt_embeds=torch.rand(context_len, 10), + ) + output_embed = torch.rand(10) + else: + seq_data = SequenceData.from_seqs( + prompt_token_ids=range(context_len)) + output_embed = None seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. - seq_data.append_token_id(1, 0) + seq_data.append_token_id(1, 0, output_embed) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -180,9 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size): model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata, slot_mapping = ( - model_input.input_tokens, model_input.input_positions, - model_input.attn_metadata, model_input.attn_metadata.slot_mapping) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds + attn_metadata = model_input.attn_metadata + slot_mapping = attn_metadata.slot_mapping + assert len(slot_mapping) == len(input_tokens) expected_bs = model_runner.vllm_config.pad_for_cudagraph( @@ -227,7 +263,7 @@ def test_prepare_decode_cuda_graph(batch_size): # block table's first index corresponds to each batch, meaning in # decoding it is each token. assert attn_metadata.block_tables.shape[0] == len(input_tokens) - # Block table's second dim correspondsd to each token's block number. + # Block table's second dim corresponds to each token's block number. # It is padded up to assert attn_metadata.block_tables.shape[1] == ( model_runner.get_max_block_per_batch()) @@ -235,7 +271,12 @@ def test_prepare_decode_cuda_graph(batch_size): assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs - torch.allclose(input_tokens, input_positions) + if use_prompt_embeds: + expected_input_embeds_length = start_loc[-1] + assert len(input_embeds) == expected_input_embeds_length + assert expected_input_embeds_length <= expected_bs + else: + assert input_embeds is None # Verify Sampling expected_selected_token_indices = [] @@ -266,25 +307,27 @@ def test_empty_seq_group(): seq_group_metadata_list: list[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - ) + + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + assert input_tokens is None assert input_positions is None assert attn_metadata is None model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata, return_seq_lens) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - model_input.seq_lens, - ) + + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + assert input_tokens is None assert input_positions is None + assert input_embeds is None assert attn_metadata is None assert return_seq_lens is None @@ -299,9 +342,15 @@ def distributed_init(): ensure_model_parallel_initialized(1, 1) -@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("batch_size", list(range(2, 128, 3))) @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_hybrid_batches(batch_size, enforce_eager, distributed_init): +@pytest.mark.parametrize('use_prompt_embeds', [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds, + distributed_init, monkeypatch): + if use_prompt_embeds: + # Prompt Embeddings is only currently supported on V0 + monkeypatch.setenv("VLLM_USE_V1", "0") + model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -320,11 +369,20 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): block_tables = {0: [1]} prefill_batch_size = batch_size // 2 decode_batch_size = batch_size - prefill_batch_size + expected_input_embeds_len = 0 for i in range(prefill_batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) + if use_prompt_embeds: + seq_data = SequenceData.from_seqs( + prompt_token_ids=[0] * seq_len, + prompt_embeds=torch.rand(seq_len, 10), + ) + expected_input_embeds_len += seq_len + else: + seq_data = SequenceData.from_seqs( + prompt_token_ids=range(seq_len), ) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -340,8 +398,21 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData.from_seqs(range(context_len)) - seq_data.append_token_id(1, 0) + if use_prompt_embeds: + seq_data = SequenceData.from_seqs( + prompt_token_ids=[0] * context_len, + prompt_embeds=torch.rand(context_len, 10), + ) + output_embed = torch.rand(10) + # This also iterates the expected input_embeds, because the model + # needs both the input and output embeddings passed into together + expected_input_embeds_len += 1 + else: + seq_data = SequenceData.from_seqs( + prompt_token_ids=range(context_len), ) + output_embed = None + assert len(seq_data.prompt_token_ids) == context_len + seq_data.append_token_id(1, 0, output_embed) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -355,11 +426,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - ) + + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds + attn_metadata = model_input.attn_metadata prefill_meta_actual = attn_metadata.prefill_metadata decode_meta_actual = attn_metadata.decode_metadata @@ -369,6 +440,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert attn_metadata.num_prefills == prefill_batch_size assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_prefill_tokens == sum(seq_lens) + if expected_input_embeds_len == 0: + assert input_embeds is None + else: + assert len(input_embeds) == expected_input_embeds_len # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7bb01507ac2c..64f4310151cd 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -559,7 +559,6 @@ def cutlass_scaled_mm(a: torch.Tensor, scale_a.shape * [1, 128] == a.shape scale_b.shape * [128, 128] == b.shape """ - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert bias is None or bias.shape[0] == b.shape[ 1] and bias.dtype == out_dtype @@ -567,7 +566,8 @@ def cutlass_scaled_mm(a: torch.Tensor, m = a.shape[0] n = b.shape[1] - if current_platform.is_rocm(): + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if current_platform.is_rocm() or not cutlass_compatible_b: triton_scaled_mm_module = importlib.import_module( "vllm.model_executor.layers.quantization.compressed_tensors." "triton_scaled_mm") diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index 0203dc092a71..a21eb7f599fa 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -18,19 +18,25 @@ ASSET_DIR = "multimodal_asset" +AudioAssetName = Literal["winning_call", "mary_had_lamb"] + @dataclass(frozen=True) class AudioAsset: - name: Literal["winning_call", "mary_had_lamb"] + name: AudioAssetName + + @property + def filename(self) -> str: + return f"{self.name}.ogg" @property def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: - audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", + audio_path = get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) def get_local_path(self) -> Path: - return get_vllm_public_assets(filename=f"{self.name}.ogg", + return get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR) @property diff --git a/vllm/assets/image.py b/vllm/assets/image.py index 2b1d258da9c7..d8cca9b74edd 100644 --- a/vllm/assets/image.py +++ b/vllm/assets/image.py @@ -10,10 +10,12 @@ VLM_IMAGES_DIR = "vision_model_images" +ImageAssetName = Literal["stop_sign", "cherry_blossom"] + @dataclass(frozen=True) class ImageAsset: - name: Literal["stop_sign", "cherry_blossom"] + name: ImageAssetName @property def pil_image(self) -> Image.Image: diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 133e18b68e25..bf06746a9ff6 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from functools import lru_cache -from typing import Literal, Optional +from typing import ClassVar, Literal, Optional import cv2 import numpy as np @@ -76,20 +76,31 @@ def video_to_pil_images_list(path: str, ] +VideoAssetName = Literal["baby_reading"] + + @dataclass(frozen=True) class VideoAsset: - name: Literal["sample_demo_1.mp4"] + name: VideoAssetName num_frames: int = -1 + _NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = { + "baby_reading": "sample_demo_1.mp4", + } + + @property + def filename(self) -> str: + return self._NAME_TO_FILE[self.name] + @property def pil_images(self) -> list[Image.Image]: - video_path = download_video_asset(self.name) + video_path = download_video_asset(self.filename) ret = video_to_pil_images_list(video_path, self.num_frames) return ret @property def np_ndarrays(self) -> npt.NDArray: - video_path = download_video_asset(self.name) + video_path = download_video_asset(self.filename) ret = video_to_ndarrays(video_path, self.num_frames) return ret @@ -99,5 +110,5 @@ def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: See also: examples/offline_inference/qwen2_5_omni/only_thinker.py """ - video_path = download_video_asset(self.name) + video_path = download_video_asset(self.filename) return librosa.load(video_path, sr=sampling_rate)[0] diff --git a/vllm/attention/backends/cpu_mla.py b/vllm/attention/backends/cpu_mla.py index 528df2e98679..4567893a9ef7 100644 --- a/vllm/attention/backends/cpu_mla.py +++ b/vllm/attention/backends/cpu_mla.py @@ -281,8 +281,7 @@ def _forward_prefill( # remove padding output = output.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]] - output = output.reshape(-1, self.num_heads * v.shape[-1]) - return self.o_proj(output)[0] + return output.reshape(-1, self.num_heads * v.shape[-1]) def _forward_decode( self, @@ -303,4 +302,4 @@ def _forward_decode( ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale, decode_meta.block_tables, decode_meta.seq_lens_tensor) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index d92177d58a48..37b20d0739f7 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -367,9 +367,17 @@ def begin_forward(self, model_input): # scheduled while CUDA graph mode is enabled. We don't run graph in that # case. if use_cuda_graph and is_decode: - batch_size = model_input.input_tokens.shape[0] - state = (self.runner.graph_runners[model_input.virtual_engine] - [batch_size].attn_state) + if model_input.inputs_embeds is None: + batch_size = model_input.input_tokens.shape[0] + state = ( + self.runner.graph_runners[model_input.virtual_engine][( + batch_size, False)].attn_state) + else: + batch_size = model_input.inputs_embeds.shape[0] + state = ( + self.runner.graph_runners[model_input.virtual_engine][( + batch_size, True)].attn_state) + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( ) model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index 5d0c23093310..0e62748ddbee 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -239,4 +239,4 @@ def _forward_decode( causal=True, ) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 382a9a6d44d8..12d85b74244f 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -207,7 +207,7 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, RowParallelLinear, + LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) @@ -1032,12 +1032,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, rotary_emb: RotaryEmbedding, - # q_proj should be q_b_proj if q_lora_rank is not None, but from an - # attention backend perspective we rely on the layer to pass in the - # correct matrix - q_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear, - o_proj: RowParallelLinear, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -1055,9 +1050,7 @@ def __init__( self.rotary_emb = rotary_emb self.use_yarn_rope = isinstance(rotary_emb, DeepseekScalingRotaryEmbedding) - self.q_proj = q_proj self.kv_b_proj = kv_b_proj - self.o_proj = o_proj self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn @@ -1141,27 +1134,13 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, return attn_out, rest[0] return attn_out - def _v_up_proj_and_o_proj(self, x): + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return self.o_proj(x)[0] - - # Return `ql_nope`, `q_pe` - def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = self.q_proj(x)[0]\ - .view(-1, self.num_heads, self.qk_head_dim)\ - .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - return ql_nope.transpose(0, 1), q_pe + return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -1345,7 +1324,7 @@ def _forward_prefill( suffix_lse=suffix_lse, ) - return self.o_proj(output.flatten(start_dim=-2))[0] + return output.flatten(start_dim=-2) @abstractmethod def _forward_decode( @@ -1360,7 +1339,7 @@ def _forward_decode( def forward( self, layer: AttentionLayer, - hidden_states_or_q_c: torch.Tensor, # query in unified attn + q: torch.Tensor, # query in unified attn k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, @@ -1391,27 +1370,32 @@ def forward( assert hasattr(attn_metadata, "input_positions") num_prefill_tokens: int = attn_metadata.num_prefill_tokens + q = q.view(-1, self.num_heads, self.qk_head_dim) - decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:] + decode_q = q[num_prefill_tokens:] decode_k_pe = k_pe[num_prefill_tokens:] decode_input_positions = \ attn_metadata.input_positions[num_prefill_tokens:] - prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens] + prefill_q = q[:num_prefill_tokens] prefill_k_pe = k_pe[:num_prefill_tokens] prefill_input_positions = \ attn_metadata.input_positions[:num_prefill_tokens] prefill_k_c_normed = k_c_normed[:num_prefill_tokens] if has_decode: - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_input_positions, decode_q_pe, decode_k_pe) if has_prefill: - prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( prefill_input_positions, prefill_q_pe, prefill_k_pe) @@ -1429,9 +1413,9 @@ def forward( output = torch.empty(attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens, - self.o_proj.output_size, - device=hidden_states_or_q_c.device, - dtype=hidden_states_or_q_c.dtype) + self.v_head_dim * self.num_heads, + device=q.device, + dtype=q.dtype) if has_prefill: output[:num_prefill_tokens] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 6e695b78e0e1..2984bc1dad64 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -409,4 +409,4 @@ def _forward_decode( attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_lens) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 61e5c76d9fda..6945c2c6e29c 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -110,4 +110,4 @@ def _forward_decode( decode_meta.seq_lens_tensor, attn_logits, num_kv_splits, self.scale, PAGE_SIZE) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 1b47581641b0..759b3d8536dd 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -289,7 +289,7 @@ def chunked_prefill_paged_decode( max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM) assert _PARTITION_SIZE_ROCM % block_size == 0 - total_num_seq = query.shape[0] + total_num_seq = block_table.shape[0] tmp_output = torch.empty( size=(total_num_seq, num_query_heads, max_num_partitions, head_size), diff --git a/vllm/config.py b/vllm/config.py index c2995cacaeb6..5ca70f2f67b6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,6 +9,7 @@ import re import sys import textwrap +import uuid import warnings from collections import Counter from contextlib import contextmanager @@ -268,7 +269,7 @@ class ModelConfig: It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" rope_scaling: dict[str, Any] = field(default_factory=dict) - """RoPE scaling configuration in JSON format. For example, + """RoPE scaling configuration. For example, `{"rope_type":"dynamic","factor":2.0}`.""" rope_theta: Optional[float] = None """RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE @@ -346,14 +347,13 @@ class ModelConfig: (stored in `~/.huggingface`).""" hf_overrides: HfOverrides = field(default_factory=dict) """If a dictionary, contains arguments to be forwarded to the Hugging Face - config. If a callable, it is called to update the HuggingFace config. When - specified via CLI, the argument must be a valid JSON string.""" + config. If a callable, it is called to update the HuggingFace config.""" mm_processor_kwargs: Optional[dict[str, Any]] = None """Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. Overrides for the multi-modal processor obtained from `AutoProcessor.from_pretrained`. The available overrides depend on the model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. - When specified via CLI, the argument must be a valid JSON string.""" + """ disable_mm_preprocessor_cache: bool = False """If `True`, disable caching of the multi-modal preprocessor/mapper (not recommended).""" @@ -361,15 +361,14 @@ class ModelConfig: """Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to configure the neuron config that can not be gathered from the vllm - arguments. e.g. `{"cast_logits_dtype": "bloat16"}`. When specified via CLI, - the argument must be a valid JSON string.""" + arguments. e.g. `{"cast_logits_dtype": "bloat16"}`.""" pooler_config: Optional["PoolerConfig"] = field(init=False) """Pooler config which controls the behaviour of output pooling in pooling models.""" override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None """Initialize non-default pooling config or override default pooling config for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`. - When specified via CLI, the argument must be a valid JSON string.""" + """ logits_processor_pattern: Optional[str] = None """Optional regex pattern specifying valid logits processor qualified names that can be passed with the `logits_processors` extra completion argument. @@ -385,8 +384,7 @@ class ModelConfig: """Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If used with `--generation-config auto`, the override parameters will be merged with the default config from the model. If used with - `--generation-config vllm`, only the override parameters are used. - When specified via CLI, the argument must be a valid JSON string.""" + `--generation-config vllm`, only the override parameters are used.""" enable_sleep_mode: bool = False """Enable sleep mode for the engine (only cuda platform is supported).""" model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value @@ -1556,14 +1554,23 @@ class LoadConfig: cache directory of Hugging Face.""" model_loader_extra_config: dict = field(default_factory=dict) """Extra config for model loader. This will be passed to the model loader - corresponding to the chosen load_format. This should be a JSON string that - will be parsed into a dictionary.""" + corresponding to the chosen load_format.""" ignore_patterns: Optional[Union[list[str], str]] = None """The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints.""" use_tqdm_on_load: bool = True """Whether to enable tqdm for showing progress bar when loading model weights.""" + pt_load_map_location: Union[str, dict[str, str]] = "cpu" + """ + pt_load_map_location: the map location for loading pytorch checkpoint, to + support loading checkpoints can only be loaded on certain devices like + "cuda", this is equivalent to {"": "cuda"}. Another supported format is + mapping from different devices like from GPU 1 to GPU 0: + {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings + in dictionary needs to be double quoted for json parsing. For more details, + see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html + """ def compute_hash(self) -> str: """ @@ -2816,7 +2823,6 @@ class MultiModalConfig: "limit_mm_per_prompt") """ The maximum number of input items allowed per prompt for each modality. - This should be a JSON string that will be parsed into a dictionary. Defaults to 1 (V0) or 999 (V1) for each modality. For example, to allow up to 16 images and 2 videos per prompt: @@ -3135,6 +3141,14 @@ def _get_and_verify_max_len( # derived length from the HF model config. if max_model_len is None: max_model_len = int(derived_max_model_len) + if current_platform.is_tpu(): + logger.warning( + "--max-model-len is not specified, " + "it's currently using model's default length %s, " + "which might be too large." + "Please input with --max-model-len based on your " + "request input length and output length, to avoid " + "unnecessary degradation.", max_model_len) elif max_model_len > derived_max_model_len: # Some models might have a separate key for specifying model_max_length # that will be bigger than derived_max_model_len. We compare user input @@ -3387,6 +3401,11 @@ class KVTransferConfig(BaseModel): # The KV connector for vLLM to transmit KV caches between vLLM instances. kv_connector: Optional[str] = None + # Engine ID for the KV transfers. + # Note(tms): sticking this here so the engine_id is consistent between + # scheduler-side and worker-side of the KVConnector + engine_id: str = str(uuid.uuid4()) + # The device used by kv connector to buffer the KV cache. # Currently only support 'cuda'. kv_buffer_device: Optional[str] = "cuda" @@ -3396,7 +3415,7 @@ class KVTransferConfig(BaseModel): kv_buffer_size: float = 1e9 # Whether this vLLM instance produces, consumes KV cache, or both. Choices - # are 'kv_producer', 'kv_consumer', and 'both'. + # are 'kv_producer', 'kv_consumer', and 'kv_both'. kv_role: Optional[str] = None # The rank of this vLLM instance in the KV cache transfer. Typical value: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 97d03d5e3b40..06d4ed470b20 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1071,6 +1071,7 @@ def _schedule_prefills( ) ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] + using_prompt_embeds: bool = False waiting_queue = self.waiting @@ -1138,6 +1139,15 @@ def _schedule_prefills( waiting_queue.popleft() continue + # We cannot mix sequence groups that use prompt embeds and + # those that do not. + if len(seq_groups) == 0: + using_prompt_embeds = seq_group.uses_prompt_embeds() + if using_prompt_embeds != seq_group.uses_prompt_embeds(): + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id @@ -1295,17 +1305,39 @@ def _schedule_default(self) -> SchedulerOutputs: # Merge lists num_prefill_groups = len(prefills.seq_groups) + ignored_seq_groups_for_embeds = list[SequenceGroup]() if num_prefill_groups > 0: scheduled_seq_groups = prefills.seq_groups scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) + ignored_seq_groups_for_embeds.clear() else: scheduled_seq_groups = running_scheduled.decode_seq_groups + if len(scheduled_seq_groups) > 0: + using_prompt_embeds = scheduled_seq_groups[ + 0].seq_group.uses_prompt_embeds() + ignored_seq_groups_for_embeds.clear() + indices_ignored = list[int]() + for i, schedule_seq_group in enumerate(scheduled_seq_groups): + if using_prompt_embeds !=\ + schedule_seq_group.seq_group.uses_prompt_embeds(): + ignored_seq_groups_for_embeds.append( + schedule_seq_group.seq_group) + indices_ignored.append(i) + if len(ignored_seq_groups_for_embeds) > 0: + scheduled_seq_groups = [ + group for i, group in enumerate(scheduled_seq_groups) + if i not in indices_ignored + ] + else: + ignored_seq_groups_for_embeds.clear() + scheduled_seq_groups.extend(swapped_in.decode_seq_groups) blocks_to_copy = running_scheduled.blocks_to_copy blocks_to_copy.extend(swapped_in.blocks_to_copy) ignored_seq_groups = prefills.ignored_seq_groups + ignored_seq_groups.extend(ignored_seq_groups_for_embeds) ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) return SchedulerOutputs( diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 6532c101a4f6..54cb1871db3c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -105,3 +105,8 @@ def create_connector_v1( "LMCacheConnectorV1", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", "LMCacheConnectorV1") + +KVConnectorFactory.register_connector( + "NixlConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", + "NixlConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 95967d2ca919..ca9e19156719 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -62,6 +62,20 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self._vllm_config = vllm_config self._role = role + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + return + + def get_finished(self) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + return set(), set() + @property def role(self) -> KVConnectorRole: return self._role @@ -188,6 +202,7 @@ def get_num_new_matched_tokens( @abstractmethod def update_state_after_alloc(self, request: "Request", + block_ids: list[int], num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index e07f185f0dd8..8f86b72f9cff 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -110,6 +110,7 @@ def get_num_new_matched_tokens( request, num_computed_tokens) def update_state_after_alloc(self, request: "Request", + block_ids: list[int], num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py new file mode 100644 index 000000000000..6a75f1522461 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -0,0 +1,642 @@ +# SPDX-License-Identifier: Apache-2.0 +import math +import os +import time +import uuid +from collections import defaultdict +from typing import TYPE_CHECKING, Any + +import msgspec +import torch +import zmq +from typing_extensions import Optional + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.sampling_params import KVTransferParams +from vllm.utils import round_down +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + + +class NixlAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + engine_id: str + agent_metadata: bytes + # Base addr for each layer for KVs + # NOTE: we will need another list for TP>1 + kv_caches_base_addr: list[int] + num_blocks: int + + +class ReqMeta: + + def __init__( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_engine_id: str, + ): + self.local_block_ids = local_block_ids + self.remote_block_ids = remote_block_ids + self.remote_engine_id = remote_engine_id + + +class NixlConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: KVTransferParams, + ): + assert request_id not in self.requests + assert kv_transfer_params.remote_engine_id is not None + assert kv_transfer_params.remote_block_ids is not None + + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params.remote_block_ids, + remote_engine_id=kv_transfer_params.remote_engine_id) + + +class NixlConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + self.engine_id = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler : Optional[NixlConnectorScheduler] = \ + NixlConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_worker: Optional[NixlConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = NixlConnectorWorker(str(self.engine_id)) + + ############################################################ + # Scheduler Side Methods + ############################################################ + def get_num_new_matched_tokens(self, request: "Request", + num_computed_tokens: int) -> int: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + block_ids: list[int], + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, block_ids, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """NixlConnector does not do layerwise saving.""" + return + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """NixlConnector does not save explicitly.""" + return + + def wait_for_save(self): + """NixlConnector does not save explicitly.""" + return + + +class NixlConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + logger.info("Initializing NIXL Scheduler %s", engine_id) + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + + def get_num_new_matched_tokens(self, request: "Request", + num_computed_tokens: int) -> int: + """For remote prefill, allocate for all tokens.""" + + # NOTE: this function is called in the WAITING loop. + # So we should only have full blocks of computed tokens. + assert num_computed_tokens % self.block_size == 0 + + if request.do_remote_prefill: + rounded_num_prompt_tokens = round_down( + len(request.prompt_token_ids), self.block_size) + return max(rounded_num_prompt_tokens - num_computed_tokens, 0) + + return 0 + + def update_state_after_alloc(self, request: "Request", + block_ids: list[int], + num_external_tokens: int): + if request.do_remote_prefill and num_external_tokens > 0: + self._reqs_need_recv[request.request_id] = (request, block_ids) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = NixlConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + + return meta + + +class NixlConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, engine_id: str): + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL wrapper") + logger.info("Initializing NIXL worker %s", engine_id) + + # Agent. + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + # Map of engine_id -> list[agent_names] (1 per rank). + self._remote_agents: dict[str, list[str]] = {} + + # Metadata. + self.engine_id = engine_id + self.rank = 0 + + # KV Caches and nixl tracking data. + self.kv_caches: dict[str, torch.Tensor] = {} + + # Map of engine_id -> kv_caches_base_addr + # For Local: base addr for *this* rank, each layer for K,V + # For Remote: base addr for *each* rank, each layer for K,V + # KV_CACHES_ADDR_TYPE = Union[list[tuple[int, int]], + # list[list[tuple[int, int]]]] + self.kv_caches_base_addr: dict[str, list[int]] = {} + + # Number of NIXL regions. Currently one region per cache + # (so 1 per layer for MLA, otherwise 2 per layer) + self.num_regions = 0 + + # Map of tp_mult -> nixl_prepped_dlist_handle (int). + self.src_xfer_side_handles: dict[int, int] = {} + # Map of engine_id -> map[tp_mult -> nixl_prepped_dlist_handle (int)]. + self.dst_xfer_side_handles: defaultdict[str, + dict[int, + int]] = defaultdict(dict) + # Map of engine_id -> num_blocks. + self.dst_num_blocks: dict[str, int] = {} + self._registered_descs: list[Any] = [] + + # In progress transfers. + # [req_id -> list[handle]] + self._recving_transfers: dict[str, list[Any]] = defaultdict(list[Any]) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in nixl.""" + + first_layer_name = next(iter(kv_caches)) + first_kv_cache = kv_caches[first_layer_name] + + # [2 (k and v), num_blocks, ...] + # TODO(tms): num_blocks will be in a different spot for MLA. + num_blocks = first_kv_cache.shape[1] + kv_elem_size = first_kv_cache[0].element_size() + # TODO(tms): self.block_len needs to be per-layer for sliding window, + # hybrid attn, etc + self.block_len = kv_elem_size * math.prod(first_kv_cache.shape[-3:]) + + logger.debug("Per layer kv cache size: %s", first_kv_cache[0].shape) + self.num_blocks = num_blocks + self.dst_num_blocks[self.engine_id] = num_blocks + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + + # Note(tms): I modified this from the original region setup code. + # K and V are now in different regions. Advantage is that we can + # elegantly support MLA and any cases where the K and V tensors + # are non-contiguous (it's not locally guaranteed that they will be) + # Disadvantage is that the encoded NixlAgentMetadata is now larger + # (roughly 8KB vs 5KB). + for layer_name in kv_caches: + for cache in kv_caches[layer_name]: + base_addr = cache.data_ptr() + region_len = num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank, "")) + kv_caches_base_addr.append(base_addr) + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) + + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + + self._registered_descs.append(descs) + + # THIS IS FOR DEV + _ctx = zmq.Context() # type: ignore + _side_channel = _ctx.socket(zmq.PAIR) # type: ignore + NIXL_ROLE = os.getenv("NIXL_ROLE") + + # FOR DEV, SENDER puts data in its KV caches so the RECVER can check it + if os.environ.get("VLLM_DEBUG_INITIAL_NIXL_PD_XFER") is not None: + n_blocks_to_send = min(4096, kv_caches[first_layer_name].shape[1]) + debug_xfer_gb = 2.0 * n_blocks_to_send * self.block_len / 1e9 + logger.debug( + "Starting initial NIXL PD XFER: Total %s GB, Block len %s KB", + debug_xfer_gb, self.block_len / 1024) + if NIXL_ROLE == "SENDER": + for b in range(n_blocks_to_send): + kv_caches[first_layer_name][0, b, 0, 0, 0] = b + 100.0 + kv_caches[first_layer_name][1, b, 0, 0, 0] = b + 200.0 + + for b in range(5): + logger.debug("%s KV_CACHE coord %s val %f", NIXL_ROLE, + (0, b, 0, 0, 0), + kv_caches[first_layer_name][0, b, 0, 0, 0]) + logger.debug("%s KV_CACHE coord %s val %f", NIXL_ROLE, + (1, b, 0, 0, 0), + kv_caches[first_layer_name][1, b, 0, 0, 0]) + remote_engine_id = None # HACK for debug send + + if NIXL_ROLE == "SENDER": + _side_channel.connect("tcp://localhost:5577") + _side_channel.setsockopt(zmq.LINGER, 0) # type: ignore + metadata = NixlAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.nixl_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + ) + encoded_data = msgspec.msgpack.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", + str(size_in_bytes)) + _side_channel.send(encoded_data) + + logger.debug("WAITING ON RECV") + ack = _side_channel.recv() + logger.debug("GOT ACK %s", ack) + + elif NIXL_ROLE == "RECVER": + _side_channel.bind("tcp://localhost:5577") + _side_channel.setsockopt(zmq.LINGER, 0) # type: ignore + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata_bytes = _side_channel.recv() + metadata = decoder.decode(metadata_bytes) + + remote_engine_id = metadata.engine_id #HACK + + self.add_remote_agent(metadata) + logger.debug("SENDING ACK") + _side_channel.send(b"ack") + + else: + raise Exception("SET NIXL_ROLE to SENDER OR RECVER") + + # FOR DEV, SENDER puts data in its KV caches so the RECVER can check it + + if os.environ.get("VLLM_DEBUG_INITIAL_NIXL_PD_XFER") is not None: + initial_xfer_req_id = "initial_xfer_req_id" + + if NIXL_ROLE == "RECVER": + logger.debug("SENDING BLOCKS") + connector_metadata = NixlConnectorMetadata() + assert remote_engine_id is not None + xfer_params = KVTransferParams( + do_remote_decode=True, + do_remote_prefill=False, + remote_block_ids=list(range(n_blocks_to_send)), + remote_engine_id=remote_engine_id #HACK + ) + + connector_metadata.add_new_req(request_id=initial_xfer_req_id, + local_block_ids=list( + range(n_blocks_to_send)), + kv_transfer_params=xfer_params) + self.start_load_kv(connector_metadata) + + # Wait for Receive to complete + logger.debug("START RECEIVE XFER") + done = False + start_time = time.time() + while (not done): + finished = self.get_finished() + done = initial_xfer_req_id in finished[1] + time.sleep(1e-5) + end_time = time.time() + execution_time = end_time - start_time + logger.debug( + "Transfer Received. Duration: %f ms Bandwidth %f GB/s", + 1e3 * execution_time, debug_xfer_gb / execution_time) + + if NIXL_ROLE == "SENDER": + # Wait for Send to complete + logger.debug("START SEND XFER") + done = False + start_time = time.time() + while (not done): + finished = self.get_finished() + done = initial_xfer_req_id in finished[0] + time.sleep(1e-5) + end_time = time.time() + execution_time = end_time - start_time + logger.debug( + "Transfer Sent. Duration: %f ms Bandwidth %f GB/s", + 1e3 * execution_time, debug_xfer_gb / execution_time) + + # Put some different stuff in there + if NIXL_ROLE == "SENDER": + for b in range(n_blocks_to_send): + kv_caches[first_layer_name][0, b, 0, 0, 0] = b + 300.0 + kv_caches[first_layer_name][1, b, 0, 0, 0] = b + 400.0 + + for b in range(5): + logger.debug("%s KV_CACHE coord %s val %f", NIXL_ROLE, + (0, b, 0, 0, 0), + kv_caches[first_layer_name][0, b, 0, 0, 0]) + logger.debug("%s KV_CACHE coord %s val %f", NIXL_ROLE, + (1, b, 0, 0, 0), + kv_caches[first_layer_name][1, b, 0, 0, 0]) + + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, tp_idx=0): + engine_id = nixl_agent_meta.engine_id + if engine_id in self._remote_agents: + return + + num_blocks = nixl_agent_meta.num_blocks + logger.debug("Adding remote agent %s %s", engine_id, str(num_blocks)) + + agent_names = [ + self.nixl_wrapper.add_remote_agent(nixl_agent_meta.agent_metadata) + ] + + self._remote_agents[engine_id] = agent_names + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + + # NOTE: once we support heterogeneous TP, we will need maintain the + # src for each TP multiplier. + # NOTE(rob): Dynamo only supports D TP size > P TP size. + # https://github.com/vllm-project/vllm/pull/16124/files#diff-876efa5533f5dcff3fba850e8684a47d53c112e287988957c115b11691374f4bR331 # noqa: E501 + # Create descs and xfer side handles. + tp_multiplier = 1 + dst_block_len = self.block_len // tp_multiplier + if tp_multiplier not in self.src_xfer_side_handles: + # Create descs and xfer side handles. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + for i in range(tp_multiplier): + tp_multiplier_offset = tp_idx * dst_block_len + blocks_data.append( + (base_addr + block_offset + tp_multiplier_offset, + dst_block_len, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.src_xfer_side_handles[tp_multiplier] = ( + self.nixl_wrapper.prep_xfer_dlist("", descs)) + + # create dst xfer side handles + self.dst_num_blocks[engine_id] = num_blocks + blocks_data = [] + for base_addr in self.kv_caches_base_addr[engine_id]: + for block_id in range(num_blocks): + block_offset = block_id * dst_block_len + blocks_data.append((base_addr + block_offset, dst_block_len, + self.rank * tp_multiplier)) + logger.debug("Created %s blocks for dst engine %s and rank %s", + len(blocks_data), engine_id, self.rank) + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[engine_id][tp_idx] = ( + self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][self.rank * tp_multiplier + + tp_idx], descs)) + + def get_finished(self) -> tuple[set[str], set[str]]: + """Get requests that are done sending or recving.""" + done_sending = self._get_new_notifs() + done_recving = self._pop_done_transfers(self._recving_transfers) + if len(done_sending) > 0 or len(done_recving) > 0: + logger.debug( + "get_finished: %s requests done sending " + "and %s requests done recving", len(done_sending), + len(done_recving)) + return done_sending, done_recving + + def _get_new_notifs(self) -> set[str]: + """Get req_ids which got a remote xfer message.""" + + notified_req_ids: set[str] = set() + # TODO: handle the TP case (N notifies for TP=N). + # See: vllm/worker/worker_base.py L476 in DynamoPR. + for req_ids in self.nixl_wrapper.get_new_notifs().values(): + for req_id in req_ids: + assert req_id not in notified_req_ids + notified_req_ids.add(req_id.decode("utf-8")) + return notified_req_ids + + def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: + """ + Pop completed xfers by checking for DONE state. + Args: + transfers: dict of req_id -> list[running_xfer] + Returns: + set of req_ids that have all done xfers + """ + done_req_ids: set[str] = set() + for req_id, handles in list(transfers.items()): + running_reqs = [] + for handle in handles: + xfer_state = self.nixl_wrapper.check_xfer_state(handle) + if xfer_state == "DONE": + # TODO ptarasiewicz: why abort is throwing errors? + # self.nixl_wrapper.release_xfer_handle(handle) + continue + if xfer_state == "PROC": + running_reqs.append(handle) + else: + raise RuntimeError("Transfer failed with state %s", + xfer_state) + if len(running_reqs) == 0: + done_req_ids.add(req_id) + del transfers[req_id] + else: + transfers[req_id] = running_reqs + return done_req_ids + + def start_load_kv(self, metadata: NixlConnectorMetadata): + """ + Start loading by triggering non-blocking nixl_xfer. + We check for these trnxs to complete in each step(). + """ + for req_id, meta in metadata.requests.items(): + # NOTE: this is non-blocking + logger.debug( + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + meta.remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) + self._read_blocks( + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + dst_engine_id=meta.remote_engine_id, + request_id=req_id, + ) + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + dst_engine_id: str, + request_id: str, + ): + # NOTE(rob): having the staging blocks be on the READER side is + # not going to work well (since we will have to call rearrange tensors). + # after we detect the txn is complete (which means we cannot make the + # read trxn async easily). If we want to make "READ" happen cleanly, + # then we will need to have the staging blocks on the remote side. + + # NOTE(rob): according to nvidia the staging blocks are used to + # saturate IB with heterogeneous TP sizes. We should remove the staging + # blocks until we are ready. + + # NOTE(rob): we could potentially do the rearranging during the load_kv! + + # Note(tms): The remote_block_ids only contain full computed blocks, + # while the local_block_ids are all blocks allocated for this request, + # so truncate the local_block_ids to account for this. + del local_block_ids[len(remote_block_ids):] + assert len(local_block_ids) == len(remote_block_ids) + + # NOTE(rob): this can cause the remote blocks to not be freed? + if len(local_block_ids) == 0: + return + + # TODO: support TP multipliers. + tp_multiplier = 1 + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, "all", remote_block_ids) + local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] + + # Read the data from the remote. + for i in range(tp_multiplier): + local_block_descs_ids = self._get_block_descs_ids( + self.engine_id, + "all", + local_block_ids, + i=None, #TODO: Enable both tp_multiplier and staging_ranges. + tp_multiplier=tp_multiplier, + staging_ranges=None) + assert len(local_block_descs_ids) == len(remote_block_descs_ids) + remote_xfer_side_handle = self.dst_xfer_side_handles[ + dst_engine_id][i] + + # NOTE(rob): we use the request_id as the notify msg, so we + # must use the same request_id in both the p and d workers. + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=request_id.encode("utf-8"), + ) + + # Call transfer to begin the async transfer + # We will check this is done in the next forward pass. + self.nixl_wrapper.transfer(handle) + self._recving_transfers[request_id].append(handle) + + def _get_block_descs_ids(self, + engine_id, + region_ids, + block_ids, + i=None, + tp_multiplier=1, + staging_ranges=None): + + if region_ids == "all": + region_ids = range(self.num_regions) + if block_ids == "all": + block_ids = range(self.num_blocks) + + descs_ids = [] + + if i is not None: + raise NotImplementedError("Prefill and Decode instances must have " + "the same TP size.") + + num_blocks = self.dst_num_blocks[engine_id] + for reg_id in region_ids: + for block_id in block_ids: + descs_ids.append(reg_id * num_blocks + block_id) + return descs_ids diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index f91ffbc720e7..94bf53d90c91 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -260,6 +260,7 @@ def get_num_new_matched_tokens( return num_tokens_to_check - num_computed_tokens def update_state_after_alloc(self, request: "Request", + block_ids: list[int], num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3cafcb7c31f2..aefba620e189 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -64,6 +64,13 @@ def _optional_type(val: str) -> Optional[T]: return _optional_type +def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: + if not re.match("^{.*}$", val): + return str(val) + else: + return optional_type(json.loads)(val) + + @deprecated( "Passing a JSON argument as a string containing comma separated key=value " "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON " @@ -143,7 +150,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the help text for the field name = field.name - help = cls_docs[name] + help = cls_docs[name].strip() # Escape % for argparse help = help.replace("%", "%%") @@ -158,6 +165,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: type_hints.add(field.type) # Set other kwargs based on the type hints + json_tip = "\n\nShould be a valid JSON string." if contains_type(type_hints, bool): # Creates --no- and -- flags kwargs[name]["action"] = argparse.BooleanOptionalAction @@ -187,9 +195,14 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name]["type"] = human_readable_int elif contains_type(type_hints, float): kwargs[name]["type"] = float + elif contains_type(type_hints, + dict) and (contains_type(type_hints, str) or any( + is_not_builtin(th) for th in type_hints)): + kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): # Dict arguments will always be optional kwargs[name]["type"] = optional_type(json.loads) + kwargs[name]["help"] += json_tip elif (contains_type(type_hints, str) or any(is_not_builtin(th) for th in type_hints)): kwargs[name]["type"] = str @@ -371,6 +384,7 @@ class EngineArgs: reasoning_parser: str = DecodingConfig.reasoning_backend use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load + pt_load_map_location: str = LoadConfig.pt_load_map_location def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -491,6 +505,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default=None, help='Name or path of the QLoRA adapter.') + load_group.add_argument('--pt-load-map-location', + **load_kwargs["pt_load_map_location"]) # Guided decoding arguments guided_decoding_kwargs = get_kwargs(DecodingConfig) @@ -883,12 +899,14 @@ def create_load_config(self) -> LoadConfig: if self.quantization == "bitsandbytes": self.load_format = "bitsandbytes" + return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, use_tqdm_on_load=self.use_tqdm_on_load, + pt_load_map_location=self.pt_load_map_location, ) def create_speculative_config( @@ -1423,8 +1441,8 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: # as the platform that vLLM is running on (e.g. the case of scaling # vLLM with Ray) and has no GPUs. In this case we use the default # values for non-H100/H200 GPUs. + from vllm.platforms import current_platform try: - from vllm.platforms import current_platform device_memory = current_platform.get_device_total_memory() except Exception: # This is only used to set default_max_num_batched_tokens @@ -1445,11 +1463,37 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: } default_max_num_seqs = 256 + # tpu specific default values. + if current_platform.is_tpu(): + default_max_num_batched_tokens_tpu = { + UsageContext.LLM_CLASS: { + 'V6E': 2048, + 'V5E': 1024, + 'V5P': 512, + }, + UsageContext.OPENAI_API_SERVER: { + 'V6E': 1024, + 'V5E': 512, + 'V5P': 256, + } + } + use_context_value = usage_context.value if usage_context else None if (self.max_num_batched_tokens is None and usage_context in default_max_num_batched_tokens): - self.max_num_batched_tokens = default_max_num_batched_tokens[ - usage_context] + if current_platform.is_tpu(): + chip_name = current_platform.get_device_name() + if chip_name in default_max_num_batched_tokens_tpu[ + usage_context]: + self.max_num_batched_tokens = \ + default_max_num_batched_tokens_tpu[ + usage_context][chip_name] + else: + self.max_num_batched_tokens = \ + default_max_num_batched_tokens[usage_context] + else: + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context] logger.debug( "Setting max_num_batched_tokens to %d for %s usage context.", self.max_num_batched_tokens, use_context_value) @@ -1513,7 +1557,7 @@ def _warn_or_fallback(feature_name: str) -> bool: def human_readable_int(value): """Parse human-readable integers like '1k', '2M', etc. Including decimal values with decimal multipliers. - + Examples: - '1k' -> 1,000 - '1K' -> 1,024 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6cc9b881464e..50da9679d5aa 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -489,9 +489,13 @@ async def add_request_async( if arrival_time is None: arrival_time = time.time() - if self.tokenizer is not None: - tokenizer = await self.get_tokenizer_async(lora_request) - self._validate_token_prompt(prompt, tokenizer=tokenizer) + if (isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None)): + # We use the -2 dimension (instead of 0) in case a batched input + # of batch size 1 is passed in. + prompt["prompt_token_ids"] = [0 + ] * prompt["prompt_embeds"].shape[-2] processed_inputs = await self.input_preprocessor.preprocess_async( prompt, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0930bae02e41..4398852daac9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -30,7 +30,7 @@ get_logits_processors as get_openai_logits_processors) from vllm.executor.executor_base import ExecutorBase from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs -from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs +from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.logits_process import get_bad_words_logits_processors @@ -753,10 +753,11 @@ def add_request( if arrival_time is None: arrival_time = time.time() - if self.tokenizer is not None: - self._validate_token_prompt( - prompt, - tokenizer=self.get_tokenizer(lora_request=lora_request)) + if (isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None)): + seq_len = prompt["prompt_embeds"].shape[0] + prompt["prompt_token_ids"] = [0] * seq_len processed_inputs = self.input_preprocessor.preprocess( prompt, @@ -776,27 +777,6 @@ def add_request( priority=priority, ) - def _validate_token_prompt(self, prompt: PromptType, - tokenizer: AnyTokenizer): - # Guard against out-of-vocab tokens. - # For some tokenizers, tokenizer.decode will happily return empty text - # for token ids that are out of vocab, and we don't detect token ids - # that are greater than the max token id before running the model. - # However, these token ids will later crash a cuda kernel at runtime - # with an index out of bounds error. This will crash the entire engine. - # This needs to happen before multimodal input pre-processing, which - # may add dummy tokens that aren't part of the tokenizer's - # vocabulary. - if is_token_prompt(prompt): - prompt_ids = prompt["prompt_token_ids"] - if len(prompt_ids) == 0: - # Empty prompt check is handled later - return - max_input_id = max(prompt_ids) - if max_input_id > tokenizer.max_token_id: - raise ValueError( - "Token id {} is out of vocabulary".format(max_input_id)) - def _create_sequence_group_with_sampling( self, request_id: str, @@ -1267,11 +1247,13 @@ def _advance_to_next_step( if self.scheduler_config.is_multi_step: is_prefill_append = seq.data.get_num_uncomputed_tokens( ) == 0 - seq.append_token_id(sample.output_token, sample.logprobs) + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) if not is_prefill_append: seq_group.update_num_computed_tokens(1) else: - seq.append_token_id(sample.output_token, sample.logprobs) + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. @@ -2032,13 +2014,21 @@ def _validate_model_input( tokenizer = (None if self.tokenizer is None else self.tokenizer.get_lora_tokenizer(lora_request)) - prompt_ids = prompt_inputs["prompt_token_ids"] + prompt_ids = prompt_inputs.get("prompt_token_ids", []) if not prompt_ids: if prompt_type == "encoder" and model_config.is_multimodal_model: pass # Mllama may have empty encoder inputs for text-only data + if prompt_inputs["type"] == "embeds": + pass else: raise ValueError(f"The {prompt_type} prompt cannot be empty") + if tokenizer is not None: + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") + max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 126e7da70216..0f4c7517ebac 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -167,6 +167,7 @@ def _process_seq_outputs(self, seq: Sequence, sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples] + output_embeds = [sample.output_embed for sample in valid_samples] # Truncate to max_tokens if necessary. remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + @@ -190,11 +191,12 @@ def _process_seq_outputs(self, seq: Sequence, is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 # Incrementally append tokens to the sequence, as if we had only one new # token. - for output_token_id, output_logprob in zip(output_token_ids, - output_logprobs): + for output_token_id, output_logprob, output_embed in zip( + output_token_ids, output_logprobs, output_embeds): seq.append_token_id( token_id=output_token_id, logprobs=output_logprob, + token_embed=output_embed, ) if is_prefill_sampled_token: diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4d96791a1f8a..b5b51bb25a86 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -119,7 +119,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, sample = outputs.samples[0] seq = seq_group.first_seq if not is_async: - seq.append_token_id(sample.output_token, sample.logprobs) + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 5632e8ad446d..e9350612ee57 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -83,6 +83,9 @@ async def beam_search( else: processed_inputs = preprocessor._prompt_to_llm_inputs(prompt) + if processed_inputs["type"] == "embeds": + raise NotImplementedError + prompt_token_ids = processed_inputs["prompt_token_ids"] prompt_text = processed_inputs.get("prompt") multi_modal_data = processed_inputs.get("multi_modal_data") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0a302872d263..69523f36ffc4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -27,7 +27,7 @@ _validate_score_input_lens) from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt -from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt +from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding.guided_fields import ( @@ -567,10 +567,12 @@ def create_tokens_prompt_from_beam( mm_kwargs["mm_processor_kwargs"] = prompt[ "mm_processor_kwargs"] - if is_token_prompt(prompt): + if "prompt_token_ids" in prompt: + prompt = cast(TokensPrompt, prompt) # Needed for mypy prompt_tokens = prompt["prompt_token_ids"] else: prompt_tokens = tokenizer.encode(prompt["prompt"]) + instances.append( BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs)) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 389557dfb7c3..4fb82f38e477 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -19,7 +19,8 @@ from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, - RequestOutputKind, SamplingParams) + KVTransferParams, RequestOutputKind, + SamplingParams) from vllm.sequence import Logprob from vllm.utils import random_uuid, resolve_obj_by_qualname @@ -863,6 +864,22 @@ class CompletionRequest(OpenAIBaseModel): " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + do_remote_decode: bool = Field( + default=False, + description="KVTransfer parameters used for disaggregated serving.") + + do_remote_prefill: bool = Field( + default=False, + description="KVTransfer parameters used for disaggregated serving.") + + remote_engine_id: Optional[str] = Field( + default=None, + description="Remote engine id for prefill-decode disaggregation.") + + remote_block_ids: Optional[list[int]] = Field( + default=None, + description="Remote block ids for prefill-decode disaggregation.") + # doc: end-completion-extra-params # Default sampling parameters for completion requests @@ -960,6 +977,13 @@ def to_sampling_params( whitespace_pattern=self.guided_whitespace_pattern, ) + kv_transfer_params = KVTransferParams.from_optional( + do_remote_decode=self.do_remote_decode, + do_remote_prefill=self.do_remote_prefill, + remote_engine_id=self.remote_engine_id, + remote_block_ids=self.remote_block_ids, + ) + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -988,7 +1012,9 @@ def to_sampling_params( else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids) + allowed_token_ids=self.allowed_token_ids, + kv_transfer_params=kv_transfer_params, + ) @model_validator(mode="before") @classmethod @@ -1238,6 +1264,12 @@ class CompletionResponse(OpenAIBaseModel): model: str choices: list[CompletionResponseChoice] usage: UsageInfo + remote_engine_id: Optional[str] = Field( + default=None, + description="Remote engine id for prefill-decode disaggregation.") + remote_block_ids: Optional[list[int]] = Field( + default=None, + description="Remote block ids for prefill-decode disaggregation.") class CompletionResponseStreamChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1067f35ce240..42180b81119f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -476,12 +476,23 @@ def request_output_to_completion_response( request_metadata.final_usage_info = usage + kv_transfer_params = final_res_batch[0].kv_transfer_params + if kv_transfer_params is not None: + remote_engine_id = kv_transfer_params.remote_engine_id + remote_block_ids = kv_transfer_params.remote_block_ids + else: + remote_engine_id = None + remote_block_ids = None + + assert len(final_res_batch) == 1 return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, + remote_engine_id=remote_engine_id, + remote_block_ids=remote_block_ids, ) def _create_completion_logprobs( diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c75d8f088c5b..c24ba0f45f9e 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,10 +11,6 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger if TYPE_CHECKING: @@ -101,16 +97,6 @@ def set_forward_context(attn_metadata: Any, attn_metadata=attn_metadata, dp_metadata=dp_metadata) - # KVConnector: trigger (possibly async) load before forward. - # Each attn layer will block until the reading is complete. - trigger_kv_transfer = (attn_metadata is not None - and has_kv_transfer_group() - and is_v1_kv_transfer_group()) - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.start_load_kv(_forward_context) - try: yield finally: @@ -147,11 +133,4 @@ def set_forward_context(attn_metadata: Any, "(batchsize, count, median_time(ms)): %s"), forward_stats) - # KVConnector: each attn layer triggers (possibly async) save. - # Ensure all those operations complete before forward() is done. - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.wait_for_save() - _forward_context = prev_context diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index ca706e202836..9914a9dcffcc 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, +from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, - TokensPrompt, build_explicit_enc_dec_prompt, + TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import (DummyData, InputContext, InputProcessingContext, InputRegistry) @@ -21,7 +21,9 @@ "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "TokenInputs", + "EmbedsInputs", "token_inputs", + "embeds_inputs", "DecoderOnlyInputs", "EncoderDecoderInputs", "ProcessorInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 167189ed108e..86dbca180412 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -2,6 +2,7 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast +import torch from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: @@ -63,12 +64,25 @@ class TokensPrompt(TypedDict): """ -SingletonPrompt = Union[str, TextPrompt, TokensPrompt] +class EmbedsPrompt(TypedDict): + """Schema for a prompt provided via token embeddings.""" + + prompt_embeds: torch.Tensor + """The embeddings of the prompt.""" + + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + + +SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ Set of possible schemas for a single prompt: - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) +- An embeddings prompt (:class:`EmbedsPrompt`) Note that "singleton" is as opposed to a data structure which encapsulates multiple prompts, i.e. of the sort @@ -129,6 +143,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) +- An embeddings prompt (:class:`EmbedsPrompt`) - A single data structure containing both an encoder and a decoder prompt (:class:`ExplicitEncoderDecoderPrompt`) """ @@ -176,7 +191,35 @@ def token_inputs( return inputs -DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"] +class EmbedsInputs(TypedDict): + """Represents embeddings-based inputs.""" + + type: Literal["embeds"] + """The type of inputs.""" + + prompt_embeds: torch.Tensor + """The embeddings of the prompt.""" + + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + + +def embeds_inputs( + prompt_embeds: torch.Tensor, + cache_salt: Optional[str] = None, +) -> EmbedsInputs: + """Construct :class:`EmbedsInputs` from optional values.""" + inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds) + + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + + return inputs + + +DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. @@ -198,7 +241,7 @@ class EncoderDecoderInputs(TypedDict): """The inputs for the decoder portion.""" -SingletonInputs = Union[TokenInputs, "MultiModalInputs"] +SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] """ A processed :class:`SingletonPrompt` which can be passed to :class:`vllm.sequence.Sequence`. diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 28e207de1fd3..d17122b48344 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -6,8 +6,9 @@ from vllm.utils import is_list_of -from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt) +from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs, + PromptType, SingletonInputs, SingletonPrompt, TextPrompt, + TokensPrompt) class ParsedText(TypedDict): @@ -84,23 +85,51 @@ class ParsedTokensPrompt(TypedDict): content: TokensPrompt -def parse_singleton_prompt( - prompt: SingletonPrompt, -) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: +class ParsedEmbedsPrompt(TypedDict): + type: Literal['embeds'] + content: EmbedsPrompt + + +ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt, + ParsedTokensPrompt, ParsedEmbedsPrompt] + + +@overload +def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt: + ... + + +def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: if isinstance(prompt, str): return ParsedStrPrompt(type="str", content=prompt) elif isinstance(prompt, dict): - if "prompt_token_ids" in prompt: - return ParsedTokensPrompt(type="tokens", - content=prompt) # type: ignore + # Type ignores are because mypy does not correctly infer the TypedDicts + # Pyright does succeed. + if "prompt_embeds" in prompt: + return ParsedEmbedsPrompt( + type="embeds", content=prompt) # type: ignore[typeddict-item] + elif "prompt_token_ids" in prompt: + return ParsedTokensPrompt( + type="tokens", content=prompt) # type: ignore[typeddict-item] elif "prompt" in prompt: return ParsedTextPrompt(type="text", content=prompt) - - raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") - - -def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: - return isinstance(prompt, dict) and "prompt_token_ids" in prompt + raise TypeError( + "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") def is_explicit_encoder_decoder_prompt( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 83e6907f8c49..97a2ce5c615e 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -6,6 +6,7 @@ from typing_extensions import assert_never +from vllm import envs from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -13,12 +14,14 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs) from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, - PromptType, SingletonInputs, SingletonPrompt, token_inputs) -from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, - is_explicit_encoder_decoder_prompt, parse_singleton_prompt) +from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, + EncoderDecoderInputs, ProcessorInputs, PromptType, + SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, embeds_inputs, token_inputs) +from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) @@ -137,13 +140,10 @@ def _prepare_decoder_input_ids_for_generation( """ Prepares `decoder_input_ids` for generation with encoder-decoder models. - Based on - - https://github.com/huggingface/transformers/blob/ - 4037a2b5b1278736e566aec12e169100275545ea/ - src/transformers/generation/utils.py - - specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + Based on: + https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py + specifically, + `GenerationMixin._prepare_decoder_input_ids_for_generation()`. Arguments: @@ -180,6 +180,23 @@ def _apply_prompt_adapter( return prompt_token_ids + def _get_tokenization_kw( + self, + overrides: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + kwargs = dict[str, Any]() + + if self.model_config.hf_config.model_type == "whisper": + # For Whisper, special tokens should be provided by the user based + # on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + kwargs["add_special_tokens"] = False + + if overrides: + kwargs.update(overrides) + + return kwargs + def _tokenize_prompt( self, prompt: str, @@ -191,18 +208,11 @@ def _tokenize_prompt( corresponding token IDs. """ tokenizer = self.get_tokenizer_group() - if tokenization_kwargs is None: - tokenization_kwargs = {} + tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) - if self.model_config.hf_config.model_type == "whisper": - # For Whisper, special tokens should be provided by the user based - # on the task and language of their request. Also needed to avoid - # appending an EOS token to the prompt which disrupts generation. - tokenization_kwargs["add_special_tokens"] = False + encoder_config = self.model_config.encoder_config - if (self.model_config.encoder_config is not None - and self.model_config.encoder_config.get( - "do_lower_case", False)): + if encoder_config and encoder_config.get("do_lower_case", False): prompt = prompt.lower() return tokenizer.encode(prompt=prompt, @@ -217,18 +227,36 @@ async def _tokenize_prompt_async( ) -> list[int]: """Async version of :meth:`_tokenize_prompt`.""" tokenizer = self.get_tokenizer_group() - if tokenization_kwargs is None: - tokenization_kwargs = {} + tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) - if self.model_config.hf_config.model_type == "whisper": - # For Whisper, special tokens should be provided by the user based - # on the task and language of their request. Also needed to avoid - # appending an EOS token to the prompt which disrupts generation. - tokenization_kwargs["add_special_tokens"] = False return await tokenizer.encode_async(prompt=prompt, lora_request=lora_request, **tokenization_kwargs) + def _get_mm_tokenizer( + self, + lora_request: Optional[LoRARequest], + ) -> AnyTokenizer: + # PrithviGeoSpatialMAE needs to be initialized without a tokenizer + # while using also multi-modal input + if not self.tokenizer: + return cast(AnyTokenizer, object()) # Dummy + + tokenizer_group = self.get_tokenizer_group() + return tokenizer_group.get_lora_tokenizer(lora_request) + + async def _get_mm_tokenizer_async( + self, + lora_request: Optional[LoRARequest], + ) -> AnyTokenizer: + # PrithviGeoSpatialMAE needs to be initialized without a tokenizer + # while using also multi-modal input + if not self.tokenizer: + return cast(AnyTokenizer, object()) # Dummy + + tokenizer_group = self.get_tokenizer_group() + return await tokenizer_group.get_lora_tokenizer_async(lora_request) + def _process_multimodal( self, prompt: Union[str, list[int]], @@ -241,13 +269,7 @@ def _process_multimodal( Apply the model's multi-modal processor to a multi-modal prompt, returning the corresponding token IDs and metadata. """ - # At the moment on model (PrithviGeoSpatialMAE) requires to be - # initialized without a tokenizer while using also multi-modal input - if not self.tokenizer: - tokenizer = object() # Dummy - else: - tokenizer_group = self.get_tokenizer_group() - tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) + tokenizer = self._get_mm_tokenizer(lora_request) mm_processor = self.mm_registry.create_processor(self.model_config, tokenizer=tokenizer) @@ -267,14 +289,7 @@ async def _process_multimodal_async( return_mm_hashes: bool = False, ) -> MultiModalInputs: """Async version of :meth:`_process_multimodal`.""" - # At the moment on model (PrithviGeoSpatialMAE) requires to be - # initialized without a tokenizer while using also multi-modal input - if not self.tokenizer: - tokenizer = object() # Dummy - else: - tokenizer_group = self.get_tokenizer_group() - tokenizer = await tokenizer_group.get_lora_tokenizer_async( - lora_request) + tokenizer = await self._get_mm_tokenizer_async(lora_request) mm_processor = self.mm_registry.create_processor(self.model_config, tokenizer=tokenizer) @@ -284,28 +299,160 @@ async def _process_multimodal_async( return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, return_mm_hashes) - def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt, - ParsedTextPrompt, - ParsedTokensPrompt]): - prompt_text = None - prompt_token_ids = None - token_type_ids = None - cache_salt = None + def _process_embeds( + self, + parsed_content: EmbedsPrompt, + ) -> EmbedsInputs: + if envs.VLLM_USE_V1: + raise ValueError("prompt_embeds is only available in V0.") + + prompt_embeds = parsed_content["prompt_embeds"] + + # prompt_embeds must be (seq_len, hidden_size), but if the user + # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), + # we can unambiguously process the intent by squeezing the batch + # dimension. + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.squeeze(dim=0) + + if prompt_embeds.ndim != 2: + raise ValueError( + "prompt_embeds must be of shape (seq_len, hidden_size).") + + return embeds_inputs(prompt_embeds=prompt_embeds, + cache_salt=parsed_content.get("cache_salt")) - if parsed_prompt["type"] == "str": - prompt_text = parsed_prompt["content"] + async def _process_embeds_async( + self, + parsed_content: EmbedsPrompt, + ) -> EmbedsInputs: + return self._process_embeds(parsed_content) + + def _process_tokens( + self, + parsed_content: TokensPrompt, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_token_ids = parsed_content["prompt_token_ids"] + token_type_ids = parsed_content.get("token_type_ids") + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = self._process_multimodal( + prompt_token_ids, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) else: - cache_salt = parsed_prompt["content"].get("cache_salt") - if parsed_prompt["type"] == "text": - prompt_text = parsed_prompt["content"]["prompt"] - elif parsed_prompt["type"] == "tokens": - prompt_token_ids = parsed_prompt["content"].get( - "prompt_token_ids") - token_type_ids = parsed_prompt["content"].get("token_type_ids") - else: - assert_never(parsed_prompt) + inputs = token_inputs( + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + async def _process_tokens_async( + self, + parsed_content: TokensPrompt, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_token_ids = parsed_content["prompt_token_ids"] + token_type_ids = parsed_content.get("token_type_ids") + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = await self._process_multimodal_async( + prompt_token_ids, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + inputs = token_inputs( + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + def _process_text( + self, + parsed_content: TextPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_text = parsed_content["prompt"] + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = self._process_multimodal( + prompt_text, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + prompt_token_ids = self._tokenize_prompt( + prompt_text, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) + inputs = token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + async def _process_text_async( + self, + parsed_content: TextPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_text = parsed_content["prompt"] + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = await self._process_multimodal_async( + prompt_text, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + prompt_token_ids = await self._tokenize_prompt_async( + prompt_text, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) + inputs = token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt - return prompt_text, prompt_token_ids, token_type_ids, cache_salt + return inputs def _prompt_to_llm_inputs( self, @@ -328,36 +475,31 @@ def _prompt_to_llm_inputs( * :class:`SingletonInputs` instance """ parsed = parse_singleton_prompt(prompt) - prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ - self._get_prompt_data(parsed) - # If multimodal data is present, process and return immediately - if parsed["type"] != "str" and parsed["content"].get( - "multi_modal_data") is not None: - inputs = self._process_multimodal( - prompt_text if prompt_text is not None else prompt_token_ids, - parsed["content"]["multi_modal_data"], - parsed["content"].get("mm_processor_kwargs"), + if parsed["type"] == "embeds": + return self._process_embeds(parsed["content"]) + if parsed["type"] == "tokens": + return self._process_tokens( + parsed["content"], lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) - if cache_salt is not None: - inputs["cache_salt"] = cache_salt - return inputs - - if prompt_token_ids is None: - prompt_token_ids = self._tokenize_prompt( - prompt_text, + if parsed["type"] == "text": + return self._process_text( + parsed["content"], + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "str": + return self._process_text( + TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, ) - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - cache_salt=cache_salt, - ) + assert_never(parsed) async def _prompt_to_llm_inputs_async( self, @@ -366,49 +508,49 @@ async def _prompt_to_llm_inputs_async( lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> SingletonInputs: - """Async version of :meth:`_extract_prompt_components`.""" + """Async version of :meth:`_prompt_to_llm_inputs`.""" parsed = parse_singleton_prompt(prompt) - prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ - self._get_prompt_data(parsed) - - if parsed["type"] != "str" and parsed["content"].get( - "multi_modal_data") is not None: - inputs = await self._process_multimodal_async( - prompt_token_ids if prompt_text is None else prompt_text, - parsed["content"]["multi_modal_data"], - parsed["content"].get("mm_processor_kwargs"), + if parsed["type"] == "embeds": + return await self._process_embeds_async(parsed["content"]) + if parsed["type"] == "tokens": + return await self._process_tokens_async( + parsed["content"], lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) - if cache_salt is not None: - inputs["cache_salt"] = cache_salt - return inputs - - if prompt_token_ids is None: - prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + if parsed["type"] == "text": + return await self._process_text_async( + parsed["content"], + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "str": + return await self._process_text_async( + TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, ) - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - cache_salt=cache_salt, - ) + assert_never(parsed) def _build_enc_dec_llm_inputs( self, encoder_inputs: SingletonInputs, decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: - if (encoder_inputs["type"] == "token" - or encoder_inputs["type"] == "multimodal"): - pass - else: - assert_never(encoder_inputs) # type: ignore[arg-type] + if (encoder_inputs["type"] == "embeds" + or decoder_inputs and decoder_inputs["type"] == "embeds"): + raise ValueError("Embedding inputs are not supported for encoder-" + "decoder models") + + # Needed for mypy + encoder_inputs = cast(Union[TokenInputs, MultiModalInputs], + encoder_inputs) + decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]], + decoder_inputs) if decoder_inputs is None: if self.model_config.hf_config.model_type == "whisper": @@ -421,73 +563,78 @@ def _build_enc_dec_llm_inputs( dec_token_ids = self._prepare_decoder_input_ids_for_generation( None) decoder_inputs = token_inputs(dec_token_ids) - elif (decoder_inputs["type"] == "token" - or decoder_inputs["type"] == "multimodal"): - dec_token_ids = self._prepare_decoder_input_ids_for_generation( - decoder_inputs["prompt_token_ids"]) - decoder_inputs["prompt_token_ids"] = dec_token_ids - + else: if "multi_modal_data" in decoder_inputs: raise ValueError("Multi-modal decoder inputs of encoder-" "decoder models are not supported yet") - else: - assert_never(encoder_inputs) # type: ignore[arg-type] + + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + decoder_inputs["prompt_token_ids"]) + decoder_inputs["prompt_token_ids"] = dec_token_ids return EncoderDecoderInputs( encoder=encoder_inputs, decoder=decoder_inputs, ) - def _separate_enc_dec_inputs_from_mm_processor_outputs( + def _split_enc_dec_mm_inputs( self, - inputs: SingletonInputs, + inputs: Union[SingletonInputs, MultiModalEncDecInputs], decoder_inputs_to_override: Optional[SingletonInputs] = None, ) -> tuple[SingletonInputs, SingletonInputs]: """ For encoder/decoder models only: Separate Encoder/Decoder inputs from a MultiModalEncDecInputs """ + if (inputs["type"] == "embeds" or decoder_inputs_to_override + and decoder_inputs_to_override["type"] == "embeds"): + raise ValueError("Embedding inputs are not supported for encoder-" + "decoder models") + + # Needed for mypy + inputs = cast( + Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs], + inputs, + ) + decoder_inputs_to_override = cast( + Optional[Union[TokenInputs, MultiModalInputs]], + decoder_inputs_to_override, + ) + encoder_inputs: SingletonInputs decoder_inputs: SingletonInputs - if inputs["type"] == "multimodal": - # Multimodal data inputs - assert ("encoder_prompt" in inputs - and "encoder_prompt_token_ids" in inputs) + + if inputs["type"] == "multimodal": # Multimodal data inputs + if not ("encoder_prompt" in inputs + and "encoder_prompt_token_ids" in inputs): + raise RuntimeError("You should register an encoder-decoder " + "multi-modal processor for encoder-decoder " + "models.") inputs = cast(MultiModalEncDecInputs, inputs) + encoder_inputs = token_inputs( prompt=inputs["encoder_prompt"], prompt_token_ids=inputs["encoder_prompt_token_ids"], ) - if decoder_inputs_to_override is not None: - decoder_inputs = MultiModalInputs( - type="multimodal", - prompt=decoder_inputs_to_override.get("prompt", ""), - prompt_token_ids=decoder_inputs_to_override[ - "prompt_token_ids"], - mm_kwargs=inputs["mm_kwargs"], - mm_hashes=inputs["mm_hashes"], - mm_placeholders=inputs["mm_placeholders"], - ) - else: - decoder_inputs = MultiModalInputs( - type="multimodal", - prompt=inputs["prompt"], - prompt_token_ids=inputs["prompt_token_ids"], - mm_kwargs=inputs["mm_kwargs"], - mm_hashes=inputs["mm_hashes"], - mm_placeholders=inputs["mm_placeholders"], - ) - cache_salt = inputs.get("cache_salt") - if cache_salt is not None: + decoder_prompt_inputs = decoder_inputs_to_override or inputs + decoder_inputs = MultiModalInputs( + type="multimodal", + prompt=decoder_prompt_inputs.get("prompt", ""), + prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"], + mm_kwargs=inputs["mm_kwargs"], + mm_hashes=inputs["mm_hashes"], + mm_placeholders=inputs["mm_placeholders"], + ) + if cache_salt := inputs.get("cache_salt"): decoder_inputs["cache_salt"] = cache_salt - elif inputs["type"] == "token": - # Text-only inputs + elif inputs["type"] == "token": # Text-only inputs encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) decoder_inputs = decoder_inputs_to_override or inputs else: assert_never(inputs) # type: ignore[arg-type] + return encoder_inputs, decoder_inputs def _process_encoder_decoder_prompt( @@ -541,8 +688,8 @@ def _process_encoder_decoder_prompt( # with explicit decoder prompt. if self.model_config.is_multimodal_model: encoder_inputs, decoder_inputs = ( - self._separate_enc_dec_inputs_from_mm_processor_outputs( - encoder_inputs, decoder_inputs)) + self._split_enc_dec_mm_inputs(encoder_inputs, + decoder_inputs)) else: inputs = self._prompt_to_llm_inputs( prompt, @@ -551,11 +698,9 @@ def _process_encoder_decoder_prompt( if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( - self._separate_enc_dec_inputs_from_mm_processor_outputs( - inputs)) + self._split_enc_dec_mm_inputs(inputs)) else: encoder_inputs = inputs - decoder_inputs = None return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) @@ -591,8 +736,8 @@ async def _process_encoder_decoder_prompt_async( # with explicit decoder prompt. if self.model_config.is_multimodal_model: encoder_inputs, decoder_inputs = ( - self._separate_enc_dec_inputs_from_mm_processor_outputs( - encoder_inputs, decoder_inputs)) + self._split_enc_dec_mm_inputs(encoder_inputs, + decoder_inputs)) else: inputs = await self._prompt_to_llm_inputs_async( prompt, @@ -601,11 +746,9 @@ async def _process_encoder_decoder_prompt_async( if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( - self._separate_enc_dec_inputs_from_mm_processor_outputs( - inputs)) + self._split_enc_dec_mm_inputs(inputs)) else: encoder_inputs = inputs - decoder_inputs = None return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) @@ -615,14 +758,13 @@ def _build_decoder_only_llm_inputs( prompt_inputs: DecoderOnlyInputs, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: - if (prompt_inputs["type"] == "token" - or prompt_inputs["type"] == "multimodal"): + if "prompt_token_ids" in prompt_inputs: + prompt_inputs = cast(Union[TokenInputs, MultiModalInputs], + prompt_inputs) # Needed for mypy prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"], prompt_adapter_request=prompt_adapter_request, ) - else: - assert_never(prompt_inputs) # type: ignore[arg-type] return prompt_inputs diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 000000000000..e539335d4dc7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 62614a59cbe9..238808b226f4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -71,8 +71,8 @@ def single_marlin_moe( E = w.shape[0] N = w.shape[2] // (num_bits // 2) - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, gating_output, topk, renormalize) # This might not be an optimal config for a single MMM get_config_func = functools.partial(try_get_optimal_moe_config, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a209715ede77..c1edbda0dd22 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -854,7 +854,7 @@ def fused_topk( gating_output: torch.Tensor, topk: int, renormalize: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") @@ -868,20 +868,19 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + token_expert_indices = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. topk_func = dispatch_topk_func() topk_weights, topk_ids = topk_func(topk_weights, topk_ids, - token_expert_indicies, + token_expert_indices, gating_output_float, renormalize) - del token_expert_indicies # Not used. Will be used in the future. - return topk_weights, topk_ids + return topk_weights, topk_ids, token_expert_indices # This is used by the Deepseek-V2 and Deepseek-V3 model @@ -1510,8 +1509,8 @@ def fused_moe( topk, renormalize, num_expert_group, topk_group) elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, gating_output, topk, renormalize) else: topk_weights, topk_ids = custom_routing_function( hidden_states, gating_output, topk, renormalize) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3cdf3c97a7d3..35994c8ac6af 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -801,10 +801,11 @@ def select_experts(hidden_states: torch.Tensor, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) else: topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py new file mode 100644 index 000000000000..cdf7e31c1436 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch + + +def moe_permute( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This function expands and permutes activation to gather uncontinuous tokens + for each expert. + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - topk_weights (torch.Tensor): topk expert route weight for each token. + - topk_ids (torch.Tensor): topk expert route id for each token. + - token_expert_indices (torch.Tensor): indice for expanded hidden. + - topk (int): The number of top-k experts to select. + - n_expert (int): The number of expert. + - n_local_expert (int): The number of expert in current EP rank. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - align_block_size (Optional[int]): align group gemm block size for deepgemm + - fill_invalid_expert(int): fill expert id in m_indices for invalid expert + to workaround DeepGemm unsupported -1 in m_indices + Returns: + - permuted_hidden_states (torch.Tensor): permuted activation. + - expert_first_token_offset (torch.Tensor): offset of the first token + of each expert for standard grouped gemm. if enable 'align_block_size' + expert_first_token_offset will align up to 'align_block_size'. + - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. + - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records + the group which the j-th row of the LHS belong to.` + """ + n_token, n_hidden = hidden_states.shape + assert (n_hidden * hidden_states.element_size() + ) % 16 == 0, "permue kernel need hidden dim align to 16B" + permuted_row_size = n_token * topk + if align_block_size is not None: + permuted_row_size = (permuted_row_size + n_expert * + (align_block_size - 1) + align_block_size - + 1) // align_block_size * align_block_size + + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + m_indices = torch.full((permuted_row_size, ), + fill_invalid_expert, + dtype=torch.int32, + device=hidden_states.device) + expert_first_token_offset = torch.empty(n_local_expert + 1, + dtype=torch.int64, + device=hidden_states.device) + src_row_id2dst_row_id_map = torch.empty((n_token, topk), + dtype=torch.int32, + device=hidden_states.device) + torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids, + token_expert_indices, expert_map, n_expert, + n_local_expert, topk, align_block_size, + permuted_hidden_states, + expert_first_token_offset, + src_row_id2dst_row_id_map, m_indices) + return (permuted_hidden_states, expert_first_token_offset, + src_row_id2dst_row_id_map, m_indices) + + +def moe_unpermute( + permuted_hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + src_row_id2dst_row_id_map: torch.Tensor, + expert_first_token_offset: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, +) -> torch.Tensor: + """ + This function expands and permutes activation to gathering uncontinuous + tokens for each expert. + Parameters: + - permuted_hidden_states (torch.Tensor): permuted activation. + - topk_weights (torch.Tensor): topk expert route weight for each token. + - topk_ids (torch.Tensor): topk expert route id for each token. + - expert_first_token_offset (torch.Tensor): offset of the first token + of each expert for grouped gemm. + - topk (int): The number of top-k experts to select. + - n_expert (int): The number of expert. + - n_local_expert (int): The number of expert in current EP rank. + Returns: + - hidden_states (torch.Tensor): The reduced and unpermuted activation + tensor. + """ + n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1] + assert (n_hidden * permuted_hidden_states.element_size() + ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" + hidden_states = torch.empty((n_token, n_hidden), + dtype=permuted_hidden_states.dtype, + device=permuted_hidden_states.device) + + torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, + topk_ids, src_row_id2dst_row_id_map, + expert_first_token_offset, n_expert, + n_local_expert, topk, hidden_states) + return hidden_states diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 07d928b597ba..f7c885c2baa3 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -140,7 +140,7 @@ def get_quant_method(self, layer: torch.nn.Module, from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) if not check_moe_marlin_supports_layer(layer, self.group_size): - logger.warning_one( + logger.warning_once( f"Layer '{prefix}' is not supported by AWQMoeMarlin. " "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 721e36af2b28..ae16a20cfaab 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -34,6 +34,7 @@ class GPTQMarlinState(Enum): "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Fp8MoECutlassMethod", + "CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", ] @@ -71,6 +72,8 @@ def get_moe_method( return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8MoEMethod(quant_config) else: raise RuntimeError( f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") @@ -545,6 +548,138 @@ def apply( ) +class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN) + if not per_channel: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found static input scales.") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.int8 + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + hidden_size, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + assert not self.static_input_scales + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_int8_w8a8=True, + per_channel_quant=True, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + + class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def __init__( diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 6ee3a2f1bbbb..b06c9579d63d 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -134,7 +134,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 70 + return 80 @classmethod def get_config_filenames(cls) -> List[str]: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index c7f9d95f4c2d..703d54b3bee6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -157,7 +157,7 @@ def get_quant_method(self, layer: torch.nn.Module, from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) if not check_moe_marlin_supports_layer(layer, self.group_size): - logger.warning_one( + logger.warning_once( f"Layer '{prefix}' is not supported by GPTQMoeMarlin. " "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 2bf21a05c46d..047724129522 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -111,7 +111,7 @@ def apply_weights(self, # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. symmetric = azp_adj is None - x_q, x_s, x_zp = ops.scaled_int8_quant(x, + x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(), i_s, i_zp, symmetric=symmetric) diff --git a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py index 5d28d327e8a2..e26ac4ea3d4c 100644 --- a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py +++ b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py @@ -71,6 +71,15 @@ def _check_bitblas_supported( f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} " "are supported.") + # Finally, check if bitblas is installed + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError("bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError: + return False, "BitBLAS is not installed." + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 98b06b6c2ae9..aaaf7a9e0a4c 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -85,6 +85,32 @@ def block_dequant( return x_dq_block +if current_platform.is_rocm(): + from triton.language import core + + # NOTE: This can be removed when hip.libdevice.round() is available. + @core.extern + def round_f32(arg0, _builder=None): + return core.extern_elementwise("", + "", [arg0], { + (core.dtype("fp32"), ): + ("llvm.round", core.dtype("fp32")), + (core.dtype("fp64"), ): + ("llvm.round", core.dtype("fp64")), + }, + is_pure=True, + _builder=_builder) + + @triton.jit + def round_int8(x): + return round_f32(x).to(tl.int8) +else: + + @triton.jit + def round_int8(x): + return tl.extra.cuda.libdevice.round(x).to(tl.int8) + + @triton.jit def _per_token_quant_int8( x_ptr, @@ -106,7 +132,7 @@ def _per_token_quant_int8( absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) scale_x = absmax / 127 x_q = x * (127 / absmax) - x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) + x_q = round_int8(x_q) tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) tl.store(scale_ptr + row_id, scale_x) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1ee1332ac45e..9368992b24fe 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -110,6 +110,11 @@ class SamplerOutput( # 'broadcasted' to all other PP ranks for next step. sampled_token_ids_cpu: Optional[torch.Tensor] = None + # On-device tensor containing the sampled token embeddings (embeddings + # corresponding to the sampled token ids). Used when prompt embeddings are + # specified in lieu of prompt token ids or text. + sampled_token_embeds: Optional[torch.Tensor] = None + # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None @@ -183,7 +188,7 @@ def __init__(self): # Whether or not the SamplerOutput should have on-device tensors # containing the sampled token ids and probabilities. This is used by - # speculative decoding. + # speculative decoding and when prompt embeddings are specified. self.include_gpu_probs_tensor = False self.should_modify_greedy_probs_inplace = False diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index cb9100e35594..01f75db9ee86 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -384,6 +384,7 @@ def _get_weights_iterator( weights_iterator = pt_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, ) if current_platform.is_tpu(): @@ -890,6 +891,7 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): iterator = pt_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, ) for org_name, param in iterator: # mapping weight names from transformers to vllm while preserving diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 37a8491cf63d..10bc55ca5f7d 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -502,6 +502,7 @@ def fastsafetensors_weights_iterator( def pt_weights_iterator( hf_weights_files: List[str], use_tqdm_on_load: bool, + pt_load_map_location: Union[str, dict[str, str]] = "cpu", ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" for bin_file in tqdm( @@ -510,7 +511,9 @@ def pt_weights_iterator( disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, ): - state = torch.load(bin_file, map_location="cpu", weights_only=True) + state = torch.load(bin_file, + map_location=pt_load_map_location, + weights_only=True) yield from state.items() del state diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index dfe8f20c70d6..c518efdb54f8 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -175,10 +175,8 @@ def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) do_normalize = self.top_k > 1 - topk_weights, topk_ids = fused_topk(hidden_states, - router_logits, - self.top_k, - renormalize=do_normalize) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, router_logits, self.top_k, renormalize=do_normalize) # topk_ids: (num_tokens, k) if self.is_quant: if 2 * num_tokens <= self.num_experts: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index ffa5840b4604..ce86b9b2c4f0 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -454,9 +454,7 @@ def __init__( qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, rotary_emb=self.rotary_emb, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, ) self.prefix = prefix @@ -468,17 +466,22 @@ def forward( hidden_states: torch.Tensor, ) -> torch.Tensor: if self.q_lora_rank is not None: - ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) + q_c = self.q_a_proj(hidden_states)[0] + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] else: - hidden_states_or_q_c = hidden_states + q = self.q_proj(hidden_states)[0] kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, - kv_c_normed, - k_pe, - output_shape=hidden_states.shape) + + attn_out = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(hidden_states.shape[0], + self.num_local_heads * self.v_head_dim)) + return self.o_proj(attn_out)[0] class DeepseekV2DecoderLayer(nn.Module): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index c42f19fee17d..904ff3210943 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -6,7 +6,8 @@ import torch.nn as nn from transformers import LlamaConfig -from vllm.config import ModelConfig, VllmConfig +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear @@ -76,17 +77,19 @@ def forward( return hidden_states, residual +@support_torch_compile class LlamaModel(nn.Module): def __init__( self, *, - model_config: ModelConfig, + vllm_config: VllmConfig, start_layer_id: int = 0, prefix: str = "", ) -> None: super().__init__() - self.config = model_config.hf_config + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -119,8 +122,7 @@ def forward( hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) - if (hidden_states.shape[-1] != input_embeds.shape[-1]): - hidden_states = self.fc(hidden_states) + assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None hidden_states, residual = self.layers[0]( @@ -169,9 +171,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): nn.Module.__init__(self) - model_config = vllm_config.speculative_config.draft_model_config - self.config = model_config.hf_config - self.model = LlamaModel(model_config=model_config, + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + self.model = LlamaModel(vllm_config=vllm_config, start_layer_id=start_layer_id, prefix="model") @@ -214,6 +216,13 @@ def compute_logits( logits_new[:, targets] = logits return logits_new + def combine_hidden_states( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # combine multiple auxiliary hidden states returned by eagle3 + return self.model.fc(hidden_states) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 16f5327ee79b..3791b92ecc2a 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -354,9 +354,8 @@ def _get_prompt_updates( image_token_id = hf_config.image_token_index image_end_id = vocab[processor.image_end_token] - vision_config = hf_config.vision_config - assert isinstance(vision_config, PixtralVisionConfig) - encoder_info = PixtralHFEncoderInfo(vision_config) + assert isinstance(hf_config.vision_config, PixtralVisionConfig) + encoder_info = PixtralHFEncoderInfo(hf_config) def get_replacement(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index c9abe4142be5..6352ba236818 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -272,12 +272,8 @@ def _get_prompt_updates( image_token_id = hf_config.image_token_index image_end_id = vocab[processor.image_end_token] - vision_config = hf_config.vision_config - assert isinstance(vision_config, PixtralVisionConfig) - # Need to sneak in spatial_merge_size for Mistral3 - vision_config.spatial_merge_size = getattr(hf_config, - "spatial_merge_size", 1) - encoder_info = PixtralHFEncoderInfo(vision_config) + assert isinstance(hf_config.vision_config, PixtralVisionConfig) + encoder_info = PixtralHFEncoderInfo(hf_config) def get_replacement(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d756b3b8a7ca..7b11a616e585 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -916,8 +916,9 @@ def get_image_size(self) -> int: return self.vision_config.image_size def get_patch_size(self) -> int: - return (self.vision_config.patch_size * - self.vision_config.spatial_merge_size) + # spatial_merge_size is needed for Mistral3 + spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1) + return self.vision_config.patch_size * spatial_merge_size def get_patch_grid_length(self) -> int: image_size, patch_size = self.get_image_size(), self.get_patch_size() diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 05e3b3f3ccdf..901d83ec5b9e 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -19,10 +19,11 @@ class VisionEncoderInfo(ABC, Generic[_C]): - def __init__(self, vision_config: _C) -> None: + def __init__(self, hf_config: _C) -> None: super().__init__() - self.vision_config = vision_config + self.hf_config = hf_config + self.vision_config = hf_config.vision_config @abstractmethod def get_num_image_tokens( @@ -57,18 +58,14 @@ def get_vision_encoder_info( from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig from .siglip import SiglipEncoderInfo, SiglipVisionConfig - vision_config = hf_config.vision_config - if isinstance(vision_config, CLIPVisionConfig): - return CLIPEncoderInfo(vision_config) - if isinstance(vision_config, PixtralVisionConfig): - # Need to sneak in spatial_merge_size for Mistral3 - vision_config.spatial_merge_size = getattr(hf_config, - "spatial_merge_size", 1) - return PixtralHFEncoderInfo(vision_config) - if isinstance(vision_config, SiglipVisionConfig): - return SiglipEncoderInfo(vision_config) - - msg = f"Unsupported vision config: {type(vision_config)}" + if isinstance(hf_config.vision_config, CLIPVisionConfig): + return CLIPEncoderInfo(hf_config) + if isinstance(hf_config.vision_config, PixtralVisionConfig): + return PixtralHFEncoderInfo(hf_config) + if isinstance(hf_config.vision_config, SiglipVisionConfig): + return SiglipEncoderInfo(hf_config) + + msg = f"Unsupported vision config: {type(hf_config.vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index 11665ef66753..53e289370a9f 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -31,16 +31,20 @@ def serialize_item(cls, obj: object) -> bytes: return obj.encode("utf-8") if isinstance(obj, bytes): return obj - if isinstance(obj, Image.Image): - return obj.tobytes() + if isinstance(obj, (int, float)): + return np.array(obj).tobytes() - # Convertible to NumPy arrays + if isinstance(obj, Image.Image): + return cls.item_to_bytes("image", np.array(obj.convert("RGBA"))) if isinstance(obj, torch.Tensor): - obj = obj.numpy() - if isinstance(obj, (int, float)): - obj = np.array(obj) + return cls.item_to_bytes("tensor", obj.numpy()) if isinstance(obj, np.ndarray): - return obj.tobytes() + return cls.item_to_bytes( + "ndarray", { + "dtype": obj.dtype.str, + "shape": obj.shape, + "data": obj.data.tobytes(), + }) logger.warning( "No serialization method found for %s. " @@ -53,14 +57,22 @@ def item_to_bytes( cls, key: str, obj: object, + ) -> bytes: + return b''.join(kb + vb for kb, vb in cls.iter_item_to_bytes(key, obj)) + + @classmethod + def iter_item_to_bytes( + cls, + key: str, + obj: object, ) -> Iterable[tuple[bytes, bytes]]: # Recursive cases if isinstance(obj, (list, tuple)): for i, elem in enumerate(obj): - yield from cls.item_to_bytes(f"{key}.{i}", elem) + yield from cls.iter_item_to_bytes(f"{key}.{i}", elem) elif isinstance(obj, dict): for k, v in obj.items(): - yield from cls.item_to_bytes(f"{key}.{k}", v) + yield from cls.iter_item_to_bytes(f"{key}.{k}", v) else: key_bytes = cls.serialize_item(key) value_bytes = cls.serialize_item(obj) @@ -71,7 +83,7 @@ def hash_kwargs(cls, **kwargs: object) -> str: hasher = blake3() for k, v in kwargs.items(): - for k_bytes, v_bytes in cls.item_to_bytes(k, v): + for k_bytes, v_bytes in cls.iter_item_to_bytes(k, v): hasher.update(k_bytes) hasher.update(v_bytes) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index e8745a8f1f90..58168d0e850c 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1670,15 +1670,17 @@ def _validate_mm_placeholders( placeholders = mm_placeholders.get(modality, []) if len(placeholders) != item_count: + # NOTE: If you are a model developer, this can also arise from + # an inconsistency between `_call_hf_processor` and + # `_get_mm_fields_config` implementations raise RuntimeError( f"Expected there to be {item_count} prompt updates " f"corresponding to {item_count} {modality} items, but " f"instead found {len(placeholders)} prompt updates! " - "Either the prompt text has missing/incorrect tokens for " - "multi-modal inputs, or there is a problem with your " - "implementation of merged multi-modal processor for this " - "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_prompt_updates`).") + "This is likely because you forgot to include input " + "placeholder tokens (e.g., ``, `<|image_pad|>`) " + "in the prompt. If the model has a chat template, make " + "sure you have applied it before calling `LLM.generate`.") def _maybe_apply_prompt_updates( self, diff --git a/vllm/outputs.py b/vllm/outputs.py index 65a6ed01451d..a5a6641973ac 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import KVTransferParams, RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceGroupBase, SequenceStatus) @@ -103,6 +103,7 @@ class RequestOutput: encoder_prompt_token_ids: The token IDs of the encoder prompt. None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. + kv_transfer_params: The params for remote K/V transfer. """ def __init__( @@ -120,6 +121,7 @@ def __init__( num_cached_tokens: Optional[int] = None, *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, + kv_transfer_params: Optional[KVTransferParams] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -133,11 +135,13 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens + self.kv_transfer_params = kv_transfer_params def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished + self.kv_transfer_params = next_output.kv_transfer_params for next_completion in next_output.outputs: for i, completion in enumerate(self.outputs): diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index c5555aba1a3e..6a78e00a9049 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -406,12 +406,12 @@ def validate_request( """Raises if this request is unsupported on this platform""" def __getattr__(self, key: str): - device = getattr(torch, self.device_name, None) + device = getattr(torch, self.device_type, None) if device is not None and hasattr(device, key): return getattr(device, key) else: logger.warning("Current platform %s does not have '%s'" \ - " attribute.", self.device_name, key) + " attribute.", self.device_type, key) return None @classmethod diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index de097ab9af1b..ff63f9656c01 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -58,6 +58,15 @@ "excessive use of shared memory. If this happens, disable Triton FA " "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") } +_ROCM_DEVICE_ID_NAME_MAP: Dict[str, str] = { + "0x74a0": "AMD_Instinct_MI300A", + "0x74a1": "AMD_Instinct_MI300X", + "0x74b5": "AMD_Instinct_MI300X", # MI300X VF + "0x74a5": "AMD_Instinct_MI325X", + "0x74b9": "AMD_Instinct_MI325X", # MI325X VF + "0x74a9": "AMD_Instinct_MI300X_HF", + "0x74bd": "AMD_Instinct_MI300X_HF", +} # Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`` if "HIP_VISIBLE_DEVICES" in os.environ: @@ -225,7 +234,11 @@ def is_fully_connected(physical_device_ids: List[int]) -> bool: def get_device_name(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) handle = amdsmi_get_processor_handles()[physical_device_id] - return amdsmi_get_gpu_asic_info(handle)["market_name"] + asic_info = amdsmi_get_gpu_asic_info(handle) + device_name: str = asic_info["device_id"] + if device_name in _ROCM_DEVICE_ID_NAME_MAP: + return _ROCM_DEVICE_ID_NAME_MAP[device_name] + return asic_info["market_name"] @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d5923557a211..9c95e6d3fa08 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Optional, Union import torch +from tpu_info import device import vllm.envs as envs from vllm.inputs import ProcessorInputs, PromptType @@ -54,7 +55,8 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, @classmethod def get_device_name(cls, device_id: int = 0) -> str: - return "tpu" + chip_type, _ = device.get_local_chips() + return f"TPU {chip_type.name}" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 3ac5c5c3daab..6f10ba3d5fd3 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -27,6 +27,38 @@ class SamplingType(IntEnum): RANDOM_SEED = 2 +# TODO(rob): make this per connector +class KVTransferParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + # TODO(rob): we can handle xPyD and direct KV block Xfer + remote_engine_id: Optional[str] = None + remote_block_ids: Optional[list[int]] = None + do_remote_decode: bool = False + do_remote_prefill: bool = False + + @staticmethod + def from_optional( + do_remote_decode: bool, + do_remote_prefill: bool, + remote_engine_id: Optional[str], + remote_block_ids: Optional[list[int]], + ) -> Optional["KVTransferParams"]: + if do_remote_decode and do_remote_prefill: + raise ValueError( + "Cannot do both remote prefill and remote decode.") + if do_remote_decode or do_remote_prefill: + return KVTransferParams( + do_remote_decode=do_remote_decode, + do_remote_prefill=do_remote_prefill, + remote_engine_id=remote_engine_id, + remote_block_ids=remote_block_ids, + ) + return None + + # maybe make msgspec? @dataclass class GuidedDecodingParams: @@ -186,9 +218,9 @@ class SamplingParams( logits_processors: list of functions that modify logits based on previously generated tokens, and optionally prompt tokens as a first argument. - truncate_prompt_tokens: If set to -1, will use the truncation size - supported by the model. If set to an integer k, will use only - the last k tokens from the prompt (i.e., left truncation). + truncate_prompt_tokens: If set to -1, will use the truncation size + supported by the model. If set to an integer k, will use only + the last k tokens from the prompt (i.e., left truncation). Defaults to None (i.e., no truncation). guided_decoding: If provided, the engine will construct a guided decoding logits processor from these parameters. Defaults to None. @@ -248,6 +280,9 @@ class SamplingParams( bad_words: Optional[list[str]] = None _bad_words_token_ids: Optional[list[list[int]]] = None + # Fields used for KVTransfer in disaggregated serving. + kv_transfer_params: Optional[KVTransferParams] = None + @staticmethod def from_optional( n: Optional[int] = 1, @@ -279,6 +314,7 @@ def from_optional( guided_decoding: Optional[GuidedDecodingParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, allowed_token_ids: Optional[list[int]] = None, + kv_transfer_params: Optional[KVTransferParams] = None, extra_args: Optional[dict[str, Any]] = None, ) -> "SamplingParams": if logit_bias is not None: @@ -321,6 +357,7 @@ def from_optional( guided_decoding=guided_decoding, logit_bias=logit_bias, allowed_token_ids=allowed_token_ids, + kv_transfer_params=kv_transfer_params, extra_args=extra_args, ) diff --git a/vllm/sequence.py b/vllm/sequence.py index a97409523c94..5bc9b8a6fc82 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -166,6 +166,9 @@ class SequenceData(msgspec.Struct, _output_token_ids: array = msgspec.field( default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) + _prompt_embeds: Optional[torch.Tensor] = None + _output_embeds: Optional[torch.Tensor] = None + ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 _prompt_token_ids_tuple: tuple[int, @@ -176,6 +179,7 @@ class SequenceData(msgspec.Struct, _num_cached_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL _cached_all_token_ids: list[int] = msgspec.field(default_factory=list) + _cached_all_token_embeds: Optional[torch.Tensor] = None # It is used to get delta input. It is reset when `get_delta_and_reset` # is called. @@ -208,6 +212,8 @@ def from_prompt_token_counts( def from_seqs( prompt_token_ids: GenericSequence[int], output_token_ids: Optional[GenericSequence[int]] = None, + *, + prompt_embeds: Optional[torch.Tensor] = None, ) -> "SequenceData": """ Construct a :class:`SequenceData` instance from prompt and output @@ -217,13 +223,15 @@ def from_seqs( prompt_token_ids) if output_token_ids is None: - return SequenceData(prompt_token_ids_arr) + return SequenceData(prompt_token_ids_arr, + _prompt_embeds=prompt_embeds) output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, output_token_ids) return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr) + _output_token_ids=output_token_ids_arr, + _prompt_embeds=prompt_embeds) def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" @@ -231,6 +239,8 @@ def __post_init__(self) -> None: self._prompt_token_ids_tuple: tuple[int, ...] = tuple( self._prompt_token_ids) self._update_cached_all_tokens() + if self._prompt_embeds is not None: + self._update_cached_all_token_embeds() def _update_cached_all_tokens(self): assert isinstance(self._prompt_token_ids, array) @@ -238,6 +248,13 @@ def _update_cached_all_tokens(self): self._cached_all_token_ids: list[int] = list(self._prompt_token_ids + self._output_token_ids) + def _update_cached_all_token_embeds(self): + assert isinstance(self._prompt_embeds, torch.Tensor) + self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds + if self._output_embeds is not None: + self._cached_all_token_embeds = torch.cat( + (self._cached_all_token_embeds, self._output_embeds), dim=0) + @property def cumulative_logprob(self) -> float: return self._cumulative_logprob @@ -270,6 +287,15 @@ def output_token_ids(self, new_output_token_ids) self._update_cached_all_tokens() + @property + def output_embeds(self) -> Optional[torch.Tensor]: + return self._output_embeds + + @output_embeds.setter + def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None: + self._output_token_embeds = new_output_token_embeds + self._update_cached_all_token_embeds() + @property def output_token_ids_array(self) -> array: """Return the prompt token ids in array type. @@ -280,6 +306,15 @@ def output_token_ids_array(self) -> array: assert isinstance(self._output_token_ids, array) return self._output_token_ids + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + return self._prompt_embeds + + @prompt_embeds.setter + def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None: + self._prompt_embeds = prompt_embeds + self._update_cached_all_token_embeds() + @property def mrope_position_delta(self) -> Optional[int]: return self._mrope_position_delta @@ -288,11 +323,28 @@ def mrope_position_delta(self) -> Optional[int]: def mrope_position_delta(self, new_mrope_position_delta): self._mrope_position_delta = new_mrope_position_delta - def append_token_id(self, token_id: int, logprob: float) -> None: + def append_token_id(self, + token_id: int, + logprob: float, + token_embed: Optional[torch.Tensor] = None) -> None: self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) self._cumulative_logprob += logprob + if token_embed is not None: + # Do not pass in with batch or sequence dimensions + assert token_embed.ndim == 1 + token_embed = token_embed.detach().cpu().unsqueeze(0) + if self._output_embeds is None: + self._output_embeds = token_embed + else: + self._output_embeds = torch.cat( + (self._output_embeds, token_embed), dim=0) + assert self._cached_all_token_embeds is not None + self._cached_all_token_embeds = torch.cat( + (self._cached_all_token_embeds, + token_embed.to(device=self._cached_all_token_embeds.device)), + dim=0) def get_len(self) -> int: return len(self._output_token_ids) + len(self._prompt_token_ids) @@ -306,6 +358,9 @@ def get_output_len(self) -> int: def get_token_ids(self) -> list[int]: return self._cached_all_token_ids + def get_token_embeddings(self) -> Optional[torch.Tensor]: + return self._cached_all_token_embeds + def get_prefix_token_ids( self, num_tokens: int ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]: @@ -387,6 +442,8 @@ def stage(self) -> SequenceStage: def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " + f"prompt_embeds.shape=" + f"{getattr(self._prompt_embeds, 'shape', None)}, " f"output_token_ids={self.output_token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()})") @@ -425,7 +482,10 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.data = SequenceData.from_seqs(self.prompt_token_ids) + self.data = SequenceData.from_seqs( + self.prompt_token_ids, + prompt_embeds=self.inputs["prompt_embeds"] + if self.inputs["type"] == "embeds" else None) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -448,14 +508,20 @@ def n_blocks(self) -> int: @property def prompt(self) -> Optional[str]: + if self.inputs["type"] == "embeds": + return None return self.inputs.get("prompt") @property def prompt_token_ids(self) -> list[int]: + if self.inputs["type"] == "embeds": + return [0] * len(self.inputs["prompt_embeds"]) return self.inputs["prompt_token_ids"] @property def token_type_ids(self) -> list[int]: + if self.inputs["type"] == "embeds": + return [] return self.inputs.get("token_type_ids", []) @property @@ -554,11 +620,14 @@ def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id(self, token_id: int, logprobs: dict[int, - Logprob]) -> None: + def append_token_id(self, + token_id: int, + logprobs: dict[int, Logprob], + token_embed: Optional[torch.Tensor] = None) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) + self.data.append_token_id(token_id, logprobs[token_id].logprob, + token_embed) def get_len(self) -> int: return self.data.get_len() @@ -889,6 +958,10 @@ def __repr__(self) -> str: f"sampling_params={self.sampling_params}, " f"num_seqs={len(self.seqs)})") + def uses_prompt_embeds(self) -> bool: + """Returns True if the sequence group uses input embeds.""" + return any(seq.data.prompt_embeds is not None for seq in self.seqs) + class SequenceGroupMetadataDelta( msgspec.Struct, @@ -1043,10 +1116,14 @@ class SequenceOutput( parent_seq_id: int output_token: int logprobs: dict[int, Logprob] + output_embed: Optional[torch.Tensor] = None def __repr__(self) -> str: + output_embed_shape = \ + self.output_embed.shape if self.output_embed is not None else None return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}, " + f"output_embed.shape={output_embed_shape}" f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 24095ef2a567..a6276c563394 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -201,6 +201,9 @@ def execute_model( if self.prompt_adapter_config is not None: raise ValueError("TP1DraftModelRunner has no support for " "prompt_adapter_config") + if model_input.inputs_embeds is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "inputs_embeds") if model_input.multi_modal_kwargs: raise ValueError( "TP1DraftModelRunner has no support for multi_modal_kwargs" @@ -242,9 +245,16 @@ def execute_model( # Get model if use_cuda_graph: - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = (self.graph_runners[model_input.virtual_engine] - [graph_batch_size]) + if model_input.inputs_embeds is None: + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, False)]) + else: + graph_batch_size = model_input.inputs_embeds.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, True)]) if previous_hidden_states is not None: hidden_states = torch.cat([ @@ -281,6 +291,7 @@ def execute_model( self.vllm_config): hidden_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=None, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 1146606e9a13..de57403d1b50 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -282,7 +282,8 @@ def _append_new_tokens( else: count += 1 - seq.append_token_id(token_id, token_logprob.logprob) + seq.append_token_id(token_id, token_logprob.logprob, + seq_output.output_embed) seq.update_num_computed_tokens(1) @staticmethod diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index fa29efbf6b2d..0a0c0a4bd178 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import sys import types from importlib.util import find_spec @@ -45,9 +44,4 @@ def __init__(self): super().__init__("triton.language") self.constexpr = None self.dtype = None - - sys.modules['triton'] = TritonPlaceholder() - sys.modules['triton.language'] = TritonLanguagePlaceholder() - -if 'triton' in sys.modules: - logger.info("Triton module has been replaced with a placeholder.") + self.int64 = None diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 217dcd7c33ac..f986d797f2b0 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,9 +10,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) +from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv @@ -273,13 +275,23 @@ def make_local_attention_virtual_batches( block_table_local +def _get_sliding_window_configs( + vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: + """Get the set of all sliding window configs used in the model.""" + sliding_window_configs: set[Optional[tuple[int, int]]] = set() + layers = get_layers_from_vllm_config(vllm_config, Attention) + for layer in layers.values(): + assert isinstance(layer.impl, FlashAttentionImpl) + sliding_window_configs.add(layer.impl.sliding_window) + return sliding_window_configs + + class FlashAttentionMetadataBuilder: def __init__(self, runner: "GPUModelRunner"): model_config = runner.model_config self.runner = runner - self.aot_schedule = (get_flash_attn_version() == 3) self.num_heads_q = model_config.get_num_attention_heads( runner.parallel_config) self.num_heads_kv = model_config.get_num_kv_heads( @@ -287,6 +299,11 @@ def __init__(self, runner: "GPUModelRunner"): self.headdim = model_config.get_head_size() self.page_size = self.runner.block_size + self.aot_schedule = (get_flash_attn_version() == 3) + # Sliding window size to be used with the AOT scheduler will be + # populated on first build() call. + self.aot_sliding_window: Optional[tuple[int, int]] = None + def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False @@ -304,6 +321,22 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() + if self.aot_sliding_window is None: + self.aot_sliding_window = (-1, -1) + # For the AOT scheduler we need the sliding window value to be + # constant for all layers to. We have to populate this on the first + # build() call so the layers are constructed (cannot populate) + # in __init__. + if self.aot_schedule: + sliding_window_configs = _get_sliding_window_configs( + self.runner.vllm_config) + if len(sliding_window_configs) == 1: + sliding_window_config = sliding_window_configs.pop() + if sliding_window_config is not None: + self.aot_sliding_window = sliding_window_config + elif len(sliding_window_configs) > 1: + self.aot_schedule = False + def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.aot_schedule: @@ -318,6 +351,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, page_size=self.page_size, cu_seqlens_q=cu_query_lens, causal=causal, + window_size=self.aot_sliding_window, ) return None diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index fd3be901f4c3..3e77555d7f94 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -200,7 +200,7 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, RowParallelLinear, + LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform @@ -597,12 +597,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, rotary_emb: RotaryEmbedding, - # q_proj should be q_b_proj if q_lora_rank is not None, but from an - # attention backend perspective we rely on the layer to pass in the - # correct matrix - q_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear, - o_proj: RowParallelLinear, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -625,9 +620,7 @@ def __init__( if current_platform.is_cuda(): self.rotary_emb = rotary_emb.forward_cuda - self.q_proj = q_proj self.kv_b_proj = kv_b_proj - self.o_proj = o_proj self.vllm_flash_attn_version = get_flash_attn_version() # Handle the differences between the flash_attn_varlen from flash_attn @@ -684,27 +677,13 @@ def _flash_attn_varlen_diff_headdims(self, return attn_out, lse return attn_out - def _v_up_proj_and_o_proj(self, x): + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return self.o_proj(x)[0] - - # Return `ql_nope`, `q_pe` - def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = self.q_proj(x)[0]\ - .view(-1, self.num_heads, self.qk_head_dim)\ - .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - return ql_nope.transpose(0, 1), q_pe + return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -874,7 +853,7 @@ def _forward_prefill( suffix_lse=suffix_lse, ) - return self.o_proj(output.flatten(start_dim=-2))[0] + return output.flatten(start_dim=-2) @abstractmethod def _forward_decode( @@ -889,7 +868,7 @@ def _forward_decode( def forward( self, layer: AttentionLayer, - hidden_states_or_q_c: torch.Tensor, # query in unified attn + q: torch.Tensor, k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, @@ -908,7 +887,7 @@ def forward( # Inputs and outputs may be padded for CUDA graphs output_padded = output output = output[:num_actual_toks, ...] - hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] @@ -923,24 +902,29 @@ def forward( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + q = q.view(-1, self.num_heads, self.qk_head_dim) + decode_q = q[:num_decode_tokens] decode_k_pe = k_pe[:num_decode_tokens] - prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + prefill_q = q[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] if has_decode: assert attn_metadata.decode is not None - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) if has_prefill: assert attn_metadata.prefill is not None - prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 143bfe35bb5e..f18c9c8b6462 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -146,4 +146,4 @@ def _forward_decode( causal=True, ) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 8e7e4f10b81b..2e6b619db628 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -115,4 +115,4 @@ def _forward_decode( attn_metadata.decode.seq_lens, attn_logits, num_kv_splits, self.scale, PAGE_SIZE) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index cb13a5b7a02f..12c55be00375 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -165,6 +165,7 @@ def allocate_slots( num_tokens: int, new_computed_blocks: Optional[list[KVCacheBlock]] = None, num_lookahead_tokens: int = 0, + skip_cache_blocks: bool = False, ) -> Optional[list[KVCacheBlock]]: """Add slots for a request with new tokens to append. @@ -176,8 +177,11 @@ def allocate_slots( new_computed_blocks: A list of new computed blocks just hitting the prefix caching. num_lookahead_tokens: The number of speculative tokens to allocate. - This is used by spec decode proposers with kv-cache such + This is used by spec decode proposers with kv-cache such as eagle. + skip_cache_blocks: Whether to skip caching the blocks. This is + used by P/D when allocating blocks used in a KV transfer + which will complete in a future step. Blocks layout: ----------------------------------------------------------------------- @@ -267,11 +271,43 @@ def allocate_slots( if not self.enable_caching: return new_blocks + if skip_cache_blocks: + # NOTE(rob): this assert is valid because we only call + # skip_cache_blocks=True on the first time of WAITING + # during a P/D setup. + assert request.request_id not in self.num_cached_block + # NOTE(rob): this is necessary so we don't double + # cache a block after is has finished recving. + self.num_cached_block[request.request_id] = len( + new_computed_blocks) + return new_blocks + + self.cache_blocks( + request=request, + num_tokens=num_tokens, + num_computed_tokens=num_computed_tokens, + new_computed_blocks=new_computed_blocks, + ) + return new_blocks + + def cache_blocks( + self, + request: Request, + num_tokens: int, + num_computed_tokens: int, + new_computed_blocks: Optional[list[KVCacheBlock]] = None, + ): + if new_computed_blocks is None: + new_computed_blocks = [] + + req_blocks = self.req_to_blocks[request.request_id] + # Use `new_computed_blocks` for a new request, and `num_cached_block` # for a running request. num_cached_blocks = self.num_cached_block.get(request.request_id, len(new_computed_blocks)) - # Speculated tokens might be rejected in the future, so we does + + # Speculated tokens might be rejected in the future, so we do # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( @@ -289,7 +325,6 @@ def allocate_slots( self.num_cached_block[ request.request_id] = num_full_blocks_after_append - return new_blocks def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -364,7 +399,8 @@ def get_num_common_prefix_blocks( Returns: int: The number of common prefix blocks. """ - assert request.status == RequestStatus.RUNNING + assert request.status in (RequestStatus.RUNNING, + RequestStatus.FINISHED_REMOTE_DECODE) blocks = self.req_to_blocks[request.request_id] num_common_blocks = 0 for block in blocks: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ae7280a14706..b8e9e0db6362 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2,6 +2,7 @@ from __future__ import annotations +import itertools import time from collections import defaultdict, deque from collections.abc import Iterable @@ -14,6 +15,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.sampling_params import KVTransferParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -96,6 +98,9 @@ def __init__( # This is flushed at the end of each scheduling step. self.finished_req_ids: set[str] = set() + # Requests in states for tracking KV transfers for P/D disagg + self.finished_recving_kv_req_ids: set[str] = set() + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> deque of CachedRequestData @@ -308,6 +313,27 @@ def schedule(self) -> SchedulerOutput: request = self.waiting[0] + # Skip request if the remote KV recv is still waiting + # for the requests to arrive. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + if request.request_id in self.finished_recving_kv_req_ids: + assert self.kv_cache_manager.enable_caching + # Now that the KVs have been recved, we can cache + # them and set num_computed_tokens. + self.kv_cache_manager.cache_blocks( + request, + num_tokens=0, + num_computed_tokens=(len(request.all_token_ids) - + 1)) + self.finished_recving_kv_req_ids.remove( + request.request_id) + request.status = RequestStatus.WAITING + self.kv_cache_manager.free(request) + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + # Skip request if the structured output request is still waiting # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: @@ -344,9 +370,41 @@ def schedule(self) -> SchedulerOutput: # Total computed tokens (local + external). num_computed_tokens += num_external_tokens + if request.do_remote_prefill and num_external_tokens > 0: + # Allocate slots for the external tokens, but skip + # caching until after the KV transfer is done. + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_external_tokens, + computed_blocks, + skip_cache_blocks=True) + if new_blocks is None: + # Requests cannot be scheduled + break + + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + + # KVConnector: update internal state after allocation. + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + [ + b.block_id for b in itertools.chain( + computed_blocks, new_blocks) + ], + num_external_tokens, + ) + # We should only trigger a KV transfer once per request. + request.do_remote_prefill = False + continue + # Number of tokens to be scheduled. # We use `request.num_tokens` instead of - # `request.num_prompt_tokens` to consider the resumed requests, + # `request.num_prompt_tokens` to consider the resumed reqs, # which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens if (0 < self.scheduler_config.long_prefill_token_threshold < @@ -385,6 +443,10 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: self.connector.update_state_after_alloc( request, + [ + b.block_id for b in itertools.chain( + computed_blocks, new_blocks) + ], num_external_tokens, ) @@ -416,7 +478,7 @@ def schedule(self) -> SchedulerOutput: request.num_computed_tokens = num_computed_tokens # Encoder-related. - if encoder_inputs_to_schedule: + if not request.do_remote_prefill and encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( encoder_inputs_to_schedule) # Allocate the encoder cache. @@ -518,7 +580,8 @@ def schedule(self) -> SchedulerOutput: # 3. If some tokens (e.g. spec tokens) are rejected later, the number of # computed tokens will be adjusted in update_from_output. for req_id, num_scheduled_token in num_scheduled_tokens.items(): - self.requests[req_id].num_computed_tokens += num_scheduled_token + if req := self.requests.get(req_id): + req.num_computed_tokens += num_scheduled_token self.finished_req_ids = set() return scheduler_output @@ -742,6 +805,30 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: + # Stop request after the first token if doing a remote_decode. + # NOTE(rob): req is not freed (or preempted) in the EngineCore + # until the xfer is done to ensure we do not free the KV blocks. + kv_transfer_params = None + # TODO(rob): edge case where we get a stop for stop_strings + # inside AsyncLLM. + if request.do_remote_decode and not stopped: + request.status = RequestStatus.FINISHED_REMOTE_DECODE + self._free_request(request, skip_free_blocks=True) + stopped = True + + # TODO(rob): do this on a per-Connector basis. + remote_blocks = [ + block.block_id for block in + self.kv_cache_manager.get_computed_blocks(request)[0] + ] + + engine_id = self.vllm_config.kv_transfer_config.engine_id + kv_transfer_params = KVTransferParams( + do_remote_prefill=True, + remote_block_ids=remote_blocks, + remote_engine_id=engine_id, + ) + # Add EngineCoreOutput for this Request. outputs.append( EngineCoreOutput( @@ -751,7 +838,10 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events())) + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + )) + else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -759,9 +849,22 @@ def update_from_output( if not stopped: new_running.append(request) - # Return the cached request data to the queue so they can be reused. + # P/D: update recv and send status from last step. + for req_id in (model_runner_output.finished_recving or []): + logger.debug("Finished recving KV transfer for request %s", req_id) + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (model_runner_output.finished_sending or []): + logger.debug("Finished sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) + + # Return the cached request data to the queue so they can + # be reused. Note: we cannot add stopped requests to this + # since they are already freed above! for req_data in scheduler_output.scheduled_cached_reqs: - self._cached_reqs_data[req_data.req_id].append(req_data) + # NOTE(rob): since we free stopped reqs above, adding stopped reqs + # to _cached_reqs_data will cause a memory leak. + if req_data.req_id not in self.finished_req_ids: + self._cached_reqs_data[req_data.req_id].append(req_data) self.running = new_running engine_core_outputs = EngineCoreOutputs( @@ -810,15 +913,24 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, request: Request) -> None: + def _free_request(self, + request: Request, + skip_free_blocks: bool = False) -> None: assert request.is_finished() - self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) - del self.requests[request.request_id] self.finished_req_ids.add(request.request_id) + if not skip_free_blocks: + self._free_blocks(request) + + def _free_blocks(self, request: Request): + assert request.is_finished() + assert request.request_id not in self._cached_reqs_data + self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) + del self.requests[request.request_id] + def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e33d1a1e5dcd..ade51fff5a22 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -10,13 +10,13 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import KVTransferParams, SamplingParams from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors # These are possible values of RequestOutput.finish_reason, # so form part of the external API. -FINISH_REASON_STRINGS = ("stop", "length", "abort") +FINISH_REASON_STRINGS = ("stop", "length", "abort", "remote_decode") class FinishReason(enum.IntEnum): @@ -28,11 +28,13 @@ class FinishReason(enum.IntEnum): stop - a stop string was emitted length - max_tokens was consumed, or max_model_len was reached abort - aborted for another reason + remote_decode - request will be processed as a remote_decode """ STOP = 0 LENGTH = 1 ABORT = 2 + REMOTE_DECODE = 3 def __str__(self): return FINISH_REASON_STRINGS[self.value] @@ -105,6 +107,7 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + kv_transfer_params: Optional[KVTransferParams] = None @property def finished(self) -> bool: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index f76c44cb8bca..1d98f15ebde3 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -6,7 +6,7 @@ from typing import Optional, Union from vllm.outputs import CompletionOutput, RequestOutput -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import KVTransferParams, RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason @@ -146,6 +146,7 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + kv_transfer_params: KVTransferParams, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -167,13 +168,15 @@ def make_request_output( if not outputs: return None - return self._new_request_output(request_id, outputs, finished) + return self._new_request_output(request_id, outputs, finished, + kv_transfer_params) def _new_request_output( self, request_id: str, outputs: list[CompletionOutput], finished: bool, + kv_transfer_params: KVTransferParams, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -189,6 +192,7 @@ def _new_request_output( prompt_logprobs=prompt_logprobs, outputs=outputs, finished=finished, + kv_transfer_params=kv_transfer_params, ) def _new_completion_output( @@ -301,22 +305,22 @@ def process_outputs( 1) Compute stats for logging 2) Detokenize 3) Create and handle RequestOutput objects: - * If there is a queue (for usage with AsyncLLM), + * If there is a queue (for usage with AsyncLLM), put the RequestOutput objects into the queue for handling by the per-request generate() tasks. - * If there is no queue (for usage with LLMEngine), + * If there is no queue (for usage with LLMEngine), return a list of RequestOutput objects. ****************** NOTE FOR DEVELOPERS ****************** vLLM V1 minimizes the number of python loops over the full - batch to ensure system overheads are minimized. This is the + batch to ensure system overheads are minimized. This is the only function that should loop over EngineCoreOutputs. If you need to touch every element of the batch, do it from within the loop below. - + ********************************************************** """ @@ -337,6 +341,7 @@ def process_outputs( new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason + kv_transfer_params = engine_core_output.kv_transfer_params req_state.is_prefilling = False @@ -352,7 +357,8 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, finish_reason, stop_reason, + kv_transfer_params): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2732b933c28a..e8ce0df5ed8d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -100,12 +100,16 @@ class ModelRunnerOutput: # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] - -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( - req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, -) + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None + + +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=None, + finished_recving=None) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index fde366d61c7d..42a787dff4e6 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -61,6 +61,15 @@ def __init__( self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 + # Disaggregated serving related + self.do_remote_decode = ( + sampling_params.kv_transfer_params is not None + and sampling_params.kv_transfer_params.do_remote_decode) + self.do_remote_prefill = ( + sampling_params.kv_transfer_params is not None + and sampling_params.kv_transfer_params.do_remote_prefill) + self.kv_transfer_params = sampling_params.kv_transfer_params + # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) if self.mm_hashes: @@ -150,6 +159,7 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() + WAITING_FOR_REMOTE_KVS = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED will be considered @@ -158,6 +168,7 @@ class RequestStatus(enum.IntEnum): FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() + FINISHED_REMOTE_DECODE = enum.auto() @staticmethod def is_finished(status: "RequestStatus") -> bool: @@ -178,4 +189,5 @@ def get_finished_reason( RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, + RequestStatus.FINISHED_REMOTE_DECODE: FinishReason.REMOTE_DECODE } diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 81508c2e069b..07097d7da68f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -10,6 +10,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -39,11 +40,9 @@ def __init__( self.hidden_size = vllm_config.model_config.get_hidden_size() - # TODO: make eagle3 compatible with cudagraph - self.use_cuda_graph = self.method != 'eagle3' and \ - (self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager) + self.use_cuda_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) self.cudagraph_batch_sizes = list( reversed( @@ -90,6 +89,12 @@ def propose( batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 + if self.method == "eagle3": + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] self.input_ids[:num_tokens - 1] = target_token_ids[1:] @@ -126,12 +131,7 @@ def propose( # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions - if self.method == 'eagle': - self.hidden_states[:num_tokens] = target_hidden_states - hidden_states = self.hidden_states - else: - # TODO: make eagle3 compatible with cuda graph - hidden_states = target_hidden_states + self.hidden_states[:num_tokens] = target_hidden_states with set_forward_context(attn_metadata, self.vllm_config, @@ -139,7 +139,7 @@ def propose( last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], - hidden_states=hidden_states[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], ) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -209,10 +209,7 @@ def propose( self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions - if self.method == 'eagle': - # TODO: make eagle3 compatible with cudagraph. - self.hidden_states[:batch_size] = hidden_states - hidden_states = self.hidden_states + self.hidden_states[:batch_size] = hidden_states # Run the model. with set_forward_context(attn_metadata, @@ -221,7 +218,7 @@ def propose( last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:input_batch_size], positions=self.positions[:input_batch_size], - hidden_states=hidden_states[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], @@ -314,12 +311,11 @@ def dummy_run( ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - if self.method == 'eagle': - self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], - ) + self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + ) # NOTE(woosuk): Currently, the below code is not used and we always use argmax diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97d8c91b4659..3a8dae04ee0a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import copy import gc import time import weakref @@ -16,8 +17,9 @@ get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import get_pp_group, graph_capture -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model @@ -1017,20 +1019,55 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: - # Update KVConnector with the KVConnector metadata forward(). - if has_kv_transfer_group(): - get_kv_transfer_group().bind_connector_metadata( - scheduler_output.kv_connector_metadata) + + def maybe_setup_kv_connector(): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + def maybe_wait_for_save(): + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + kv_connector.wait_for_save() + + def maybe_get_finished() -> tuple[set[str], set[str]]: + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + return kv_connector.get_finished() + return set(), set() self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: + # KV send/recv even if no work to do. + with set_forward_context(None, self.vllm_config): + maybe_setup_kv_connector() + maybe_wait_for_save() + finished_sending, finished_recving = maybe_get_finished() + # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + output = EMPTY_MODEL_RUNNER_OUTPUT + + if len(finished_sending) > 0 or len(finished_recving) > 0: + output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + return output # Prepare the decoder inputs. + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + attn_metadata, logits_indices, spec_decode_metadata = ( self._prepare_inputs(scheduler_output)) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. @@ -1102,17 +1139,22 @@ def execute_model( with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - output = self.model( + maybe_setup_kv_connector() + + model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + maybe_wait_for_save() + finished_sending, finished_recving = maybe_get_finished() + if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = output + hidden_states, aux_hidden_states = model_output else: - hidden_states = output + hidden_states = model_output if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. @@ -1291,6 +1333,8 @@ def execute_model( spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + finished_sending=finished_sending, + finished_recving=finished_recving, ) def generate_draft_token_ids( @@ -1743,6 +1787,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4df192a8727c..4864163b0de2 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -49,6 +49,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "encoder_input_tokens": self.encoder_input_tokens, "encoder_input_positions": self.encoder_input_positions, @@ -172,10 +173,17 @@ def execute_model( if (model_input.attn_metadata is not None and model_input.attn_metadata.prefill_metadata is None and model_input.attn_metadata.decode_metadata.use_cuda_graph): - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[ - model_input.virtual_engine][graph_batch_size] + if model_input.inputs_embeds is None: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, False)]) + else: + graph_batch_size = model_input.inputs_embeds.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, True)]) else: model_executable = self.model @@ -189,6 +197,7 @@ def execute_model( model_input.virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, encoder_input_ids=model_input.encoder_input_tokens, encoder_positions=model_input.encoder_input_positions, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 73e0eff9a8b7..85814e9af9e3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -35,7 +35,8 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, + get_sampler) from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import supports_lora, supports_multimodal @@ -85,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase): additional fields. """ input_tokens: Optional[torch.Tensor] = None + inputs_embeds: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None @@ -105,6 +107,7 @@ class ModelInputForGPU(ModelRunnerInputBase): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, @@ -155,6 +158,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, @@ -194,6 +198,7 @@ class InterDataForSeqGroup: def simple_reinit(self): self.input_tokens[0].clear() # type: ignore + self.inputs_embeds = None # type: ignore self.input_positions[0].clear() # type: ignore self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore @@ -221,6 +226,7 @@ def __init__( # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, input_positions: Optional[List[List[int]]] = None, token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, @@ -282,6 +288,8 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.input_tokens[seq_id].clear() + self.inputs_embeds = inputs_embeds + if input_positions: self.input_positions = input_positions else: @@ -356,6 +364,7 @@ def __init__( else: self.input_tokens = input_tokens or [] + self.inputs_embeds = inputs_embeds self.input_positions = input_positions or [] self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None @@ -401,6 +410,26 @@ def __post_init__(self): self.lora_index_mapping = [] self.lora_prompt_mapping = [] + def __repr__(self) -> str: + return (f"InterDataForSeqGroup(" + f"request_id={self.request_id}, " + f"seq_ids={self.seq_ids}, " + f"is_prompt={self.is_prompt}, " + f"block_tables={self.block_tables}, " + f"computed_block_nums={self.computed_block_nums}, " + f"n_seqs={self.n_seqs}, " + f"input_tokens={self.input_tokens}, " + f"inputs_embeds.shape=" + f"{getattr(self.inputs_embeds, 'shape', None)}, " + f"input_positions={self.input_positions}, " + f"token_types={self.token_types}, " + f"mrope_input_positions={self.mrope_input_positions}, " + f"seq_lens={self.seq_lens}, " + f"orig_seq_lens={self.orig_seq_lens}, " + f"query_lens={self.query_lens}, " + f"context_lens={self.context_lens}, " + f"multi_modal_kwargs={self.multi_modal_kwargs}") + def gen_inter_data_builder(self, num_seqs: int): return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( request_id="", @@ -511,13 +540,21 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_data.get_num_computed_tokens() # Compute tokens. - tokens = seq_data.get_token_ids()[context_len:seq_len] + if seq_data.prompt_embeds is None: + tokens = seq_data.get_token_ids()[context_len:seq_len] + prompt_embeds = None + else: + tokens = [0] * (seq_len - context_len) + prompt_embeds = seq_data.get_token_embeddings( + )[context_len:seq_len] + token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) + inter_data.inputs_embeds = prompt_embeds inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.token_types[seq_idx].extend( token_types if token_types else []) @@ -822,15 +859,29 @@ def build(self) -> ModelInputForGPU: create on-device tensors. """ # Combine and flatten intermediate data. - input_tokens = [] - token_types = [] + input_tokens = list[int]() + inputs_embeds_lst = list[torch.Tensor]() + token_types = list[int]() for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) + if inter_data.inputs_embeds is not None: + inputs_embeds_lst.append( + inter_data.inputs_embeds.to( + dtype=self.runner.model_config.dtype, + device=self.runner.device)) + inputs_embeds: Optional[torch.Tensor] + if len(inputs_embeds_lst) == 0: + inputs_embeds = None + else: + inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( + dtype=self.runner.model_config.dtype, + device=self.runner.device) + assert len(inputs_embeds) == len(input_tokens) - if not input_tokens: + if not input_tokens and inputs_embeds is None: # This may happen when all prefill requests hit # prefix caching and there is no decode request. return self.model_input_cls() @@ -980,6 +1031,7 @@ def build(self) -> ModelInputForGPU: return self.model_input_cls( input_tokens=input_tokens_tensor, + inputs_embeds=inputs_embeds, input_positions=input_positions_tensor, token_types=token_types_tensor, attn_metadata=attn_metadata, @@ -1029,7 +1081,8 @@ def __init__( self.max_batchsize_to_capture = \ self.vllm_config.compilation_config.max_capture_size - self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ + # + self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) ] self.graph_memory_pool: Optional[Tuple[ @@ -1466,6 +1519,10 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long, device=self.device) + inputs_embeds = torch.zeros( + (max_batch_size, self.model_config.get_hidden_size()), + dtype=self.model_config.dtype, + device=self.device) if self.model_config.uses_mrope: input_positions = torch.tile(input_positions, (3, 1)).cuda(device=self.device) @@ -1503,15 +1560,22 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): - # Only rank 0 should print progress bar during capture - cudagraph_capture_sizes = (tqdm( - self.vllm_config.compilation_config. + # We need to not only iterate over batch sizes, but also whether + # to use inputs_embeds or not, hence we use the cartesian + # product. + cudagraph_capture_sizes = self.vllm_config.compilation_config\ + .cudagraph_capture_sizes + cudagraph_inputs_embeds = (True, False) + compilation_cases = itertools.product( cudagraph_capture_sizes, - desc="Capturing CUDA graph shapes", - ) if get_tensor_model_parallel_rank() == 0 else - self.vllm_config.compilation_config. - cudagraph_capture_sizes) - for batch_size in cudagraph_capture_sizes: + cudagraph_inputs_embeds, + ) + # Only rank 0 should print progress bar during capture + if get_tensor_model_parallel_rank() == 0: + compilation_cases = tqdm( + list(compilation_cases), + desc="Capturing CUDA graph shapes") + for batch_size, use_inputs_embeds in compilation_cases: attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size, @@ -1542,6 +1606,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: capture_inputs = { "input_ids": input_tokens[:batch_size], + "inputs_embeds": + inputs_embeds[:batch_size] + if use_inputs_embeds else None, "positions": input_positions[..., :batch_size], "intermediate_inputs": @@ -1578,8 +1645,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: virtual_engine): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[virtual_engine][batch_size] = ( - graph_runner) + self.graph_runners[virtual_engine][( + batch_size, use_inputs_embeds)] = graph_runner if self.lora_config: self._remove_dummy_loras() @@ -1711,8 +1778,9 @@ def execute_model( if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[virtual_engine][ - graph_batch_size] + use_inputs_embeds = model_input.inputs_embeds is not None + model_executable = self.graph_runners[virtual_engine][( + graph_batch_size, use_inputs_embeds)] if previous_hidden_states is not None: previous_hidden_states = torch.cat([ previous_hidden_states, @@ -1763,6 +1831,7 @@ def execute_model( self.vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, @@ -1817,6 +1886,11 @@ def execute_model( model_input.async_callback() # Sample the next token. + assert isinstance(self.sampler, Sampler) + orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor + if model_input.inputs_embeds is not None: + self.sampler.include_gpu_probs_tensor = True + output: SamplerOutput = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, @@ -1838,6 +1912,18 @@ def execute_model( output.model_forward_time = (orig_model_forward_time + model_forward_time) + if model_input.inputs_embeds is not None: + self.sampler.include_gpu_probs_tensor = \ + orig_include_gpu_probs_tensor + if output.sampled_token_ids is not None: + output.sampled_token_embeds = self.model.get_input_embeddings( + output.sampled_token_ids.squeeze(1)) + + for token_embed, sequence_group_output in zip( + output.sampled_token_embeds, output.outputs): + assert len(sequence_group_output.samples) == 1 + sequence_group_output.samples[0].output_embed = token_embed + if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None @@ -1931,6 +2017,7 @@ def graph(self): def capture( self, input_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor], positions: torch.Tensor, intermediate_inputs: Optional[IntermediateTensors], kv_caches: List[torch.Tensor], @@ -1947,6 +2034,7 @@ def capture( for _ in range(_NUM_WARMUP_ITERS): self.model( input_ids=input_ids, + inputs_embeds=inputs_embeds, positions=positions, intermediate_tensors=intermediate_inputs, **kwargs, @@ -1959,6 +2047,9 @@ def capture( with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): output_hidden_or_intermediate_states = self.model( input_ids=input_ids, + **({ + "inputs_embeds": inputs_embeds, + } if inputs_embeds is not None else {}), positions=positions, intermediate_tensors=intermediate_inputs, **kwargs, @@ -1986,6 +2077,9 @@ def capture( self.input_buffers = { "input_ids": input_ids, + **({ + "inputs_embeds": inputs_embeds, + } if inputs_embeds is not None else {}), "positions": positions, "kv_caches": @@ -2006,6 +2100,7 @@ def capture( def forward( self, input_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], **kwargs, @@ -2020,6 +2115,9 @@ def forward( # so the shape is not padded, we need to copy partial only self.input_buffers["positions"][:positions.shape[0]].copy_( positions, non_blocking=True) + if inputs_embeds is not None: + self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_( + inputs_embeds, non_blocking=True) if self.backend_name != "NO_ATTENTION": self.input_buffers["slot_mapping"].copy_( diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index cbd5e2060cad..fdb7353f2f9c 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -84,10 +84,17 @@ def execute_model( # explore how to leverage it. if (prefill_meta is None and decode_meta is not None and decode_meta.use_cuda_graph): - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[virtual_engine][ - graph_batch_size] + if model_input.inputs_embeds is None: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, False)]) + else: + graph_batch_size = model_input.inputs_embeds.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, True)]) else: model_executable = self.model