From 4730522a0d12b4e490bd192e248f5c868454acc7 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Thu, 17 Apr 2025 13:56:09 -0700 Subject: [PATCH 001/119] [Update] LMcache connector v1 implementation Signed-off-by: ApostaC --- .../kv_transfer/kv_connector/factory.py | 5 + .../kv_connector/v1/lmcache_connector.py | 131 ++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 665ea2f5ba01..6532c101a4f6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -100,3 +100,8 @@ def create_connector_v1( "SharedStorageConnector", "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", "SharedStorageConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnectorV1", + "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", + "LMCacheConnectorV1") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py new file mode 100644 index 000000000000..e07f185f0dd8 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING + +import torch +from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class LMCacheConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + self._lmcache_engine.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + self._lmcache_engine.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, + **kwargs) + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + self._lmcache_engine.wait_for_save() + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self._lmcache_engine.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + self._lmcache_engine.update_state_after_alloc(request, + num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + return self._lmcache_engine.build_connector_meta(scheduler_output) From 4162650e50342883192c6b9e075a5abc984a878b Mon Sep 17 00:00:00 2001 From: ApostaC Date: Thu, 17 Apr 2025 13:58:13 -0700 Subject: [PATCH 002/119] [Add] examples for disaggregated prefill Signed-off-by: ApostaC --- .../configs/lmcache-decoder-config.yaml | 13 ++ .../configs/lmcache-prefiller-config.yaml | 13 ++ examples/other/LMCache/disagg-example.sh | 67 ++++++ examples/other/LMCache/disagg_proxy_server.py | 193 ++++++++++++++++++ 4 files changed, 286 insertions(+) create mode 100644 examples/other/LMCache/configs/lmcache-decoder-config.yaml create mode 100644 examples/other/LMCache/configs/lmcache-prefiller-config.yaml create mode 100644 examples/other/LMCache/disagg-example.sh create mode 100644 examples/other/LMCache/disagg_proxy_server.py diff --git a/examples/other/LMCache/configs/lmcache-decoder-config.yaml b/examples/other/LMCache/configs/lmcache-decoder-config.yaml new file mode 100644 index 000000000000..c3f5a0ae69c0 --- /dev/null +++ b/examples/other/LMCache/configs/lmcache-decoder-config.yaml @@ -0,0 +1,13 @@ +local_cpu: False +max_local_cpu_size: 0 +#local_disk: +max_local_disk_size: 0 +remote_serde: NULL + +enable_nixl: True +nixl_role: "receiver" +nixl_peer_host: "localhost" +nixl_peer_port: 55555 +nixl_buffer_size: 1073741824 # 1GB +nixl_buffer_device: "cuda" +nixl_enable_gc: True diff --git a/examples/other/LMCache/configs/lmcache-prefiller-config.yaml b/examples/other/LMCache/configs/lmcache-prefiller-config.yaml new file mode 100644 index 000000000000..8b0e82958a64 --- /dev/null +++ b/examples/other/LMCache/configs/lmcache-prefiller-config.yaml @@ -0,0 +1,13 @@ +local_cpu: False +max_local_cpu_size: 0 +#local_disk: +max_local_disk_size: 0 +remote_serde: NULL + +enable_nixl: True +nixl_role: "sender" +nixl_peer_host: "localhost" +nixl_peer_port: 55555 +nixl_buffer_size: 1073741824 # 1GB +nixl_buffer_device: "cuda" +nixl_enable_gc: True diff --git a/examples/other/LMCache/disagg-example.sh b/examples/other/LMCache/disagg-example.sh new file mode 100644 index 000000000000..7d81c0570e22 --- /dev/null +++ b/examples/other/LMCache/disagg-example.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [model]" + exit 1 +fi + +if [[ $# -eq 1 ]]; then + echo "Using default model: meta-llama/Llama-3.1-8B-Instruct" + MODEL="meta-llama/Llama-3.1-8B-Instruct" +else + echo "Using model: $2" + MODEL=$2 +fi + + +if [[ $1 == "prefiller" ]]; then + # Prefiller listens on port 8100 + prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + LMCACHE_CONFIG_FILE=$prefill_config_file \ + LMCACHE_USE_EXPERIMENTAL=True \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + CUDA_VISIBLE_DEVICES=0 \ + vllm serve $MODEL \ + --port 8100 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' + +elif [[ $1 == "decoder" ]]; then + # Decoder listens on port 8200 + decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + LMCACHE_CONFIG_FILE=$decode_config_file \ + LMCACHE_USE_EXPERIMENTAL=True \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + CUDA_VISIBLE_DEVICES=1 \ + vllm serve $MODEL \ + --port 8200 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' + +elif [[ $1 == "proxy" ]]; then + # Proxy listens on port 9000 + python3 $SCRIPT_DIR/disagg_proxy_server.py \ + --host localhost \ + --port 9000 \ + --prefiller-host localhost \ + --prefiller-port 8100 \ + --decoder-host localhost \ + --decoder-port 8200 + +else + echo "Invalid role: $1" + echo "Should be either prefill, decode, or proxy" + exit 1 +fi diff --git a/examples/other/LMCache/disagg_proxy_server.py b/examples/other/LMCache/disagg_proxy_server.py new file mode 100644 index 000000000000..8db93bc8931b --- /dev/null +++ b/examples/other/LMCache/disagg_proxy_server.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import time +from contextlib import asynccontextmanager + +import httpx +import numpy as np +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' + decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +class StatsCalculator: + + def __init__(self): + self._stats = [] + self._last_log_time = time.time() + + def add(self, value): + self._stats.append(value) + if time.time() - self._last_log_time > 5: + self._log_stats() + self._last_log_time = time.time() + + def _log_stats(self): + # Print average, median, and 99th percentile + np_arr = np.array(self._stats) + output_str = f"\nNum requests: {len(self._stats)}" + \ + "\nPrefill node TTFT stats:" + \ + f"\n - Average (ms): {np.mean(np_arr)}" + \ + f"\n - Median (ms): {np.median(np_arr)}" + \ + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" + print("===============================", output_str, + "===============================") + + +stats_calculator = StatsCalculator() +counter = 0 + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['max_tokens'] = 1 + if 'max_completion_tokens' in req_data: + req_data['max_completion_tokens'] = 1 + + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, "/completions", + req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, + "/chat/completions", req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/chat/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server " + " - chat completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) From 3ccd34c9394086304078a9c1ddfefe5c70340d6b Mon Sep 17 00:00:00 2001 From: ApostaC Date: Thu, 17 Apr 2025 17:26:48 -0700 Subject: [PATCH 003/119] [add] extra information about evns Signed-off-by: ApostaC --- examples/other/LMCache/disagg-example.sh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/other/LMCache/disagg-example.sh b/examples/other/LMCache/disagg-example.sh index 7d81c0570e22..43b0b59c88f8 100644 --- a/examples/other/LMCache/disagg-example.sh +++ b/examples/other/LMCache/disagg-example.sh @@ -33,6 +33,10 @@ if [[ $1 == "prefiller" ]]; then --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' + # Potential Env vars and cmdline options + # LMCACHE_LOG_LEVEL=DEBUG -- Set log level to DEBUG + # --enforce-eager -- Enforce eager mode + elif [[ $1 == "decoder" ]]; then # Decoder listens on port 8200 decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml @@ -50,6 +54,10 @@ elif [[ $1 == "decoder" ]]; then --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' + # Potential Env vars and cmdline options + # LMCACHE_LOG_LEVEL=DEBUG -- Set log level to DEBUG + # --enforce-eager -- Enforce eager mode + elif [[ $1 == "proxy" ]]; then # Proxy listens on port 9000 python3 $SCRIPT_DIR/disagg_proxy_server.py \ From 161010c3847204d441bcc6ec91709d324071e954 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 18 Apr 2025 15:38:00 -0400 Subject: [PATCH 004/119] Initial stubs for P/D scheduling changes Signed-off-by: Tyler Michael Smith --- .../kv_transfer/kv_connector/v1/base.py | 6 +++- .../v1/shared_storage_connector.py | 8 +++-- vllm/v1/core/sched/scheduler.py | 31 ++++++++++++++++++- vllm/v1/request.py | 3 ++ vllm/v1/worker/gpu_model_runner.py | 4 +++ 5 files changed, 47 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 95967d2ca919..a335f43d3ad3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -196,7 +196,9 @@ def update_state_after_alloc(self, request: "Request", @abstractmethod def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput, + sending_KV_req_ids: set[str], + waiting_KV_req_ids: set[str]) -> KVConnectorMetadata: """ Build the connector metadata for this step. @@ -205,5 +207,7 @@ def build_connector_meta( Args: scheduler_output (SchedulerOutput): the scheduler output object. + sending_KV_req_ids (set[str]): Request IDs to send + waiting_KV_req_ids (set[str]): Request IDs to receive """ pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 1d2040784e6c..fb1f1e24da0a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -271,9 +271,9 @@ def update_state_after_alloc(self, request: "Request", self._requests_need_load[request.request_id] = request def build_connector_meta( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput, + sending_KV_req_ids: set[str], + waiting_KV_req_ids: set[str]) -> KVConnectorMetadata: """Build the connector metadata for this step. This function should NOT modify any fields in the scheduler_output. @@ -281,6 +281,8 @@ def build_connector_meta( Args: scheduler_output (SchedulerOutput): the scheduler output object. + sending_KV_req_ids (set[str]): Request IDs to send + waiting_KV_req_ids (set[str]): Request IDs to receive """ meta = SharedStorageConnectorMetadata() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7e658d134cf7..75f42449c267 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -98,6 +98,10 @@ def __init__( # This is flushed at the end of each scheduling step. self.finished_req_ids: set[str] = set() + # Requests in states for tracking KV transfers for P/D disagg + self.sending_KV_req_ids: set[str] = set() + self.waiting_KV_req_ids: set[str] = set() + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> CachedRequestData @@ -167,6 +171,21 @@ def schedule(self) -> SchedulerOutput: # For logging. scheduled_timestamp = time.monotonic() + # Check for new remote decode requests for P/D + if self.connector is not None: + self.waiting_KV_req_ids.update( + self.connector.receive_remote_decode_requests()) + + # Check if any P/D requests have finished sending or receiving + for req_id in list(self.sending_KV_req_ids): + if self.connector.done_sending_remote_decode_request(req_id): + self.sending_KV_req_ids.remove(req_id) + self.finished_req_ids.add(req_id) + for req_id in list(self.waiting_KV_req_ids): + if self.connector.done_waiting_remote_decode_request(req_id): + self.waiting_KV_req_ids.remove(req_id) + self.waiting.append(self.requests[req_id]) + # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: @@ -479,7 +498,9 @@ def schedule(self) -> SchedulerOutput: # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) + meta = self.connector.build_connector_meta(scheduler_output, + self.sending_KV_req_ids, + self.waiting_KV_req_ids) scheduler_output.kv_connector_metadata = meta # Advance the number of computed tokens for the request AFTER @@ -682,6 +703,7 @@ def update_from_output( # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. + # TODO: What if we detect we're done here when doing P/D disagg? stopped = check_stop(request, self.max_model_len) if stopped: self._free_request(request) @@ -718,6 +740,13 @@ def update_from_output( # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors + if self.connector is not None and request.do_remote_decode: + stopped = True + + self.sending_KV_req_ids.add(req_id) + self.connector.send_remote_decode_request( + self.kv_cache_manager.req_to_blocks[req_id]) + self.scheduled_req_ids.remove(req_id) if not stopped: new_running.append(request) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 6be72431dde5..7c7803560bc8 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -61,6 +61,9 @@ def __init__( self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 + # P/D disagg related + self.do_remote_decode = False + # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) if self.mm_hashes: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ac0701c45986..00026731d516 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -991,6 +991,10 @@ def execute_model( ) -> Union[ModelRunnerOutput, torch.Tensor]: # Update KVConnector with the KVConnector metadata forward(). if has_kv_transfer_group(): + # Background KV cache transfers can happen here, + # since kv_connector_metadata has the req_ids to send/receive. + # Not sure I like doing it here since this does not have to do + # with model execution but this way we don't do a separate rpc. get_kv_transfer_group().bind_connector_metadata( scheduler_output.kv_connector_metadata) From 1f708e96fd5ef64ee207a81961d7283f09d48131 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sat, 19 Apr 2025 14:26:00 -0400 Subject: [PATCH 005/119] Updates Signed-off-by: Tyler Michael Smith --- .../kv_transfer/kv_connector/v1/base.py | 31 +++++++------ .../kv_connector/v1/lmcache_connector.py | 21 ++++++--- .../v1/shared_storage_connector.py | 41 +++++++++-------- vllm/forward_context.py | 21 --------- vllm/v1/core/sched/output.py | 5 ++- vllm/v1/core/sched/scheduler.py | 44 +++++++++++++------ vllm/v1/request.py | 1 + vllm/v1/worker/gpu_model_runner.py | 28 ++++++++---- 8 files changed, 108 insertions(+), 84 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index a335f43d3ad3..fc67d118070e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -6,7 +6,7 @@ The class provides the following primitives: Scheduler-side: runs in the scheduler, binds metadata, which is used by the worker-side to load/save KV cache. - get_num_new_matched_tokens() - get number of new tokens + get_num_new_matched_tokens() - get number of new tokens that exist in the remote KV cache update_state_after_alloc() - update KVConnector state after temporary buffer alloc by the CacheManager. @@ -70,7 +70,7 @@ def bind_connector_metadata( self, connector_metadata: KVConnectorMetadata) -> None: """Set the connector metadata from the scheduler. - This function should be called by the model runner every time + This function should be called by the model runner every time before the model execution. The metadata will be used for runtime KV cache loading and saving. @@ -82,7 +82,7 @@ def bind_connector_metadata( def clear_connector_metadata(self) -> None: """Clear the connector metadata. - This function should be called by the model runner every time + This function should be called by the model runner every time after the model execution. """ self._connector_metadata = KVConnectorMetadata() @@ -114,9 +114,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. - + """ pass @@ -126,7 +126,7 @@ def wait_for_layer_load(self, layer_name: str) -> None: Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. - + This interface will be useful for layer-by-layer pipelining. Args: @@ -138,13 +138,13 @@ def wait_for_layer_load(self, layer_name: str) -> None: def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs) -> None: """ - Start saving a layer of KV cache from vLLM's paged buffer + Start saving a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -174,14 +174,14 @@ def get_num_new_matched_tokens( """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ pass @@ -196,9 +196,7 @@ def update_state_after_alloc(self, request: "Request", @abstractmethod def build_connector_meta( - self, scheduler_output: SchedulerOutput, - sending_KV_req_ids: set[str], - waiting_KV_req_ids: set[str]) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: """ Build the connector metadata for this step. @@ -211,3 +209,10 @@ def build_connector_meta( waiting_KV_req_ids (set[str]): Request IDs to receive """ pass + + # These return true for now since they are not async + def is_request_done_sending(self, req_id: str) -> bool: + raise NotImplementedError + + def is_request_done_receiving(self, req_id: str) -> bool: + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index e07f185f0dd8..3b64c14361a4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -39,9 +39,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. - + """ self._lmcache_engine.start_load_kv(forward_context, **kwargs) @@ -50,7 +50,7 @@ def wait_for_layer_load(self, layer_name: str) -> None: Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. - + This interface will be useful for layer-by-layer pipelining. Args: @@ -61,13 +61,13 @@ def wait_for_layer_load(self, layer_name: str) -> None: def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs) -> None: """ - Start saving the a layer of KV cache from vLLM's paged buffer + Start saving the a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -96,14 +96,14 @@ def get_num_new_matched_tokens( """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ return self._lmcache_engine.get_num_new_matched_tokens( @@ -129,3 +129,10 @@ def build_connector_meta( scheduler_output (SchedulerOutput): the scheduler output object. """ return self._lmcache_engine.build_connector_meta(scheduler_output) + + # These return true for now since they are not async + def is_request_done_sending(self, req_id: str) -> bool: + return True + + def is_request_done_receiving(self, req_id: str) -> bool: + return True diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index fb1f1e24da0a..9e4ce253b618 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -85,7 +85,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: - """Start loading the KV cache from the connector buffer to vLLM's + """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. Args: @@ -93,7 +93,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. """ attn_metadata = forward_context.attn_metadata @@ -106,13 +106,13 @@ def inject_kv_into_layer( """Inject the KV cache into the layer. Args: - dst_kv_cache_layer (torch.Tensor): the destination KV cache - layer. In shape [2, num_pages, page_size, xxx] if not + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not using MLA, [num_pages, page_size, xxx] otherwise. src_kv_cache (torch.Tensor): the source KV cache. In shape - [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] otherwise. - slot_mapping (torch.Tensor): the slot mapping. In shape + slot_mapping (torch.Tensor): the slot mapping. In shape [num_tokens]. """ dst_kv_cache_layer_shape = dst_kv_cache_layer.shape @@ -168,8 +168,8 @@ def inject_kv_into_layer( def wait_for_layer_load(self, layer_name: str) -> None: """Blocking until the KV for a specific layer is loaded into vLLM's - paged buffer. - + paged buffer. + This interface will be useful for layer-by-layer pipelining. Args: @@ -179,12 +179,12 @@ def wait_for_layer_load(self, layer_name: str) -> None: def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs) -> None: - """Start saving the KV cache of the layer from vLLM's paged buffer + """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -229,14 +229,14 @@ def get_num_new_matched_tokens( """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ @@ -271,9 +271,7 @@ def update_state_after_alloc(self, request: "Request", self._requests_need_load[request.request_id] = request def build_connector_meta( - self, scheduler_output: SchedulerOutput, - sending_KV_req_ids: set[str], - waiting_KV_req_ids: set[str]) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: """Build the connector metadata for this step. This function should NOT modify any fields in the scheduler_output. @@ -281,8 +279,6 @@ def build_connector_meta( Args: scheduler_output (SchedulerOutput): the scheduler output object. - sending_KV_req_ids (set[str]): Request IDs to send - waiting_KV_req_ids (set[str]): Request IDs to receive """ meta = SharedStorageConnectorMetadata() @@ -333,6 +329,13 @@ def build_connector_meta( self._requests_need_load.clear() return meta + # These return true for now since they are not async + def is_request_done_sending(self, req_id: str) -> bool: + return True + + def is_request_done_receiving(self, req_id: str) -> bool: + return True + # ============================== # Helper functions # ============================== @@ -355,7 +358,7 @@ def _generate_foldername_debug( input_ids: torch.Tensor, create_folder=False, ) -> str: - """Generate a folder name based on the hash of the bytes of the input + """Generate a folder name based on the hash of the bytes of the input ids. """ input_ids_bytes = input_ids.numpy().tobytes() @@ -370,7 +373,7 @@ def _generate_filename_debug( layer_name: str, input_ids: torch.Tensor, ) -> str: - """Generate a file name based on the layer name and the hash + """Generate a file name based on the layer name and the hash of the bytes of the input ids. """ foldername = self._generate_foldername_debug(input_ids, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 06790d8ee2f8..5b030c126b3f 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,10 +11,6 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger if TYPE_CHECKING: @@ -103,16 +99,6 @@ def set_forward_context(attn_metadata: Any, attn_metadata=attn_metadata, dp_metadata=dp_metadata) - # KVConnector: trigger (possibly async) load before forward. - # Each attn layer will block until the reading is complete. - trigger_kv_transfer = (attn_metadata is not None - and has_kv_transfer_group() - and is_v1_kv_transfer_group()) - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.start_load_kv(_forward_context) - try: yield finally: @@ -149,11 +135,4 @@ def set_forward_context(attn_metadata: Any, "(batchsize, count, median_time(ms)): %s"), forward_stats) - # KVConnector: each attn layer triggers (possibly async) save. - # Ensure all those operations complete before forward() is done. - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.wait_for_save() - _forward_context = prev_context diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 1d3f1f41f8fb..6263743d9710 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: @@ -126,3 +126,6 @@ class SchedulerOutput: # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None + sending_KV_req_ids: set[str] = field(default_factory=set) + receiving_KV_req_ids: set[str] = field(default_factory=set) + new_KV_requests_to_send: list[NewRequestData] = field(default_factory=list) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e86cc469828a..d3e562594aa1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -99,8 +99,9 @@ def __init__( self.finished_req_ids: set[str] = set() # Requests in states for tracking KV transfers for P/D disagg + self.waiting_to_send_KV_req_ids: set[str] = set() self.sending_KV_req_ids: set[str] = set() - self.waiting_KV_req_ids: set[str] = set() + self.receiving_KV_req_ids: set[str] = set() # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. @@ -172,18 +173,25 @@ def schedule(self) -> SchedulerOutput: scheduled_timestamp = time.monotonic() # Check for new remote decode requests for P/D + new_KV_requests_to_send: list[Request] = [] if self.connector is not None: - self.waiting_KV_req_ids.update( - self.connector.receive_remote_decode_requests()) + # TODO: Receive request over ZMQ + # self.receiving_KV_req_ids.update( + # self.connector.receive_remote_decode_requests()) # Check if any P/D requests have finished sending or receiving + for req_id in list(self.waiting_to_send_KV_req_ids): + self.sending_KV_req_ids.add(req_id) + self.waiting_to_send_KV_req_ids.remove(req_id) + new_KV_requests_to_send.append(self.requests[req_id]) + for req_id in list(self.sending_KV_req_ids): - if self.connector.done_sending_remote_decode_request(req_id): + if self.connector.is_request_done_sending(req_id): self.sending_KV_req_ids.remove(req_id) self.finished_req_ids.add(req_id) - for req_id in list(self.waiting_KV_req_ids): - if self.connector.done_waiting_remote_decode_request(req_id): - self.waiting_KV_req_ids.remove(req_id) + for req_id in list(self.receiving_KV_req_ids): + if self.connector.is_request_done_receiving(req_id): + self.receiving_KV_req_ids.remove(req_id) self.waiting.append(self.requests[req_id]) # First, schedule the RUNNING requests. @@ -498,11 +506,19 @@ def schedule(self) -> SchedulerOutput: # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output, - self.sending_KV_req_ids, - self.waiting_KV_req_ids) + meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta + # TODO: encapsulate these in the KV connector metadata + scheduler_output.sending_KV_req_ids = self.sending_KV_req_ids + scheduler_output.receiving_KV_req_ids = self.receiving_KV_req_ids + new_KV_to_send_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_block_ids[req.request_id]) + for req in new_KV_requests_to_send + ] + scheduler_output.new_KV_requests_to_send = new_KV_to_send_reqs_data + # Advance the number of computed tokens for the request AFTER # the request is scheduled. # 1. The scheduler_output of the current step has to include the @@ -742,10 +758,10 @@ def update_from_output( if self.connector is not None and request.do_remote_decode: stopped = True - - self.sending_KV_req_ids.add(req_id) - self.connector.send_remote_decode_request( - self.kv_cache_manager.req_to_blocks[req_id]) + self.waiting_to_send_KV_req_ids.add(req_id) + # TODO: Add ZMQ request + #self.connector.send_remote_decode_request( + # self.kv_cache_manager.req_to_blocks[req_id]) self.scheduled_req_ids.remove(req_id) if not stopped: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 7c7803560bc8..60b4ee739fec 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -154,6 +154,7 @@ class RequestStatus(enum.IntEnum): WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() RUNNING = enum.auto() + SENDING_KV = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED will be considered # as a finished status. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3a774f8fc7a2..c780bbe40934 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -15,8 +15,9 @@ from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import get_pp_group, graph_capture -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -998,14 +999,22 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: - # Update KVConnector with the KVConnector metadata forward(). - if has_kv_transfer_group(): - # Background KV cache transfers can happen here, - # since kv_connector_metadata has the req_ids to send/receive. - # Not sure I like doing it here since this does not have to do - # with model execution but this way we don't do a separate rpc. - get_kv_transfer_group().bind_connector_metadata( - scheduler_output.kv_connector_metadata) + + def maybe_setup_kv_connector(): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + if get_forward_context().attn_metadata is not None: + kv_connector.start_load_kv(get_forward_context()) self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: @@ -1078,6 +1087,7 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): + maybe_setup_kv_connector() hidden_states = self.model( input_ids=input_ids, positions=positions, From 038f2f81288d7e45f787afdad9f465699211a151 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Sun, 20 Apr 2025 13:15:12 -0400 Subject: [PATCH 006/119] Rs branch (#3) * updated Signed-off-by: rshaw@neuralmagic.com --- .../openai_completion_client.py | 15 ++-- .../openai_completion_client2.py | 39 ++++++++++ examples/other/LMCache/disagg-example.sh | 6 +- examples/other/LMCache/disagg_proxy_server.py | 8 +-- vllm/entrypoints/openai/protocol.py | 20 +++++- vllm/outputs.py | 6 +- vllm/sampling_params.py | 32 +++++++++ vllm/v1/core/sched/output.py | 2 +- vllm/v1/core/sched/scheduler.py | 72 +++++++++++-------- vllm/v1/engine/__init__.py | 7 +- vllm/v1/engine/output_processor.py | 12 +++- vllm/v1/request.py | 9 ++- vllm/v1/worker/gpu_model_runner.py | 6 ++ 13 files changed, 178 insertions(+), 56 deletions(-) create mode 100644 examples/online_serving/openai_completion_client2.py diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index 6ab7619bff19..f1e6175268de 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -4,7 +4,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" +openai_api_base = "http://localhost:9000/v1" def main(): @@ -14,18 +14,17 @@ def main(): base_url=openai_api_base, ) - models = client.models.list() - model = models.data[0].id + # models = client.models.list() + # model = models.data[0].id # Completion API stream = False completion = client.completions.create( - model=model, - prompt="A robot may not injure a human being", + model="meta-llama/Llama-3.1-8B-Instruct", + prompt= + "The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, which means that there is a lot of opportunity to work with community and customers on key infrastructure projects. This means", # noqa: E501 echo=False, - n=2, - stream=stream, - logprobs=3) + stream=stream) print("-" * 50) print("Completion results:") diff --git a/examples/online_serving/openai_completion_client2.py b/examples/online_serving/openai_completion_client2.py new file mode 100644 index 000000000000..fb6d0d120b5e --- /dev/null +++ b/examples/online_serving/openai_completion_client2.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 + +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8100/v1" + + +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + # models = client.models.list() + # model = models.data[0].id + + # Completion API + stream = False + completion = client.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + prompt="The quick brown job jumped", + echo=False, + stream=stream) + + print("-" * 50) + print("Completion results:") + if stream: + for c in completion: + print(c) + else: + print(completion) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/other/LMCache/disagg-example.sh b/examples/other/LMCache/disagg-example.sh index 43b0b59c88f8..89f1c753e887 100644 --- a/examples/other/LMCache/disagg-example.sh +++ b/examples/other/LMCache/disagg-example.sh @@ -25,10 +25,9 @@ if [[ $1 == "prefiller" ]]; then LMCACHE_USE_EXPERIMENTAL=True \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ - CUDA_VISIBLE_DEVICES=0 \ + CUDA_VISIBLE_DEVICES=6 \ vllm serve $MODEL \ --port 8100 \ - --disable-log-requests \ --enforce-eager \ --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' @@ -46,10 +45,9 @@ elif [[ $1 == "decoder" ]]; then LMCACHE_USE_EXPERIMENTAL=True \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ - CUDA_VISIBLE_DEVICES=1 \ + CUDA_VISIBLE_DEVICES=7 \ vllm serve $MODEL \ --port 8200 \ - --disable-log-requests \ --enforce-eager \ --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' diff --git a/examples/other/LMCache/disagg_proxy_server.py b/examples/other/LMCache/disagg_proxy_server.py index 8db93bc8931b..2639409a1522 100644 --- a/examples/other/LMCache/disagg_proxy_server.py +++ b/examples/other/LMCache/disagg_proxy_server.py @@ -88,13 +88,12 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, Send a request to a service using a persistent client. """ req_data = req_data.copy() - req_data['max_tokens'] = 1 - if 'max_completion_tokens' in req_data: - req_data['max_completion_tokens'] = 1 - + req_data['do_remote_decode'] = True + req_data["stream"] = False headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} response = await client.post(endpoint, json=req_data, headers=headers) response.raise_for_status() + return response @@ -104,6 +103,7 @@ async def stream_service_response(client: httpx.AsyncClient, endpoint: str, Asynchronously stream the response from a service using a persistent client. """ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + req_data['do_remote_prefill'] = True async with client.stream("POST", endpoint, json=req_data, headers=headers) as response: response.raise_for_status() diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8d2ab29d221e..c56a15af1367 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -17,7 +17,8 @@ from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, - RequestOutputKind, SamplingParams) + KVTransferParams, RequestOutputKind, + SamplingParams) from vllm.sequence import Logprob from vllm.utils import random_uuid, resolve_obj_by_qualname @@ -807,6 +808,14 @@ class CompletionRequest(OpenAIBaseModel): " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + do_remote_decode: bool = Field( + default=False, + description="KVTransfer parameters used for disaggregated serving.") + + do_remote_prefill: bool = Field( + default=False, + description="KVTransfer parameters used for disaggregated serving.") + # doc: end-completion-extra-params # Default sampling parameters for completion requests @@ -904,6 +913,11 @@ def to_sampling_params( whitespace_pattern=self.guided_whitespace_pattern, ) + kv_transfer_params = KVTransferParams.from_optional( + do_remote_decode=self.do_remote_decode, + do_remote_prefill=self.do_remote_prefill, + ) + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -932,7 +946,9 @@ def to_sampling_params( else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids) + allowed_token_ids=self.allowed_token_ids, + kv_transfer_params=kv_transfer_params, + ) @model_validator(mode="before") @classmethod diff --git a/vllm/outputs.py b/vllm/outputs.py index 014e8d5d8823..06206ea3e1ec 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import KVTransferParams, RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceGroupBase, SequenceStatus) @@ -103,6 +103,7 @@ class RequestOutput: encoder_prompt_token_ids: The token IDs of the encoder prompt. None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. + kv_transfer_params: The params for remote K/V transfer. """ def __init__( @@ -120,6 +121,7 @@ def __init__( num_cached_tokens: Optional[int] = None, *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, + kv_transfer_params: Optional[KVTransferParams] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -133,11 +135,13 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens + self.kv_transfer_params = kv_transfer_params def add(self, next_output: "RequestOutput") -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished + self.kv_transfer_params = next_output.kv_transfer_params for next_completion in next_output.outputs: for completion in self.outputs: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 68ed99664947..2504b0367b04 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -26,6 +26,33 @@ class SamplingType(IntEnum): RANDOM_SEED = 2 +# TODO(rob): make this per connector +class KVTransferParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + # TODO(rob): we can handle xPyD and direct KV block Xfer + # remote_instance_id: Optional[str] = None + # remote_block_ids: Optional[list[int]] = None + do_remote_decode: bool = False + do_remote_prefill: bool = False + + @staticmethod + def from_optional(do_remote_decode: bool, + do_remote_prefill: bool) -> Optional["KVTransferParams"]: + if do_remote_prefill and do_remote_prefill: + raise ValueError( + "Cannot do both remote prefill and remote decode.") + elif do_remote_decode or do_remote_prefill: + return KVTransferParams( + do_remote_decode=do_remote_decode, + do_remote_prefill=do_remote_prefill, + ) + else: + return None + + # maybe make msgspec? @dataclass class GuidedDecodingParams: @@ -237,6 +264,9 @@ class SamplingParams( bad_words: Optional[list[str]] = None _bad_words_token_ids: Optional[list[list[int]]] = None + # Fields used for KVTransfer in disaggregated serving. + kv_transfer_params: Optional[KVTransferParams] = None + @staticmethod def from_optional( n: Optional[int] = 1, @@ -268,6 +298,7 @@ def from_optional( guided_decoding: Optional[GuidedDecodingParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, allowed_token_ids: Optional[list[int]] = None, + kv_transfer_params: Optional[KVTransferParams] = None, extra_args: Optional[dict[str, Any]] = None, ) -> "SamplingParams": if logit_bias is not None: @@ -310,6 +341,7 @@ def from_optional( guided_decoding=guided_decoding, logit_bias=logit_bias, allowed_token_ids=allowed_token_ids, + kv_transfer_params=kv_transfer_params, extra_args=extra_args, ) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 6263743d9710..297a2d2a1355 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -128,4 +128,4 @@ class SchedulerOutput: kv_connector_metadata: Optional[KVConnectorMetadata] = None sending_KV_req_ids: set[str] = field(default_factory=set) receiving_KV_req_ids: set[str] = field(default_factory=set) - new_KV_requests_to_send: list[NewRequestData] = field(default_factory=list) + new_KV_req_ids_to_send: list[str] = field(default_factory=list) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d3e562594aa1..000375e6a533 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.sampling_params import KVTransferParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -173,18 +174,8 @@ def schedule(self) -> SchedulerOutput: scheduled_timestamp = time.monotonic() # Check for new remote decode requests for P/D - new_KV_requests_to_send: list[Request] = [] + new_KV_req_ids_to_send: list[str] = [] if self.connector is not None: - # TODO: Receive request over ZMQ - # self.receiving_KV_req_ids.update( - # self.connector.receive_remote_decode_requests()) - - # Check if any P/D requests have finished sending or receiving - for req_id in list(self.waiting_to_send_KV_req_ids): - self.sending_KV_req_ids.add(req_id) - self.waiting_to_send_KV_req_ids.remove(req_id) - new_KV_requests_to_send.append(self.requests[req_id]) - for req_id in list(self.sending_KV_req_ids): if self.connector.is_request_done_sending(req_id): self.sending_KV_req_ids.remove(req_id) @@ -193,6 +184,10 @@ def schedule(self) -> SchedulerOutput: if self.connector.is_request_done_receiving(req_id): self.receiving_KV_req_ids.remove(req_id) self.waiting.append(self.requests[req_id]) + for req_id in list(self.waiting_to_send_KV_req_ids): + self.sending_KV_req_ids.add(req_id) + self.waiting_to_send_KV_req_ids.remove(req_id) + new_KV_req_ids_to_send.append(req_id) # First, schedule the RUNNING requests. req_index = 0 @@ -328,6 +323,19 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.appendleft(request) continue + # TODO(rob): we should do this after we allocate the blocks if + # we want to write directly into the BlockTable (like Dynamo). + # TODO(rob): this logic is incorrect if the req was preempted. + if request.do_remote_decode: + assert self.connector is not None + if not self.connector.is_request_done_receiving( + request.request_id): + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + self.receiving_KV_req_ids.add(request.request_id) + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + # Check that adding the request still respects the max_loras # constraint. if self.lora_config and request.lora_request and ( @@ -506,18 +514,13 @@ def schedule(self) -> SchedulerOutput: # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) - scheduler_output.kv_connector_metadata = meta - # TODO: encapsulate these in the KV connector metadata scheduler_output.sending_KV_req_ids = self.sending_KV_req_ids scheduler_output.receiving_KV_req_ids = self.receiving_KV_req_ids - new_KV_to_send_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_block_ids[req.request_id]) - for req in new_KV_requests_to_send - ] - scheduler_output.new_KV_requests_to_send = new_KV_to_send_reqs_data + scheduler_output.new_KV_req_ids_to_send = new_KV_req_ids_to_send + + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta # Advance the number of computed tokens for the request AFTER # the request is scheduled. @@ -719,7 +722,6 @@ def update_from_output( # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. - # TODO: What if we detect we're done here when doing P/D disagg? stopped = check_stop(request, self.max_model_len) if stopped: self._free_request(request) @@ -742,6 +744,22 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: + # Stop request after the first token if doing a remote_decode. + # NOTE(rob): req is not freed (or preempted) in the EngineCore + # until the xfer is done to ensure we do not free the KV blocks. + kv_transfer_params = None + if request.do_remote_decode and not stopped: + stopped = True + request.status = RequestStatus.FINISHED_REMOTE_DECODE + self.waiting_to_send_KV_req_ids.add(req_id) + assert self.connector is not None + # TODO(rob): do this on a per-Connector basis. + # NOTE(rob): this KVTransferParams will be sent to the + # DWorker. From the POV of the DWorker, it should be a + # remote Prefill. + kv_transfer_params = KVTransferParams( + do_remote_prefill=True) + # Add EngineCoreOutput for this Request. outputs.append( EngineCoreOutput( @@ -751,18 +769,14 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events())) + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + )) + else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors - if self.connector is not None and request.do_remote_decode: - stopped = True - self.waiting_to_send_KV_req_ids.add(req_id) - # TODO: Add ZMQ request - #self.connector.send_remote_decode_request( - # self.kv_cache_manager.req_to_blocks[req_id]) - self.scheduled_req_ids.remove(req_id) if not stopped: new_running.append(request) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index af4122a51077..ac6228edfc56 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -10,13 +10,13 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import KVTransferParams, SamplingParams from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors # These are possible values of RequestOutput.finish_reason, # so form part of the external API. -FINISH_REASON_STRINGS = ("stop", "length", "abort") +FINISH_REASON_STRINGS = ("stop", "length", "abort", "remote_decode") class FinishReason(enum.IntEnum): @@ -28,11 +28,13 @@ class FinishReason(enum.IntEnum): stop - a stop string was emitted length - max_tokens was consumed, or max_model_len was reached abort - aborted for another reason + remote_decode - request will be processed as a remote_decode """ STOP = 0 LENGTH = 1 ABORT = 2 + REMOTE_DECODE = 3 def __str__(self): return FINISH_REASON_STRINGS[self.value] @@ -102,6 +104,7 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + kv_transfer_params: Optional[KVTransferParams] = None @property def finished(self) -> bool: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 21e2a1aee4e2..1de8e8994a86 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -6,7 +6,7 @@ from typing import Optional, Union from vllm.outputs import CompletionOutput, RequestOutput -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import KVTransferParams, RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason @@ -148,6 +148,7 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + kv_transfer_params: KVTransferParams, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -169,13 +170,15 @@ def make_request_output( if not outputs: return None - return self._new_request_output(request_id, outputs, finished) + return self._new_request_output(request_id, outputs, finished, + kv_transfer_params) def _new_request_output( self, request_id: str, outputs: list[CompletionOutput], finished: bool, + kv_transfer_params: KVTransferParams, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -191,6 +194,7 @@ def _new_request_output( prompt_logprobs=prompt_logprobs, outputs=outputs, finished=finished, + kv_transfer_params=kv_transfer_params, ) def _new_completion_output( @@ -337,6 +341,7 @@ def process_outputs( new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason + kv_transfer_params = engine_core_output.kv_transfer_params req_state.is_prefilling = False @@ -352,7 +357,8 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, finish_reason, stop_reason, + kv_transfer_params): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 60b4ee739fec..11722c4ccc9a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -61,8 +61,10 @@ def __init__( self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 - # P/D disagg related - self.do_remote_decode = False + # Disaggregated serving related + self.do_remote_decode = ( + False if sampling_params.kv_transfer_params is None else + sampling_params.kv_transfer_params.do_remote_decode) # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) @@ -153,6 +155,7 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() + WAITING_FOR_REMOTE_KVS = enum.auto() RUNNING = enum.auto() SENDING_KV = enum.auto() PREEMPTED = enum.auto() @@ -162,6 +165,7 @@ class RequestStatus(enum.IntEnum): FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() + FINISHED_REMOTE_DECODE = enum.auto() @staticmethod def is_finished(status: "RequestStatus") -> bool: @@ -182,4 +186,5 @@ def get_finished_reason( RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, + RequestStatus.FINISHED_REMOTE_DECODE: FinishReason.REMOTE_DECODE } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c780bbe40934..e6c61ecb0254 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1016,6 +1016,11 @@ def maybe_setup_kv_connector(): if get_forward_context().attn_metadata is not None: kv_connector.start_load_kv(get_forward_context()) + def maybe_wait_for_save(): + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + kv_connector.wait_for_save() + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOutput if there's no work to do. @@ -1094,6 +1099,7 @@ def maybe_setup_kv_connector(): intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + maybe_wait_for_save() if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. return hidden_states From 5c4fc6f2d77ac189e7829740c68edc4afe0db374 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Sun, 20 Apr 2025 19:50:47 -0400 Subject: [PATCH 007/119] Rs branch (#5) Signed-off-by: rshaw@neuralmagic.com --- .../openai_completion_client.py | 2 +- .../openai_completion_client2.py | 39 ------------------- .../kv_transfer/kv_connector/v1/base.py | 2 +- .../kv_connector/v1/lmcache_connector.py | 4 +- .../v1/shared_storage_connector.py | 2 +- vllm/sampling_params.py | 2 +- vllm/v1/core/sched/scheduler.py | 21 +++++----- vllm/v1/request.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 4 ++ 9 files changed, 23 insertions(+), 57 deletions(-) delete mode 100644 examples/online_serving/openai_completion_client2.py diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index f1e6175268de..7917ac4797b5 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -18,7 +18,7 @@ def main(): # model = models.data[0].id # Completion API - stream = False + stream = True completion = client.completions.create( model="meta-llama/Llama-3.1-8B-Instruct", prompt= diff --git a/examples/online_serving/openai_completion_client2.py b/examples/online_serving/openai_completion_client2.py deleted file mode 100644 index fb6d0d120b5e..000000000000 --- a/examples/online_serving/openai_completion_client2.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from openai import OpenAI - -# Modify OpenAI's API key and API base to use vLLM's API server. -openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8100/v1" - - -def main(): - client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, - ) - - # models = client.models.list() - # model = models.data[0].id - - # Completion API - stream = False - completion = client.completions.create( - model="meta-llama/Llama-3.1-8B-Instruct", - prompt="The quick brown job jumped", - echo=False, - stream=stream) - - print("-" * 50) - print("Completion results:") - if stream: - for c in completion: - print(c) - else: - print(completion) - print("-" * 50) - - -if __name__ == "__main__": - main() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index fc67d118070e..810f3b001b1a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -214,5 +214,5 @@ def build_connector_meta( def is_request_done_sending(self, req_id: str) -> bool: raise NotImplementedError - def is_request_done_receiving(self, req_id: str) -> bool: + def is_request_done_receiving(self, request: "Request") -> bool: raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 3b64c14361a4..89d7ffe9ba58 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -134,5 +134,5 @@ def build_connector_meta( def is_request_done_sending(self, req_id: str) -> bool: return True - def is_request_done_receiving(self, req_id: str) -> bool: - return True + def is_request_done_receiving(self, request: "Request") -> bool: + return self._lmcache_engine.is_request_done_receiving(request) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 9e4ce253b618..01037fda285d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -333,7 +333,7 @@ def build_connector_meta( def is_request_done_sending(self, req_id: str) -> bool: return True - def is_request_done_receiving(self, req_id: str) -> bool: + def is_request_done_receiving(self, request: "Request") -> bool: return True # ============================== diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 2504b0367b04..38b84427b05d 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -41,7 +41,7 @@ class KVTransferParams( @staticmethod def from_optional(do_remote_decode: bool, do_remote_prefill: bool) -> Optional["KVTransferParams"]: - if do_remote_prefill and do_remote_prefill: + if do_remote_decode and do_remote_prefill: raise ValueError( "Cannot do both remote prefill and remote decode.") elif do_remote_decode or do_remote_prefill: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 000375e6a533..b9b8d7e45431 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -102,7 +102,6 @@ def __init__( # Requests in states for tracking KV transfers for P/D disagg self.waiting_to_send_KV_req_ids: set[str] = set() self.sending_KV_req_ids: set[str] = set() - self.receiving_KV_req_ids: set[str] = set() # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. @@ -180,10 +179,6 @@ def schedule(self) -> SchedulerOutput: if self.connector.is_request_done_sending(req_id): self.sending_KV_req_ids.remove(req_id) self.finished_req_ids.add(req_id) - for req_id in list(self.receiving_KV_req_ids): - if self.connector.is_request_done_receiving(req_id): - self.receiving_KV_req_ids.remove(req_id) - self.waiting.append(self.requests[req_id]) for req_id in list(self.waiting_to_send_KV_req_ids): self.sending_KV_req_ids.add(req_id) self.waiting_to_send_KV_req_ids.remove(req_id) @@ -326,15 +321,19 @@ def schedule(self) -> SchedulerOutput: # TODO(rob): we should do this after we allocate the blocks if # we want to write directly into the BlockTable (like Dynamo). # TODO(rob): this logic is incorrect if the req was preempted. - if request.do_remote_decode: + if request.do_remote_prefill: assert self.connector is not None - if not self.connector.is_request_done_receiving( - request.request_id): + # NOTE(rob): this returning false causes busy waiting + # if there is no other work to do. This is "functional" + # but not ideal. Also, if the transfer fails for any + # reason we will spin in this state. + if not self.connector.is_request_done_receiving(request): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS - self.receiving_KV_req_ids.add(request.request_id) self.waiting.popleft() skipped_waiting_requests.appendleft(request) continue + elif request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + request.status = RequestStatus.WAITING # Check that adding the request still respects the max_loras # constraint. @@ -516,7 +515,6 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: # TODO: encapsulate these in the KV connector metadata scheduler_output.sending_KV_req_ids = self.sending_KV_req_ids - scheduler_output.receiving_KV_req_ids = self.receiving_KV_req_ids scheduler_output.new_KV_req_ids_to_send = new_KV_req_ids_to_send meta = self.connector.build_connector_meta(scheduler_output) @@ -839,7 +837,8 @@ def _free_request(self, request: Request) -> None: self.finished_req_ids.add(request.request_id) def get_num_unfinished_requests(self) -> int: - return len(self.waiting) + len(self.running) + return len(self.waiting) + len(self.running) + len( + self.waiting_to_send_KV_req_ids) def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 11722c4ccc9a..dc70dea3d65f 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -65,6 +65,9 @@ def __init__( self.do_remote_decode = ( False if sampling_params.kv_transfer_params is None else sampling_params.kv_transfer_params.do_remote_decode) + self.do_remote_prefill = ( + False if sampling_params.kv_transfer_params is None else + sampling_params.kv_transfer_params.do_remote_prefill) # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) @@ -157,7 +160,6 @@ class RequestStatus(enum.IntEnum): WAITING_FOR_FSM = enum.auto() WAITING_FOR_REMOTE_KVS = enum.auto() RUNNING = enum.auto() - SENDING_KV = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED will be considered # as a finished status. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e6c61ecb0254..a0ba2e7a483c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1023,6 +1023,10 @@ def maybe_wait_for_save(): self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: + # KV send/recv even if no work to do. + with set_forward_context(None, self.vllm_config): + maybe_setup_kv_connector() + maybe_wait_for_save() # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT From 1800689648680408b81c11fac5c1d061b4daca64 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 21 Apr 2025 09:11:58 -0400 Subject: [PATCH 008/119] Remove Unneeded Arguments (#7) * updated Signed-off-by: rshaw@neuralmagic.com * stash Signed-off-by: rshaw@neuralmagic.com * cleanup Signed-off-by: rshaw@neuralmagic.com --------- Signed-off-by: rshaw@neuralmagic.com --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 2 -- vllm/v1/core/sched/output.py | 2 -- vllm/v1/core/sched/scheduler.py | 5 +---- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 810f3b001b1a..583cf0595581 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -205,8 +205,6 @@ def build_connector_meta( Args: scheduler_output (SchedulerOutput): the scheduler output object. - sending_KV_req_ids (set[str]): Request IDs to send - waiting_KV_req_ids (set[str]): Request IDs to receive """ pass diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 297a2d2a1355..1daee4a0418c 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -126,6 +126,4 @@ class SchedulerOutput: # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None - sending_KV_req_ids: set[str] = field(default_factory=set) - receiving_KV_req_ids: set[str] = field(default_factory=set) new_KV_req_ids_to_send: list[str] = field(default_factory=list) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b9b8d7e45431..eee43107740a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -506,6 +506,7 @@ def schedule(self) -> SchedulerOutput: free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, + new_KV_req_ids_to_send=new_KV_req_ids_to_send, ) # NOTE(Kuntai): this function is designed for multiple purposes: @@ -513,10 +514,6 @@ def schedule(self) -> SchedulerOutput: # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - # TODO: encapsulate these in the KV connector metadata - scheduler_output.sending_KV_req_ids = self.sending_KV_req_ids - scheduler_output.new_KV_req_ids_to_send = new_KV_req_ids_to_send - meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta From 7a1f25f714525c867fbf88401b81bbdd93b75d16 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 21 Apr 2025 13:57:29 -0400 Subject: [PATCH 009/119] Improve disagg-example.sh (#8) - fix spelling - CUDA_VISIBLE_DEVICES should be set externally Signed-off-by: Tyler Michael Smith --- examples/other/LMCache/disagg-example.sh | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/other/LMCache/disagg-example.sh b/examples/other/LMCache/disagg-example.sh index 89f1c753e887..8e52396c5eec 100644 --- a/examples/other/LMCache/disagg-example.sh +++ b/examples/other/LMCache/disagg-example.sh @@ -3,7 +3,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" if [[ $# -lt 1 ]]; then - echo "Usage: $0 [model]" + echo "Usage: $0 [model]" exit 1 fi @@ -16,7 +16,7 @@ else fi -if [[ $1 == "prefiller" ]]; then +if [[ $1 == "prefill" ]]; then # Prefiller listens on port 8100 prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml @@ -25,7 +25,6 @@ if [[ $1 == "prefiller" ]]; then LMCACHE_USE_EXPERIMENTAL=True \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ - CUDA_VISIBLE_DEVICES=6 \ vllm serve $MODEL \ --port 8100 \ --enforce-eager \ @@ -36,7 +35,7 @@ if [[ $1 == "prefiller" ]]; then # LMCACHE_LOG_LEVEL=DEBUG -- Set log level to DEBUG # --enforce-eager -- Enforce eager mode -elif [[ $1 == "decoder" ]]; then +elif [[ $1 == "decode" ]]; then # Decoder listens on port 8200 decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml @@ -45,7 +44,6 @@ elif [[ $1 == "decoder" ]]; then LMCACHE_USE_EXPERIMENTAL=True \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ - CUDA_VISIBLE_DEVICES=7 \ vllm serve $MODEL \ --port 8200 \ --enforce-eager \ From 2385d8e75066c751ae77ed93be49f266cf2f6787 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 11:59:22 +0000 Subject: [PATCH 010/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/kv_cache_manager.py | 9 +- vllm/v1/core/sched/scheduler.py | 149 ++++++++++++++++------------- vllm/v1/outputs.py | 22 +++-- vllm/v1/worker/gpu_model_runner.py | 17 +++- 4 files changed, 117 insertions(+), 80 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c3c83baf5129..2514d231135a 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -166,12 +166,16 @@ def get_computed_blocks( num_computed_tokens = len(computed_blocks) * self.block_size return computed_blocks, num_computed_tokens + def cache_blocks(self, request: Request): + pass + def allocate_slots( self, request: Request, num_tokens: int, new_computed_blocks: Optional[list[KVCacheBlock]] = None, num_lookahead_tokens: int = 0, + skip_cache_blocks: bool = False, ) -> Optional[list[KVCacheBlock]]: """Add slots for a request with new tokens to append. @@ -185,6 +189,9 @@ def allocate_slots( num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such as eagle. + skip_cache_blocks: Whether to skip cachings the blocks. This is + used by P/D when allocating blocks that used in KV transfer + which will complete in a future step. Blocks layout: ----------------------------------------------------------------------- @@ -275,7 +282,7 @@ def allocate_slots( new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) - if not self.enable_caching: + if not self.enable_caching or skip_cache_blocks: return new_blocks # Use `new_computed_blocks` for a new request, and `num_cached_block` diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index eee43107740a..5e0b641788c1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -100,8 +100,8 @@ def __init__( self.finished_req_ids: set[str] = set() # Requests in states for tracking KV transfers for P/D disagg - self.waiting_to_send_KV_req_ids: set[str] = set() self.sending_KV_req_ids: set[str] = set() + self.recving_KV_req_ids: set[str] = set() # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. @@ -172,22 +172,14 @@ def schedule(self) -> SchedulerOutput: # For logging. scheduled_timestamp = time.monotonic() - # Check for new remote decode requests for P/D - new_KV_req_ids_to_send: list[str] = [] - if self.connector is not None: - for req_id in list(self.sending_KV_req_ids): - if self.connector.is_request_done_sending(req_id): - self.sending_KV_req_ids.remove(req_id) - self.finished_req_ids.add(req_id) - for req_id in list(self.waiting_to_send_KV_req_ids): - self.sending_KV_req_ids.add(req_id) - self.waiting_to_send_KV_req_ids.remove(req_id) - new_KV_req_ids_to_send.append(req_id) - # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] + if request.request_id in self.recving_KV_req_ids: + # P/D: This request is still waiting for KVs. + req_index += 1 + continue if request.request_id in self.scheduled_req_ids: # This request has already been scheduled. req_index += 1 @@ -230,6 +222,11 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. # Preempt the lowest-priority request. preempted_req = self.running.pop() + # NOTE(rob): we cannot free these blocks once in flight. + # TODO(rob): understand full implications of this. + if preempted_req.request_id in self.recving_KV_req_ids: + pass + self.kv_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 @@ -318,23 +315,6 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.appendleft(request) continue - # TODO(rob): we should do this after we allocate the blocks if - # we want to write directly into the BlockTable (like Dynamo). - # TODO(rob): this logic is incorrect if the req was preempted. - if request.do_remote_prefill: - assert self.connector is not None - # NOTE(rob): this returning false causes busy waiting - # if there is no other work to do. This is "functional" - # but not ideal. Also, if the transfer fails for any - # reason we will spin in this state. - if not self.connector.is_request_done_receiving(request): - request.status = RequestStatus.WAITING_FOR_REMOTE_KVS - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) - continue - elif request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: - request.status = RequestStatus.WAITING - # Check that adding the request still respects the max_loras # constraint. if self.lora_config and request.lora_request and ( @@ -351,6 +331,7 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_computed_blocks(request) # Get externally-cached tokens if using a KVConnector. + # NOTE(rob): this returns the full prompt length for nixl num_external_tokens = ( 0 if self.connector is None else self.connector.get_num_new_matched_tokens( @@ -359,46 +340,67 @@ def schedule(self) -> SchedulerOutput: # Total computed tokens (local + external). num_computed_tokens += num_external_tokens - # Number of tokens to be scheduled. - # We use `request.num_tokens` instead of - # `request.num_prompt_tokens` to consider the resumed requests, - # which have output tokens. - num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) - num_new_tokens = min(num_new_tokens, token_budget) - assert num_new_tokens > 0 - - # Schedule encoder inputs. - if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_budget) - if num_new_tokens == 0: - # The request cannot be scheduled. + # TODO: how can we make this code clean? + if request.do_remote_prefill: + # TODO: handle preempted state. + assert request.status != RequestStatus.PREEMPTED + assert self.connector is not None + + # Schedule 0 tokens until the recv is done. + num_new_tokens = 0 + + # Allocate slots for the external tokens, but skip + # caching until after the KV transfer is done. + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_external_tokens, + computed_blocks, + skip_cache_blocks=True) + if new_blocks is None: + # Request cannot be scheduled. break + self.recving_KV_req_ids.add(request.request_id) + else: - encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed reqs, + # which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget + ) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + else: + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget - new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens + num_external_tokens, - computed_blocks) - if new_blocks is None: - # The request cannot be scheduled. - break + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens + num_external_tokens, + computed_blocks) + if new_blocks is None: + # The request cannot be scheduled. + break # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. if self.connector is not None: self.connector.update_state_after_alloc( - request, - num_external_tokens, - ) + request, num_external_tokens) self.waiting.popleft() if request.use_structured_output: @@ -506,7 +508,6 @@ def schedule(self) -> SchedulerOutput: free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, - new_KV_req_ids_to_send=new_KV_req_ids_to_send, ) # NOTE(Kuntai): this function is designed for multiple purposes: @@ -740,18 +741,17 @@ def update_from_output( prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: # Stop request after the first token if doing a remote_decode. + # TODO(rob): check if it is okay to send a finished request to + # AsyncLLM w/o adding it to eco.finished_requests # NOTE(rob): req is not freed (or preempted) in the EngineCore # until the xfer is done to ensure we do not free the KV blocks. kv_transfer_params = None if request.do_remote_decode and not stopped: stopped = True request.status = RequestStatus.FINISHED_REMOTE_DECODE - self.waiting_to_send_KV_req_ids.add(req_id) - assert self.connector is not None + self.sending_KV_req_ids.add(req_id) # TODO(rob): do this on a per-Connector basis. - # NOTE(rob): this KVTransferParams will be sent to the - # DWorker. From the POV of the DWorker, it should be a - # remote Prefill. + # From POV of DWorker, this is a remote prefill. kv_transfer_params = KVTransferParams( do_remote_prefill=True) @@ -776,6 +776,17 @@ def update_from_output( if not stopped: new_running.append(request) + # P/D: update recv and send status from last step. + for req_id in list(model_runner_output.finished_recving): + # TODO(rob): Implement this method. + # Cache blocks for APC after KVs have been recv'ed. + self.kv_cache_manager.cache_blocks(req_id) + self.recving_KV_req_ids.remove(req_id) + self.scheduled_req_ids.remove(req_id) + for req_id in list(model_runner_output.finished_sending): + self.sending_KV_req_ids.remove(req_id) + self._free_request(self.requests[req_id]) + self.running = new_running engine_core_outputs = EngineCoreOutputs( outputs=outputs, @@ -831,11 +842,11 @@ def _free_request(self, request: Request) -> None: self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] + self.sending_KV_req_ids.remove(request.request_id) self.finished_req_ids.add(request.request_id) def get_num_unfinished_requests(self) -> int: - return len(self.waiting) + len(self.running) + len( - self.waiting_to_send_KV_req_ids) + return len(self.waiting) + len(self.running) def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2732b933c28a..d1eae6a8ba7c 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -100,12 +100,16 @@ class ModelRunnerOutput: # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] - -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( - req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, -) + # [req_ids] + finished_sending: set[str] + finished_recving: set[str] + + +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=set(), + finished_recving=set()) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a0ba2e7a483c..3585d5b36c31 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1021,14 +1021,27 @@ def maybe_wait_for_save(): kv_connector = get_kv_transfer_group() kv_connector.wait_for_save() + def maybe_get_finished() -> tuple[set[str], set[str]]: + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + return kv_connector.get_finished() + else: + # TODO: make this optional instead. + return set(), set() + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # KV send/recv even if no work to do. with set_forward_context(None, self.vllm_config): maybe_setup_kv_connector() maybe_wait_for_save() + finished_sending, finished_recving = maybe_get_finished() # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + output = EMPTY_MODEL_RUNNER_OUTPUT + if len(finished_sending) > 0 or len(finished_sending) > 0: + output.finished_sending = finished_sending + output.finished_recving = finished_recving + return output # Prepare the decoder inputs. attn_metadata, logits_indices, spec_decode_metadata = ( @@ -1104,6 +1117,8 @@ def maybe_wait_for_save(): inputs_embeds=inputs_embeds, ) maybe_wait_for_save() + finished_sending, finished_recving = maybe_get_finished() + if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. return hidden_states From 6eeb47c640275475339e2dc12a61a66de5ad3c40 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 12:16:48 +0000 Subject: [PATCH 011/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/worker/gpu_model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3585d5b36c31..3bf329549c49 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1294,6 +1294,8 @@ def maybe_get_finished() -> tuple[set[str], set[str]]: spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + finished_sending=finished_sending, + finished_recving=finished_recving, ) def generate_draft_token_ids( From 266fceed9fd474fbaab001fa1d2e4ecd2b427a15 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 12:20:53 +0000 Subject: [PATCH 012/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_transfer/kv_connector/v1/base.py | 27 ++-- .../kv_connector/v1/lmcache_connector.py | 138 ------------------ 2 files changed, 10 insertions(+), 155 deletions(-) delete mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 583cf0595581..95967d2ca919 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -6,7 +6,7 @@ The class provides the following primitives: Scheduler-side: runs in the scheduler, binds metadata, which is used by the worker-side to load/save KV cache. - get_num_new_matched_tokens() - get number of new tokens + get_num_new_matched_tokens() - get number of new tokens that exist in the remote KV cache update_state_after_alloc() - update KVConnector state after temporary buffer alloc by the CacheManager. @@ -70,7 +70,7 @@ def bind_connector_metadata( self, connector_metadata: KVConnectorMetadata) -> None: """Set the connector metadata from the scheduler. - This function should be called by the model runner every time + This function should be called by the model runner every time before the model execution. The metadata will be used for runtime KV cache loading and saving. @@ -82,7 +82,7 @@ def bind_connector_metadata( def clear_connector_metadata(self) -> None: """Clear the connector metadata. - This function should be called by the model runner every time + This function should be called by the model runner every time after the model execution. """ self._connector_metadata = KVConnectorMetadata() @@ -114,9 +114,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. - + """ pass @@ -126,7 +126,7 @@ def wait_for_layer_load(self, layer_name: str) -> None: Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. - + This interface will be useful for layer-by-layer pipelining. Args: @@ -138,13 +138,13 @@ def wait_for_layer_load(self, layer_name: str) -> None: def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs) -> None: """ - Start saving a layer of KV cache from vLLM's paged buffer + Start saving a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -174,14 +174,14 @@ def get_num_new_matched_tokens( """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ pass @@ -207,10 +207,3 @@ def build_connector_meta( scheduler_output (SchedulerOutput): the scheduler output object. """ pass - - # These return true for now since they are not async - def is_request_done_sending(self, req_id: str) -> bool: - raise NotImplementedError - - def is_request_done_receiving(self, request: "Request") -> bool: - raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py deleted file mode 100644 index 89d7ffe9ba58..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING - -import torch -from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.logger import init_logger -from vllm.v1.core.sched.output import SchedulerOutput - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata - from vllm.forward_context import ForwardContext - from vllm.v1.request import Request - -logger = init_logger(__name__) - - -class LMCacheConnectorV1(KVConnectorBase_V1): - - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) - self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) - - # ============================== - # Worker-side methods - # ============================== - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - """ - Start loading the KV cache from the connector to vLLM's paged - KV buffer. This is called from the forward context before the - forward pass to enable async loading during model execution. - - Args: - forward_context (ForwardContext): the forward context. - **kwargs: additional arguments for the load operation - - Note: - The number of elements in kv_caches and layer_names should be - the same. - - """ - self._lmcache_engine.start_load_kv(forward_context, **kwargs) - - def wait_for_layer_load(self, layer_name: str) -> None: - """ - Block until the KV for a specific layer is loaded into vLLM's - paged buffer. This is called from within attention layer to ensure - async copying from start_load_kv is complete. - - This interface will be useful for layer-by-layer pipelining. - - Args: - layer_name: the name of that layer - """ - self._lmcache_engine.wait_for_layer_load(layer_name) - - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - """ - Start saving the a layer of KV cache from vLLM's paged buffer - to the connector. This is called from within attention layer to - enable async copying during execution. - - Args: - layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current - layer in vLLM. - attn_metadata (AttentionMetadata): the attention metadata. - **kwargs: additional arguments for the save operation. - """ - self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, - **kwargs) - - def wait_for_save(self): - """ - Block until all the save operations is done. This is called - as the forward context exits to ensure that the async saving - from save_kv_layer is complete before finishing the forward. - - This prevents overwrites of paged KV buffer before saving done. - """ - self._lmcache_engine.wait_for_save() - - # ============================== - # Scheduler-side methods - # ============================== - def get_num_new_matched_tokens( - self, - request: "Request", - num_computed_tokens: int, - ) -> int: - """ - Get number of new tokens that can be loaded from the - external KV cache beyond the num_computed_tokens. - - Args: - request (Request): the request object. - num_computed_tokens (int): the number of locally - computed tokens for this request - - Returns: - the number of tokens that can be loaded from the - external KV cache beyond what is already computed. - """ - return self._lmcache_engine.get_num_new_matched_tokens( - request, num_computed_tokens) - - def update_state_after_alloc(self, request: "Request", - num_external_tokens: int): - """ - Update KVConnector state after block allocation. - """ - self._lmcache_engine.update_state_after_alloc(request, - num_external_tokens) - - def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: - """ - Build the connector metadata for this step. - - This function should NOT modify fields in the scheduler_output. - Also, calling this function will reset the state of the connector. - - Args: - scheduler_output (SchedulerOutput): the scheduler output object. - """ - return self._lmcache_engine.build_connector_meta(scheduler_output) - - # These return true for now since they are not async - def is_request_done_sending(self, req_id: str) -> bool: - return True - - def is_request_done_receiving(self, request: "Request") -> bool: - return self._lmcache_engine.is_request_done_receiving(request) From f7e16f130f048ba2f7a5d93f7b554efaefb21012 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 12:21:09 +0000 Subject: [PATCH 013/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../v1/shared_storage_connector.py | 39 ++++++++----------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 01037fda285d..1d2040784e6c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -85,7 +85,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: - """Start loading the KV cache from the connector buffer to vLLM's + """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. Args: @@ -93,7 +93,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. """ attn_metadata = forward_context.attn_metadata @@ -106,13 +106,13 @@ def inject_kv_into_layer( """Inject the KV cache into the layer. Args: - dst_kv_cache_layer (torch.Tensor): the destination KV cache - layer. In shape [2, num_pages, page_size, xxx] if not + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not using MLA, [num_pages, page_size, xxx] otherwise. src_kv_cache (torch.Tensor): the source KV cache. In shape - [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] otherwise. - slot_mapping (torch.Tensor): the slot mapping. In shape + slot_mapping (torch.Tensor): the slot mapping. In shape [num_tokens]. """ dst_kv_cache_layer_shape = dst_kv_cache_layer.shape @@ -168,8 +168,8 @@ def inject_kv_into_layer( def wait_for_layer_load(self, layer_name: str) -> None: """Blocking until the KV for a specific layer is loaded into vLLM's - paged buffer. - + paged buffer. + This interface will be useful for layer-by-layer pipelining. Args: @@ -179,12 +179,12 @@ def wait_for_layer_load(self, layer_name: str) -> None: def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs) -> None: - """Start saving the KV cache of the layer from vLLM's paged buffer + """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -229,14 +229,14 @@ def get_num_new_matched_tokens( """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ @@ -271,7 +271,9 @@ def update_state_after_alloc(self, request: "Request", self._requests_need_load[request.request_id] = request def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: """Build the connector metadata for this step. This function should NOT modify any fields in the scheduler_output. @@ -329,13 +331,6 @@ def build_connector_meta( self._requests_need_load.clear() return meta - # These return true for now since they are not async - def is_request_done_sending(self, req_id: str) -> bool: - return True - - def is_request_done_receiving(self, request: "Request") -> bool: - return True - # ============================== # Helper functions # ============================== @@ -358,7 +353,7 @@ def _generate_foldername_debug( input_ids: torch.Tensor, create_folder=False, ) -> str: - """Generate a folder name based on the hash of the bytes of the input + """Generate a folder name based on the hash of the bytes of the input ids. """ input_ids_bytes = input_ids.numpy().tobytes() @@ -373,7 +368,7 @@ def _generate_filename_debug( layer_name: str, input_ids: torch.Tensor, ) -> str: - """Generate a file name based on the layer name and the hash + """Generate a file name based on the layer name and the hash of the bytes of the input ids. """ foldername = self._generate_foldername_debug(input_ids, From f591b8ef4a710b7f36b253fcec4fbb43aa046096 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 12:22:00 +0000 Subject: [PATCH 014/119] added connector Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/kv_rearrange.py | 119 ++++++++ .../kv_connector/v1/nixl_connector.py | 286 ++++++++++++++++++ 2 files changed, 405 insertions(+) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/kv_rearrange.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/kv_rearrange.py b/vllm/distributed/kv_transfer/kv_connector/v1/kv_rearrange.py new file mode 100644 index 000000000000..c55e4de8f7d9 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/kv_rearrange.py @@ -0,0 +1,119 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def rearrange_kernel_read( + t1_ptr, + t2_ptr, + N, + B, + H, + C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + curr_n = offsets // block_size + curr_b = offsets // token_size % B + curr_h = offsets // C % H + curr_c = offsets % C + + src_pos = offsets + + tp_group = curr_h * d // H + dst_h = curr_h % (H // d) + tp_group_offset = curr_n * (block_size // + d) + curr_b * (H // d) * C + dst_h * C + curr_c + + dst_pos = tensor_subset_size * tp_group + tp_group_offset + + tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos)) + + +@triton.jit +def rearrange_kernel_write( + t1_ptr, + t2_ptr, + N, + B, + H, + C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + curr_n = offsets // block_size + curr_b = offsets // token_size % B + curr_h = offsets // C % H + curr_c = offsets % C + + src_pos = offsets + + tp_group = curr_h * d // H + dst_h = curr_h % (H // d) + tp_group_offset = curr_n * (block_size // + d) + curr_b * (H // d) * C + dst_h * C + curr_c + + dst_pos = tensor_subset_size * tp_group + tp_group_offset + + tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos)) + + +def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, + direction: str): + N, B, H, C = t1.shape + + assert t2.shape == (N, B, H, + C), "Destination tensor must have same shape as source" + assert H % d == 0, "H must be divisible by d" + + block_size = B * H * C + token_size = H * C + tensor_size = N * block_size + tensor_subset_size = tensor_size // d + + BLOCK_SIZE = 1024 + grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + + if direction == "read": + rearrange_kernel_read[grid](t1, + t2, + N, + B, + H, + C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE=BLOCK_SIZE) + elif direction == "write": + rearrange_kernel_write[grid](t1, + t2, + N, + B, + H, + C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE=BLOCK_SIZE) + else: + raise ValueError(f"Invalid direction: {direction}") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py new file mode 100644 index 000000000000..b1b3f6d75c38 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 +import uuid +from collections import defaultdict +from typing import TYPE_CHECKING + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1) +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.v1.request import Request + +logger = init_logger(__name__) + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + + +class NixlConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): + self.vllm_config = vllm_config + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL wrapper") + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + + self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer + + self.num_layers = None + self.num_blocks = None + self.num_heads = None + self.block_len = None + self.kv_caches = None + self.kv_caches_base_addr = {} + self.kv_cache_shape = {} + + self._registered_descs = [] + self._remote_agents = {} + self.engine_id = engine_id + self.rank = rank + self.src_xfer_side_handles = {} + self.dst_xfer_side_handles = defaultdict(dict) + self.dst_num_blocks = {} + + # [req_id -> list[handle]] + self._recving_transfers = defaultdict(list) + + def get_agent_metadata(self): + return self.nixl_wrapper.get_agent_metadata() + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + For remote prefill, we allocate for all tokens. + """ + # Allocate space for external tokens. + if request.do_remote_prefill: + return len(request.prompt_token_ids) - num_computed_tokens + + def get_finished(self) -> tuple[set[str], set[str]]: + """Get the finished requests and the requests that are still in progress.""" + done_sending = self._get_new_notifs() + done_recving = self._update_transfers(self._recving_transfers) + return done_sending, done_recving + + def _get_new_notifs(self) -> set[str]: + """ + Get set of req_id that got a new notification. + """ + notified_req_ids: set[str] = set() + # TODO: handle the TP case (N notifies for TP=N). + # vllm/worker/worker_base.py L476 in DynamoPR. + for req_ids in self.nixl_wrapper.get_new_notifs().values(): + for req_id in req_ids: + assert req_id not in notified_req_ids + notified_req_ids.add(req_id) + return notified_req_ids + + def _update_transfers(self, transfers: dict[str, list[str]]) -> set[str]: + """ + Update the list of transfers that are running by checking + the nixl_xfer_state and removing those in "DONE" state. + + Args: + transfers: dictionary of req_id -> list[running_xfer] + + Returns: + set of req_ids that have all done xfers + """ + done_req_ids: str[str] = set() + for req_id, handles in transfers.items(): + running_reqs = [] + for handle in handles: + xfer_state = self.nixl_wrapper.check_xfer_state(handle) + if xfer_state == "DONE": + # TODO ptarasiewicz: why abort is throwing errors? + # self.nixl_wrapper.release_xfer_handle(handle) + continue + if xfer_state == "PROC": + running_reqs.append(handle) + else: + raise RuntimeError("Transfer failed with state %s", + xfer_state) + if len(running_reqs) == 0: + done_req_ids.add(req_id) + else: + transfers[req_id] = running_reqs + return done_req_ids + + def read_blocks( + self, + local_block_ids: list[int], + staging_block_ids: list[int], + remote_block_ids: list[int], + dst_engine_id: str, + request_id: str, + ): + # NOTE(rob): having the staging blocks be on the READER side is + # not going to work well (since we will have to call rearrange tensors). + # after we detect the txn is complete (which means we cannot make the + # read trxn async easily). If we want to make "READ" happen cleanly, then + # we will need to have the staging blocks on the remote side. + # NOTE(rob): we could potentially do the rearranging during the load_kv! + + assert len(local_block_ids) == len(staging_block_ids) == len( + remote_block_ids) + if len(local_block_ids) == 0: + return + + # TODO(rob): understand ranges code. + local_ranges = self._get_ranges(local_block_ids) + staging_ranges = self._get_ranges(staging_block_ids) + local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges( + local_ranges, staging_ranges) + + # TODO(rob): understand tp multiplier. + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[ + self.engine_id] + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, "all", remote_block_ids) + local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] + + # Read the data from the remote. + for i in range(tp_multiplier): + staging_block_descs_ids = self._get_block_descs_ids( + self.engine_id, + "all", + staging_block_ids, + i=i, + tp_multiplier=tp_multiplier, + staging_ranges=staging_rearranging_ranges) + assert len(staging_block_descs_ids) == len(remote_block_descs_ids) + remote_xfer_side_handle = self.dst_xfer_side_handles[ + dst_engine_id][i] + + # NOTE(rob): we use the request_id as the notify msg, so we + # must use the aem request_id in both the p and d workers. + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + staging_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=request_id, + ) + # NOTE(rob): we will check this is done in the next forward pass. + self._recving_transfers[request_id].append(handle) + + # NOTE(rob): this is actually pretty serious problem. + # We need to figure out if we can put the staging blocks on the P worker side. + # The staging blocks need to be on the side that sends. + + # for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): + # logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", + # self.kv_caches[0].shape, local_range, staging_range) + # for kv_cache in self.kv_caches: + # for cache in kv_cache: + # rearrange_tensors(cache[local_range[0]:local_range[1] + 1], + # cache[staging_range[0]:staging_range[1] + 1], + # tp_multiplier, "read") + + def shutdown(self): + for descs_list in self._registered_descs: + self.nixl_wrapper.deregister_memory(descs_list) + for agent_names in self._remote_agents.values(): + for agent_name in agent_names: + self.nixl_wrapper.remove_remote_agent(agent_name) + for src_xfer_side_handle in self.src_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) + for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + for dst_xfer_side_handle in dst_xfer_side_handles.values(): + self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) + + def register_kv_caches(self, kv_caches: list[torch.Tensor]): + _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape + self.block_len = block_size * num_heads * head_dim * kv_caches[ + 0].element_size() + logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) + self.num_layers = len(kv_caches) + self.num_blocks = num_blocks + self.num_heads = num_heads + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + for key_cache, value_cache in kv_caches: + base_addr = key_cache.data_ptr() + region_len = 2 * num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank, "")) + kv_caches_base_addr.append( + (key_cache.data_ptr(), value_cache.data_ptr())) + + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + self._registered_descs.append(descs) + + def _get_ranges(self, block_ids): + # This function should return a list of ranges of block ids that are contiguous + # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] + # The ranges are sorted by the starting block id + # The function should also make sure that the block ids are contiguous + # If the block ids are not contiguous, the function should raise an error + ranges = [] + for i in range(len(block_ids)): + if i == 0 or block_ids[i] != block_ids[i - 1] + 1: + ranges.append([block_ids[i], block_ids[i]]) + else: + ranges[-1][1] = block_ids[i] + return ranges + + def _get_block_descs_ids(self, + engine_id, + layer_ids, + block_ids, + i=None, + tp_multiplier=1, + staging_ranges=None): + + if layer_ids == "all": + layer_ids = list(range(self.num_layers)) + if block_ids == "all": + block_ids = list(range(self.num_blocks)) + + descs_ids = [] + + if i is not None: + num_blocks = self.num_blocks + for layer_id in layer_ids: + for is_value in [0, 1]: + staging_range_idx = 0 + for block_id in block_ids: + if block_id > staging_ranges[staging_range_idx][ + 1] or block_id < staging_ranges[ + staging_range_idx][0]: + staging_range_idx += 1 + start_offset = staging_ranges[staging_range_idx][0] + i_offset = i * (staging_ranges[staging_range_idx][-1] - + start_offset + 1) + descs_ids.append( + layer_id * 2 * num_blocks * tp_multiplier + + is_value * num_blocks * tp_multiplier + + start_offset * tp_multiplier + i_offset + + (block_id - start_offset)) + else: + num_blocks = self.dst_num_blocks[engine_id] + for layer_id in layer_ids: + for is_value in [0, 1]: + for block_id in block_ids: + descs_ids.append(layer_id * 2 * num_blocks + + is_value * num_blocks + block_id) + return descs_ids From 184d0b60dc581fba723148b9cd37e99641088b49 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 12:22:53 +0000 Subject: [PATCH 015/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/distributed/kv_transfer/kv_connector/factory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 6532c101a4f6..ad8a931daf57 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -102,6 +102,6 @@ def create_connector_v1( "SharedStorageConnector") KVConnectorFactory.register_connector( - "LMCacheConnectorV1", - "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", - "LMCacheConnectorV1") + "NixlConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", + "NixlConnector") From d4a9e5b407a15754a91821d0f4298f4a7702ff69 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 12:28:55 +0000 Subject: [PATCH 016/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/kv_cache_manager.py | 3 ++- vllm/v1/core/sched/scheduler.py | 4 +--- vllm/v1/request.py | 1 - 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 2514d231135a..33bb825a11a7 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -167,8 +167,9 @@ def get_computed_blocks( return computed_blocks, num_computed_tokens def cache_blocks(self, request: Request): + # TODO: implement this. pass - + def allocate_slots( self, request: Request, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5e0b641788c1..2eb490416eaa 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -331,7 +331,6 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_computed_blocks(request) # Get externally-cached tokens if using a KVConnector. - # NOTE(rob): this returns the full prompt length for nixl num_external_tokens = ( 0 if self.connector is None else self.connector.get_num_new_matched_tokens( @@ -784,7 +783,6 @@ def update_from_output( self.recving_KV_req_ids.remove(req_id) self.scheduled_req_ids.remove(req_id) for req_id in list(model_runner_output.finished_sending): - self.sending_KV_req_ids.remove(req_id) self._free_request(self.requests[req_id]) self.running = new_running @@ -842,7 +840,7 @@ def _free_request(self, request: Request) -> None: self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] - self.sending_KV_req_ids.remove(request.request_id) + self.sending_KV_req_ids.discard(request.request_id) self.finished_req_ids.add(request.request_id) def get_num_unfinished_requests(self) -> int: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index dc70dea3d65f..e720c4ee21fb 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -158,7 +158,6 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() - WAITING_FOR_REMOTE_KVS = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED will be considered From 4b0d1dc1fdd047a27157637e2e7a66476de004b9 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 13:41:41 +0000 Subject: [PATCH 017/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/nixl_connector.py | 60 ++++++++++++++++++- vllm/v1/core/sched/output.py | 1 - 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index b1b3f6d75c38..caa16379ffaf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -7,11 +7,13 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1) + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: from vllm.v1.request import Request + logger = init_logger(__name__) @@ -24,6 +26,18 @@ NixlWrapper = None +class NixlConnectorMetadata(KVConnectorMetadata): + def __init__(self): + self.block_ids: dict[str, list[int]] = {} + + def add_new_req( + self, + req_id: str, + block_ids: list[int], + ): + self.block_ids[req_id] = block_ids + + class NixlConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): @@ -52,6 +66,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): self.dst_xfer_side_handles = defaultdict(dict) self.dst_num_blocks = {} + # req_ids that need to start loading. + self._req_ids_to_load: set[str] = set() + # [req_id -> list[handle]] self._recving_transfers = defaultdict(list) @@ -70,12 +87,46 @@ def get_num_new_matched_tokens( if request.do_remote_prefill: return len(request.prompt_token_ids) - num_computed_tokens + def update_state_after_alloc( + self, + request: "Request", + num_external_tokens: int + ): + if request.do_remote_decode: + pass + if request.do_remote_prefill and num_external_tokens > 0: + self._req_ids_to_load.add(request.request_id) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + + meta = NixlConnectorMetadata() + + for req in scheduler_output.scheduled_new_reqs: + if req.req_id in self._req_ids_to_load: + meta.add_new_req(req.req_id, req.block_ids) + + return meta + + def get_finished(self) -> tuple[set[str], set[str]]: """Get the finished requests and the requests that are still in progress.""" done_sending = self._get_new_notifs() done_recving = self._update_transfers(self._recving_transfers) return done_sending, done_recving + def _get_new_notifs(self) -> set[str]: """ Get set of req_id that got a new notification. @@ -89,6 +140,7 @@ def _get_new_notifs(self) -> set[str]: notified_req_ids.add(req_id) return notified_req_ids + def _update_transfers(self, transfers: dict[str, list[str]]) -> set[str]: """ Update the list of transfers that are running by checking @@ -120,6 +172,7 @@ def _update_transfers(self, transfers: dict[str, list[str]]) -> set[str]: transfers[req_id] = running_reqs return done_req_ids + def read_blocks( self, local_block_ids: list[int], @@ -192,6 +245,7 @@ def read_blocks( # cache[staging_range[0]:staging_range[1] + 1], # tp_multiplier, "read") + def shutdown(self): for descs_list in self._registered_descs: self.nixl_wrapper.deregister_memory(descs_list) @@ -204,6 +258,7 @@ def shutdown(self): for dst_xfer_side_handle in dst_xfer_side_handles.values(): self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) + def register_kv_caches(self, kv_caches: list[torch.Tensor]): _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape self.block_len = block_size * num_heads * head_dim * kv_caches[ @@ -229,6 +284,7 @@ def register_kv_caches(self, kv_caches: list[torch.Tensor]): self.nixl_wrapper.register_memory(descs) self._registered_descs.append(descs) + def _get_ranges(self, block_ids): # This function should return a list of ranges of block ids that are contiguous # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] @@ -243,6 +299,7 @@ def _get_ranges(self, block_ids): ranges[-1][1] = block_ids[i] return ranges + def _get_block_descs_ids(self, engine_id, layer_ids, @@ -284,3 +341,4 @@ def _get_block_descs_ids(self, descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id) return descs_ids + diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 1daee4a0418c..66a07e7bcd83 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -126,4 +126,3 @@ class SchedulerOutput: # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None - new_KV_req_ids_to_send: list[str] = field(default_factory=list) From bfef039dde3c8837f6957b8705a2a73c70262e71 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 13:58:31 +0000 Subject: [PATCH 018/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/nixl_connector.py | 91 +++++++++++++++---- vllm/sampling_params.py | 4 +- vllm/v1/request.py | 1 + 3 files changed, 78 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index caa16379ffaf..dc11c5f21ef4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -7,12 +7,15 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, KVConnectorMetadata) from vllm.logger import init_logger +from vllm.sampling_params import KVTransferParams from vllm.v1.core.sched.output import SchedulerOutput + if TYPE_CHECKING: from vllm.v1.request import Request + from vllm.attention.backends.abstract import AttentionMetadata logger = init_logger(__name__) @@ -25,18 +28,33 @@ logger.warning("NIXL is not available") NixlWrapper = None +class ReqMeta: + def __init__( + self, + block_ids: list[int], + remote_block_ids: list[int], + remote_engine_id: list[int], + ): + self.block_ids = block_ids + self.remote_block_ids = remote_block_ids + self.remote_engine_id = remote_engine_id class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): - self.block_ids: dict[str, list[int]] = {} + self.requests: dict[str, ReqMeta] = {} def add_new_req( self, req_id: str, block_ids: list[int], + kv_transfer_params: KVTransferParams, ): - self.block_ids[req_id] = block_ids - + assert req_id not in self.requests + self.requests[req_id] = ReqMeta( + block_ids, + remote_block_ids=kv_transfer_params.remote_block_ids, + remote_engine_id=kv_transfer_params.remote_engine_id + ) class NixlConnector(KVConnectorBase_V1): @@ -67,7 +85,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): self.dst_num_blocks = {} # req_ids that need to start loading. - self._req_ids_to_load: set[str] = set() + self._reqs_to_load: dict[str, "Request"] = {} # [req_id -> list[handle]] self._recving_transfers = defaultdict(list) @@ -95,7 +113,7 @@ def update_state_after_alloc( if request.do_remote_decode: pass if request.do_remote_prefill and num_external_tokens > 0: - self._req_ids_to_load.add(request.request_id) + self._reqs_to_load[request.request_id] = request def build_connector_meta( self, @@ -113,20 +131,23 @@ def build_connector_meta( meta = NixlConnectorMetadata() - for req in scheduler_output.scheduled_new_reqs: - if req.req_id in self._req_ids_to_load: - meta.add_new_req(req.req_id, req.block_ids) + for new_req in scheduler_output.scheduled_new_reqs: + req = self._reqs_to_load.pop(new_req.req_id, None) + if req is not None: + meta.add_new_req( + new_req.req_id, + new_req.block_ids, + req.kv_transfer_params, + ) return meta - def get_finished(self) -> tuple[set[str], set[str]]: """Get the finished requests and the requests that are still in progress.""" done_sending = self._get_new_notifs() done_recving = self._update_transfers(self._recving_transfers) return done_sending, done_recving - def _get_new_notifs(self) -> set[str]: """ Get set of req_id that got a new notification. @@ -172,7 +193,6 @@ def _update_transfers(self, transfers: dict[str, list[str]]) -> set[str]: transfers[req_id] = running_reqs return done_req_ids - def read_blocks( self, local_block_ids: list[int], @@ -245,7 +265,6 @@ def read_blocks( # cache[staging_range[0]:staging_range[1] + 1], # tp_multiplier, "read") - def shutdown(self): for descs_list in self._registered_descs: self.nixl_wrapper.deregister_memory(descs_list) @@ -258,7 +277,6 @@ def shutdown(self): for dst_xfer_side_handle in dst_xfer_side_handles.values(): self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) - def register_kv_caches(self, kv_caches: list[torch.Tensor]): _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape self.block_len = block_size * num_heads * head_dim * kv_caches[ @@ -284,7 +302,6 @@ def register_kv_caches(self, kv_caches: list[torch.Tensor]): self.nixl_wrapper.register_memory(descs) self._registered_descs.append(descs) - def _get_ranges(self, block_ids): # This function should return a list of ranges of block ids that are contiguous # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] @@ -299,7 +316,6 @@ def _get_ranges(self, block_ids): ranges[-1][1] = block_ids[i] return ranges - def _get_block_descs_ids(self, engine_id, layer_ids, @@ -342,3 +358,46 @@ def _get_block_descs_ids(self, is_value * num_blocks + block_id) return descs_ids + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + # Get the metadata + metadata: KVConnectorMetadata = \ + self._get_connector_metadata() + assert isinstance(metadata, NixlConnectorMetadata) + + if metadata is None: + logger.warning( + "In connector.start_load_kv, but the connector metadata is None") + return + + for req_id in metadata.block_ids: + local_block_ids = metadata.block_ids[req_id] + # TODO: actually do staging blocks once we support different TP + staging_block_ids = metadata.block_ids[req_id] + remote_block_ids = metadata.remote_block_ids[req_id] + + + + def wait_for_layer_load(self, layer_name: str) -> None: + """NixlConnector does not do layerwise saving.""" + return + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """NixlConnector does not save explicitly.""" + return + + def wait_for_save(self): + """NixlConnector does not save explicitly.""" + return diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 38b84427b05d..4ecebb808c29 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -33,8 +33,8 @@ class KVTransferParams( # required for @cached_property. dict=True): # TODO(rob): we can handle xPyD and direct KV block Xfer - # remote_instance_id: Optional[str] = None - # remote_block_ids: Optional[list[int]] = None + remote_engine_id: Optional[str] = None + remote_block_ids: Optional[list[int]] = None do_remote_decode: bool = False do_remote_prefill: bool = False diff --git a/vllm/v1/request.py b/vllm/v1/request.py index e720c4ee21fb..60b004dc0b2d 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -68,6 +68,7 @@ def __init__( self.do_remote_prefill = ( False if sampling_params.kv_transfer_params is None else sampling_params.kv_transfer_params.do_remote_prefill) + self.kv_transfer_params = sampling_params.kv_transfer_params # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) From 54f4a43c16b61521c67a2083944d15f1059ff9b2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 14:00:58 +0000 Subject: [PATCH 019/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/nixl_connector.py | 49 ++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index dc11c5f21ef4..84f93c85688c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -12,11 +12,10 @@ from vllm.sampling_params import KVTransferParams from vllm.v1.core.sched.output import SchedulerOutput - if TYPE_CHECKING: - from vllm.v1.request import Request from vllm.attention.backends.abstract import AttentionMetadata - + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request logger = init_logger(__name__) @@ -28,7 +27,9 @@ logger.warning("NIXL is not available") NixlWrapper = None + class ReqMeta: + def __init__( self, block_ids: list[int], @@ -39,10 +40,12 @@ def __init__( self.remote_block_ids = remote_block_ids self.remote_engine_id = remote_engine_id + class NixlConnectorMetadata(KVConnectorMetadata): + def __init__(self): self.requests: dict[str, ReqMeta] = {} - + def add_new_req( self, req_id: str, @@ -53,8 +56,8 @@ def add_new_req( self.requests[req_id] = ReqMeta( block_ids, remote_block_ids=kv_transfer_params.remote_block_ids, - remote_engine_id=kv_transfer_params.remote_engine_id - ) + remote_engine_id=kv_transfer_params.remote_engine_id) + class NixlConnector(KVConnectorBase_V1): @@ -85,8 +88,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): self.dst_num_blocks = {} # req_ids that need to start loading. - self._reqs_to_load: dict[str, "Request"] = {} - + self._reqs_to_load: dict[str, Request] = {} + # [req_id -> list[handle]] self._recving_transfers = defaultdict(list) @@ -105,11 +108,8 @@ def get_num_new_matched_tokens( if request.do_remote_prefill: return len(request.prompt_token_ids) - num_computed_tokens - def update_state_after_alloc( - self, - request: "Request", - num_external_tokens: int - ): + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): if request.do_remote_decode: pass if request.do_remote_prefill and num_external_tokens > 0: @@ -161,7 +161,6 @@ def _get_new_notifs(self) -> set[str]: notified_req_ids.add(req_id) return notified_req_ids - def _update_transfers(self, transfers: dict[str, list[str]]) -> set[str]: """ Update the list of transfers that are running by checking @@ -378,16 +377,20 @@ def start_load_kv(self, forward_context: "ForwardContext", if metadata is None: logger.warning( - "In connector.start_load_kv, but the connector metadata is None") + "In connector.start_load_kv, but the connector metadata is None" + ) return - - for req_id in metadata.block_ids: - local_block_ids = metadata.block_ids[req_id] - # TODO: actually do staging blocks once we support different TP - staging_block_ids = metadata.block_ids[req_id] - remote_block_ids = metadata.remote_block_ids[req_id] - - + + for req_id, meta in metadata.requests.items(): + # this is non-blocking + self.read_blocks( + local_block_ids=meta.block_ids, + # TODO: support staging once we do heterogenous TP + staging_block_ids=meta.block_ids, + remote_block_ids=meta.remote_block_ids, + dst_engine_id=meta.remote_engine_id, + request_id=req_id, + ) def wait_for_layer_load(self, layer_name: str) -> None: """NixlConnector does not do layerwise saving.""" From e604b09ddd2c5311e7b353aa329b682b98393c97 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 14:06:21 +0000 Subject: [PATCH 020/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 84f93c85688c..cffdb88527e3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -205,6 +205,11 @@ def read_blocks( # after we detect the txn is complete (which means we cannot make the # read trxn async easily). If we want to make "READ" happen cleanly, then # we will need to have the staging blocks on the remote side. + + # NOTE(rob): according to nvidia the staging blocks are used to + # saturate IB with heterogenous TP sizes. We should remove the staging + # blocks until we are ready. + # NOTE(rob): we could potentially do the rearranging during the load_kv! assert len(local_block_ids) == len(staging_block_ids) == len( @@ -382,7 +387,7 @@ def start_load_kv(self, forward_context: "ForwardContext", return for req_id, meta in metadata.requests.items(): - # this is non-blocking + # NOTE: this is non-blocking self.read_blocks( local_block_ids=meta.block_ids, # TODO: support staging once we do heterogenous TP From 2fc00ad38fde00257080532cf0a962931331827f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 14:26:15 +0000 Subject: [PATCH 021/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 2eb490416eaa..56eef389603b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -745,14 +745,20 @@ def update_from_output( # NOTE(rob): req is not freed (or preempted) in the EngineCore # until the xfer is done to ensure we do not free the KV blocks. kv_transfer_params = None + # TODO(rob): edge case where we get a stop for stop_strings + # inside AsyncLLM. if request.do_remote_decode and not stopped: - stopped = True request.status = RequestStatus.FINISHED_REMOTE_DECODE self.sending_KV_req_ids.add(req_id) # TODO(rob): do this on a per-Connector basis. # From POV of DWorker, this is a remote prefill. kv_transfer_params = KVTransferParams( - do_remote_prefill=True) + do_remote_prefill=True, + # put the remote block ids here + remote_block_ids=[1,2,3], + # put the enigne id here + remote_engine_id="abcdefg", + ) # Add EngineCoreOutput for this Request. outputs.append( @@ -781,7 +787,6 @@ def update_from_output( # Cache blocks for APC after KVs have been recv'ed. self.kv_cache_manager.cache_blocks(req_id) self.recving_KV_req_ids.remove(req_id) - self.scheduled_req_ids.remove(req_id) for req_id in list(model_runner_output.finished_sending): self._free_request(self.requests[req_id]) From e5967b65c85b66feed5c00f3f267a0e4d2dff2f2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 14:30:26 +0000 Subject: [PATCH 022/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 56eef389603b..112621af0f45 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -176,8 +176,9 @@ def schedule(self) -> SchedulerOutput: req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - if request.request_id in self.recving_KV_req_ids: - # P/D: This request is still waiting for KVs. + if (request.request_id in self.recving_KV_req_ids + or request.request_id in self.sending_KV_req_ids): + # P/D: This request is still recv/sending KVs. req_index += 1 continue if request.request_id in self.scheduled_req_ids: @@ -755,7 +756,7 @@ def update_from_output( kv_transfer_params = KVTransferParams( do_remote_prefill=True, # put the remote block ids here - remote_block_ids=[1,2,3], + remote_block_ids=[1, 2, 3], # put the enigne id here remote_engine_id="abcdefg", ) From f1bc0f74b00d149f4ec74c8ac3ead65d8552b356 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 14:36:50 +0000 Subject: [PATCH 023/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 66a07e7bcd83..1d3f1f41f8fb 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: From 1cea2bb12a40be5f9b9cc56c4aec23ce29184cc0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 14:38:03 +0000 Subject: [PATCH 024/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 41 ++++++++++++++++----------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 112621af0f45..d28f2e49988e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -341,27 +341,7 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens += num_external_tokens # TODO: how can we make this code clean? - if request.do_remote_prefill: - # TODO: handle preempted state. - assert request.status != RequestStatus.PREEMPTED - assert self.connector is not None - - # Schedule 0 tokens until the recv is done. - num_new_tokens = 0 - - # Allocate slots for the external tokens, but skip - # caching until after the KV transfer is done. - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_external_tokens, - computed_blocks, - skip_cache_blocks=True) - if new_blocks is None: - # Request cannot be scheduled. - break - self.recving_KV_req_ids.add(request.request_id) - - else: + if not request.do_remote_prefill: # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed reqs, @@ -394,6 +374,25 @@ def schedule(self) -> SchedulerOutput: if new_blocks is None: # The request cannot be scheduled. break + else: + # TODO: handle preempted state. + assert request.status != RequestStatus.PREEMPTED + assert self.connector is not None + + # Schedule 0 tokens until the recv is done. + num_new_tokens = 0 + + # Allocate slots for the external tokens, but skip + # caching until after the KV transfer is done. + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_external_tokens, + computed_blocks, + skip_cache_blocks=True) + if new_blocks is None: + # Request cannot be scheduled. + break + self.recving_KV_req_ids.add(request.request_id) # KVConnector: update internal state after allocation. # This information is used to determine if a load is From 489e4c0c9d0b4be20a074a52d16b995f30f03549 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 14:38:28 +0000 Subject: [PATCH 025/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d28f2e49988e..d9ec48cbbb15 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -342,6 +342,7 @@ def schedule(self) -> SchedulerOutput: # TODO: how can we make this code clean? if not request.do_remote_prefill: + # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed reqs, From 437ac91f5142c1ae3fb4280e26c7b59cd496d3b6 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 14:39:32 +0000 Subject: [PATCH 026/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d9ec48cbbb15..995bc7512e22 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -400,7 +400,9 @@ def schedule(self) -> SchedulerOutput: # needed for this request. if self.connector is not None: self.connector.update_state_after_alloc( - request, num_external_tokens) + request, + num_external_tokens, + ) self.waiting.popleft() if request.use_structured_output: From ea47af78b192a1de26a6aeeda7f9fe83273a6dbe Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 18:38:56 +0000 Subject: [PATCH 027/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/config.py | 2 +- .../kv_transfer/kv_connector/v1/base.py | 7 + .../kv_connector/v1/nixl_connector.py | 321 ++++++++++-------- vllm/v1/worker/gpu_model_runner.py | 3 + vllm/v1/worker/gpu_worker.py | 5 +- 5 files changed, 188 insertions(+), 150 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 5b5ac40f6aa2..2bf8a18250ce 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3204,7 +3204,7 @@ class KVTransferConfig(BaseModel): kv_buffer_size: float = 1e9 # Whether this vLLM instance produces, consumes KV cache, or both. Choices - # are 'kv_producer', 'kv_consumer', and 'both'. + # are 'kv_producer', 'kv_consumer', and 'kv_both'. kv_role: Optional[str] = None # The rank of this vLLM instance in the KV cache transfer. Typical value: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 95967d2ca919..6d1d7b912f7a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -61,6 +61,13 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self._connector_metadata = KVConnectorMetadata() self._vllm_config = vllm_config self._role = role + + def init_with_kv_caches(self, kv_caches: tuple[torch.Tensor, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + """ + pass @property def role(self) -> KVConnectorRole: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index cffdb88527e3..59964d8f74e8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -7,7 +7,7 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata) + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.logger import init_logger from vllm.sampling_params import KVTransferParams from vllm.v1.core.sched.output import SchedulerOutput @@ -58,116 +58,85 @@ def add_new_req( remote_block_ids=kv_transfer_params.remote_block_ids, remote_engine_id=kv_transfer_params.remote_engine_id) +class NixlConnectorWorker: -class NixlConnector(KVConnectorBase_V1): - - def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): - self.vllm_config = vllm_config + def __init__(self): if NixlWrapper is None: logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") logger.info("Initializing NIXL wrapper") - self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) - self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer - - self.num_layers = None - self.num_blocks = None - self.num_heads = None - self.block_len = None - self.kv_caches = None - self.kv_caches_base_addr = {} - self.kv_cache_shape = {} - - self._registered_descs = [] - self._remote_agents = {} - self.engine_id = engine_id - self.rank = rank - self.src_xfer_side_handles = {} - self.dst_xfer_side_handles = defaultdict(dict) - self.dst_num_blocks = {} - - # req_ids that need to start loading. - self._reqs_to_load: dict[str, Request] = {} + # Agent. + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + # Metadata. + self.engine_id = "123" + self.rank = 1 + + # KV Caches and nixl tracking data. + self.num_layers: int = 0 + self.num_layers: int = 0 + self.num_heads: int = 0 + self.kv_caches: tuple[torch.Tensor, torch.Tensor] = ( + torch.empty(), torch.empty()) + self.kv_caches_base_addr: dict[str, int] = {} + self._registered_descs: list[any] = [] + + # In progress transfers. # [req_id -> list[handle]] - self._recving_transfers = defaultdict(list) - - def get_agent_metadata(self): - return self.nixl_wrapper.get_agent_metadata() - - def get_num_new_matched_tokens( - self, - request: "Request", - num_computed_tokens: int, - ) -> int: - """ - For remote prefill, we allocate for all tokens. - """ - # Allocate space for external tokens. - if request.do_remote_prefill: - return len(request.prompt_token_ids) - num_computed_tokens - - def update_state_after_alloc(self, request: "Request", - num_external_tokens: int): - if request.do_remote_decode: - pass - if request.do_remote_prefill and num_external_tokens > 0: - self._reqs_to_load[request.request_id] = request - - def build_connector_meta( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: + self._recving_transfers = defaultdict(list[any]) + + def register_kv_caches(self, kv_caches: tuple[torch.Tensor, torch.Tensor]): """ - Build the connector metadata for this step. - - This function should NOT modify fields in the scheduler_output. - Also, calling this function will reset the state of the connector. - - Args: - scheduler_output (SchedulerOutput): the scheduler output object. + Register the KV Cache data in nixl. """ + _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape + self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size() + logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) + self.num_layers = len(kv_caches) + self.num_blocks = num_blocks + self.num_heads = num_heads + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + for key_cache, value_cache in kv_caches: + base_addr = key_cache.data_ptr() + region_len = 2 * num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank, "")) + kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr())) - meta = NixlConnectorMetadata() - - for new_req in scheduler_output.scheduled_new_reqs: - req = self._reqs_to_load.pop(new_req.req_id, None) - if req is not None: - meta.add_new_req( - new_req.req_id, - new_req.block_ids, - req.kv_transfer_params, - ) + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr - return meta + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + self._registered_descs.append(descs) + self.nixl_wrapper.register_kv_caches(kv_caches) def get_finished(self) -> tuple[set[str], set[str]]: - """Get the finished requests and the requests that are still in progress.""" + """Get requests that are done sending or recving.""" done_sending = self._get_new_notifs() - done_recving = self._update_transfers(self._recving_transfers) + done_recving = self._pop_done_transfers(self._recving_transfers) return done_sending, done_recving def _get_new_notifs(self) -> set[str]: - """ - Get set of req_id that got a new notification. - """ + """Get req_ids which got a remote xfer message.""" + notified_req_ids: set[str] = set() # TODO: handle the TP case (N notifies for TP=N). - # vllm/worker/worker_base.py L476 in DynamoPR. + # See: vllm/worker/worker_base.py L476 in DynamoPR. for req_ids in self.nixl_wrapper.get_new_notifs().values(): for req_id in req_ids: assert req_id not in notified_req_ids notified_req_ids.add(req_id) return notified_req_ids - def _update_transfers(self, transfers: dict[str, list[str]]) -> set[str]: + def _pop_done_transfers(self, transfers: dict[str, list[str]]) -> set[str]: """ - Update the list of transfers that are running by checking - the nixl_xfer_state and removing those in "DONE" state. + Pop completed xfers by checking for DONE state. Args: - transfers: dictionary of req_id -> list[running_xfer] + transfers: dict of req_id -> list[running_xfer] Returns: set of req_ids that have all done xfers @@ -192,6 +161,22 @@ def _update_transfers(self, transfers: dict[str, list[str]]) -> set[str]: transfers[req_id] = running_reqs return done_req_ids + def start_load_kv(self, metadata: NixlConnectorMetadata): + """ + Start loading by triggering non-blocking nixl_xfer. + We check for these trnxs to complete in each step(). + """ + for req_id, meta in metadata.requests.items(): + # NOTE: this is non-blocking + self.read_blocks( + local_block_ids=meta.block_ids, + # TODO: support staging once we do heterogenous TP + staging_block_ids=meta.block_ids, + remote_block_ids=meta.remote_block_ids, + dst_engine_id=meta.remote_engine_id, + request_id=req_id, + ) + def read_blocks( self, local_block_ids: list[int], @@ -269,42 +254,112 @@ def read_blocks( # cache[staging_range[0]:staging_range[1] + 1], # tp_multiplier, "read") - def shutdown(self): - for descs_list in self._registered_descs: - self.nixl_wrapper.deregister_memory(descs_list) - for agent_names in self._remote_agents.values(): - for agent_name in agent_names: - self.nixl_wrapper.remove_remote_agent(agent_name) - for src_xfer_side_handle in self.src_xfer_side_handles.values(): - self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) - for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): - for dst_xfer_side_handle in dst_xfer_side_handles.values(): - self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) - - def register_kv_caches(self, kv_caches: list[torch.Tensor]): - _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape - self.block_len = block_size * num_heads * head_dim * kv_caches[ - 0].element_size() - logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) - self.num_layers = len(kv_caches) - self.num_blocks = num_blocks - self.num_heads = num_heads - self.kv_caches = kv_caches - kv_caches_base_addr = [] - caches_data = [] - for key_cache, value_cache in kv_caches: - base_addr = key_cache.data_ptr() - region_len = 2 * num_blocks * self.block_len - caches_data.append((base_addr, region_len, self.rank, "")) - kv_caches_base_addr.append( - (key_cache.data_ptr(), value_cache.data_ptr())) +class NixlConnectorScheduler: - self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config - descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") - logger.debug("Registering descs: %s", caches_data) - self.nixl_wrapper.register_memory(descs) - self._registered_descs.append(descs) + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, Request] = {} + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int) -> int: + """ + For remote prefill, we allocate for all tokens. + """ + # Allocate space for external tokens. + if request.do_remote_prefill: + return len(request.prompt_token_ids) - num_computed_tokens + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + if request.do_remote_decode: + pass + if request.do_remote_prefill and num_external_tokens > 0: + self._reqs_need_recv[request.request_id] = request + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = NixlConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for new_req in scheduler_output.scheduled_new_reqs: + req = self._reqs_need_recv.pop(new_req.req_id, None) + if req is not None: + meta.add_new_req( + request_id=new_req.req_id, + local_block_ids=new_req.block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Invariant: only new requests should need load + # and we should get all new requests each step(). + assert len(self._reqs_need_recv) == 0 + return meta + + +class NixlConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = NixlConnectorScheduler(vllm_config) + self.connector_worker = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = NixlConnectorWorker() + + ############################################################ + # Scheduler Side Methods + ############################################################ + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int) -> int: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta( + scheduler_output) + + ############################################################ + # Worker Side Methods + ############################################################ + + def register_kv_caches(self, kv_caches: torch.Tensor): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self) -> tuple[set[str], set[str]]: + """Get the finished requests and the requests that are still in progress.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + # def shutdown(self): + # for descs_list in self._registered_descs: + # self.nixl_wrapper.deregister_memory(descs_list) + # for agent_names in self._remote_agents.values(): + # for agent_name in agent_names: + # self.nixl_wrapper.remove_remote_agent(agent_name) + # for src_xfer_side_handle in self.src_xfer_side_handles.values(): + # self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) + # for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + # for dst_xfer_side_handle in dst_xfer_side_handles.values(): + # self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) def _get_ranges(self, block_ids): # This function should return a list of ranges of block ids that are contiguous @@ -364,38 +419,8 @@ def _get_block_descs_ids(self, def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: - """Start loading the KV cache from the connector buffer to vLLM's - paged KV buffer. - - Args: - forward_context (ForwardContext): the forward context. - **kwargs: additional arguments for the load operation - - Note: - The number of elements in kv_caches and layer_names should be - the same. - """ - # Get the metadata - metadata: KVConnectorMetadata = \ - self._get_connector_metadata() - assert isinstance(metadata, NixlConnectorMetadata) - - if metadata is None: - logger.warning( - "In connector.start_load_kv, but the connector metadata is None" - ) - return - - for req_id, meta in metadata.requests.items(): - # NOTE: this is non-blocking - self.read_blocks( - local_block_ids=meta.block_ids, - # TODO: support staging once we do heterogenous TP - staging_block_ids=meta.block_ids, - remote_block_ids=meta.remote_block_ids, - dst_engine_id=meta.remote_engine_id, - request_id=req_id, - ) + assert self.connector_worker is not None + self.connector_worker.start_load_kv() def wait_for_layer_load(self, layer_name: str) -> None: """NixlConnector does not do layerwise saving.""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3bf329549c49..c79ecfdfab5d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1730,6 +1730,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 3a29f8d0deef..deb801370f08 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,7 +14,9 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + has_kv_transfer_group, + get_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -200,6 +202,7 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: with context: self.model_runner.initialize_kv_cache(kv_cache_config) + def compile_or_warm_up_model(self) -> None: # warm up sizes that are not in cudagraph capture sizes, # but users still want to compile for better performance, From 554b27d3e920ffbcce000946681f8c8e38ef64ac Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 19:59:01 +0000 Subject: [PATCH 028/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/nixl_connector.py | 396 +++++++++++------- 1 file changed, 244 insertions(+), 152 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 59964d8f74e8..fa4931f989aa 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -3,6 +3,7 @@ from collections import defaultdict from typing import TYPE_CHECKING +import msgspec import torch from vllm.config import VllmConfig @@ -28,6 +29,18 @@ NixlWrapper = None +class NixlAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + engine_id: str + agent_metadata: list[bytes] + # Base addr for each rank, each layer for KVs + kv_caches_base_addr: list[list[tuple[int, int]]] + num_blocks: int + + class ReqMeta: def __init__( @@ -58,7 +71,131 @@ def add_new_req( remote_block_ids=kv_transfer_params.remote_block_ids, remote_engine_id=kv_transfer_params.remote_engine_id) + +class NixlConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = NixlConnectorScheduler(vllm_config) + self.connector_worker = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = NixlConnectorWorker() + + ############################################################ + # Scheduler Side Methods + ############################################################ + def get_num_new_matched_tokens(self, request: "Request", + num_computed_tokens: int) -> int: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + ############################################################ + # Worker Side Methods + ############################################################ + + def register_kv_caches(self, kv_caches: torch.Tensor): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + # def shutdown(self): + # for descs_list in self._registered_descs: + # self.nixl_wrapper.deregister_memory(descs_list) + # for agent_names in self._remote_agents.values(): + # for agent_name in agent_names: + # self.nixl_wrapper.remove_remote_agent(agent_name) + # for src_xfer_side_handle in self.src_xfer_side_handles.values(): + # self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) + # for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + # for dst_xfer_side_handle in dst_xfer_side_handles.values(): + # self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + self.connector_worker.start_load_kv() + + def wait_for_layer_load(self, layer_name: str) -> None: + """NixlConnector does not do layerwise saving.""" + return + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """NixlConnector does not save explicitly.""" + return + + def wait_for_save(self): + """NixlConnector does not save explicitly.""" + return + + +class NixlConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, Request] = {} + + def get_num_new_matched_tokens(self, request: "Request", + num_computed_tokens: int) -> int: + """For remote prefill, allocate for all tokens.""" + if request.do_remote_prefill: + return len(request.prompt_token_ids) - num_computed_tokens + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + if request.do_remote_decode: + pass + if request.do_remote_prefill and num_external_tokens > 0: + self._reqs_need_recv[request.request_id] = request + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = NixlConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for new_req in scheduler_output.scheduled_new_reqs: + req = self._reqs_need_recv.pop(new_req.req_id, None) + if req is not None: + meta.add_new_req( + request_id=new_req.req_id, + local_block_ids=new_req.block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Invariant: only new requests should need load + # and we should get all new requests each step(). + assert len(self._reqs_need_recv) == 0 + return meta + + class NixlConnectorWorker: + """Implementation of Worker side methods""" def __init__(self): if NixlWrapper is None: @@ -68,30 +205,46 @@ def __init__(self): # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + # Map of engine_id -> list[agent_names] (1 per rank). + self._remote_agents: dict[str, list[str]] = {} # Metadata. self.engine_id = "123" self.rank = 1 - + # KV Caches and nixl tracking data. self.num_layers: int = 0 self.num_layers: int = 0 self.num_heads: int = 0 - self.kv_caches: tuple[torch.Tensor, torch.Tensor] = ( - torch.empty(), torch.empty()) - self.kv_caches_base_addr: dict[str, int] = {} + self.kv_caches: tuple[torch.Tensor, + torch.Tensor] = (torch.empty(), torch.empty()) + + # Map of engine_id -> kv_caches_base_addr + # For Local: base addr for *this* rank, each layer for K,V + # For Remote: base addr for *each* rank, each layer for K,V + # KV_CACHES_ADDR_TYPE = Union[list[tuple[int, int]], + # list[list[tuple[int, int]]]] + self.kv_caches_base_addr: dict[str, any] = {} + + # Map of tp_mult -> nixl_prepped_dlist_handle (int). + self.src_xfer_side_handles: dict[int, int] = {} + # Map of engine_id -> map[tp_mult -> nixl_prepped_dlist_handle (int)]. + self.dst_xfer_side_handles: defaultdict[str, + dict[int, + int]] = defaultdict(dict) + # Map of engine_id -> num_blocks. + self.dst_num_blocks: dict[str, int] = {} self._registered_descs: list[any] = [] # In progress transfers. # [req_id -> list[handle]] self._recving_transfers = defaultdict(list[any]) - + def register_kv_caches(self, kv_caches: tuple[torch.Tensor, torch.Tensor]): - """ - Register the KV Cache data in nixl. - """ + """Register the KV Cache data in nixl.""" _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape - self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size() + self.block_len = block_size * num_heads * head_dim * kv_caches[ + 0].element_size() logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) self.num_layers = len(kv_caches) self.num_blocks = num_blocks @@ -103,7 +256,8 @@ def register_kv_caches(self, kv_caches: tuple[torch.Tensor, torch.Tensor]): base_addr = key_cache.data_ptr() region_len = 2 * num_blocks * self.block_len caches_data.append((base_addr, region_len, self.rank, "")) - kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr())) + kv_caches_base_addr.append( + (key_cache.data_ptr(), value_cache.data_ptr())) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr @@ -113,6 +267,71 @@ def register_kv_caches(self, kv_caches: tuple[torch.Tensor, torch.Tensor]): self._registered_descs.append(descs) self.nixl_wrapper.register_kv_caches(kv_caches) + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + engine_id = nixl_agent_meta.engine_id + num_blocks = nixl_agent_meta.num_blocks + + agent_names = [] + for agent_meta in nixl_agent_meta.agent_metadata: + agent_name = self.nixl_wrapper.add_remote_agent(agent_meta) + agent_names.append(agent_name) + + self._remote_agents[engine_id] = agent_names + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + + # NOTE: once we support heterogeneous TP, we will need maintain the src + # for each TP multiplier. + # NOTE(rob): Dynamo only supports D TP size > P TP size. + # https://github.com/vllm-project/vllm/pull/16124/files#diff-876efa5533f5dcff3fba850e8684a47d53c112e287988957c115b11691374f4bR331 # noqa: E501 + # Create descs and xfer side handles. + tp_multiplier = 1 + dst_block_len = self.block_len // tp_multiplier + if tp_multiplier not in self.src_xfer_side_handles: + # Create descs and xfer side handles. + blocks_data = [] + for layer_id in range(self.num_layers): + # Both K and V. + for base_addr in self.kv_caches_base_addr[ + self.engine_id][layer_id]: + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + for i in range(tp_multiplier): + tp_multiplier_offset = i * dst_block_len + blocks_data.append((base_addr + block_offset + + tp_multiplier_offset, + dst_block_len, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, + self.rank * tp_multiplier + i) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.src_xfer_side_handles[tp_multiplier] = ( + self.nixl_wrapper.prep_xfer_dlist("", descs)) + + # create dst xfer side handles + self.dst_num_blocks[engine_id] = num_blocks + for i in range(tp_multiplier): + blocks_data = [] + for layer_id in range(self.num_layers): + for base_addr in self.kv_caches_base_addr[engine_id][ + self.rank * tp_multiplier + i][layer_id]: + for block_id in range(num_blocks): + block_offset = block_id * dst_block_len + blocks_data.append( + (base_addr + block_offset, dst_block_len, + self.rank * tp_multiplier + i)) + logger.debug("Created %s blocks for dst engine %s and rank %s", + len(blocks_data), engine_id, + self.rank * tp_multiplier + i) + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[engine_id][i] = ( + self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][self.rank * tp_multiplier + + i], descs)) + def get_finished(self) -> tuple[set[str], set[str]]: """Get requests that are done sending or recving.""" done_sending = self._get_new_notifs() @@ -134,10 +353,8 @@ def _get_new_notifs(self) -> set[str]: def _pop_done_transfers(self, transfers: dict[str, list[str]]) -> set[str]: """ Pop completed xfers by checking for DONE state. - Args: transfers: dict of req_id -> list[running_xfer] - Returns: set of req_ids that have all done xfers """ @@ -168,16 +385,16 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): """ for req_id, meta in metadata.requests.items(): # NOTE: this is non-blocking - self.read_blocks( + self._read_blocks( local_block_ids=meta.block_ids, - # TODO: support staging once we do heterogenous TP + # TODO: support staging once we do heterogeneous TP staging_block_ids=meta.block_ids, remote_block_ids=meta.remote_block_ids, dst_engine_id=meta.remote_engine_id, request_id=req_id, ) - def read_blocks( + def _read_blocks( self, local_block_ids: list[int], staging_block_ids: list[int], @@ -188,11 +405,11 @@ def read_blocks( # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the - # read trxn async easily). If we want to make "READ" happen cleanly, then - # we will need to have the staging blocks on the remote side. + # read trxn async easily). If we want to make "READ" happen cleanly, + # then we will need to have the staging blocks on the remote side. # NOTE(rob): according to nvidia the staging blocks are used to - # saturate IB with heterogenous TP sizes. We should remove the staging + # saturate IB with heterogeneous TP sizes. We should remove the staging # blocks until we are ready. # NOTE(rob): we could potentially do the rearranging during the load_kv! @@ -205,7 +422,7 @@ def read_blocks( # TODO(rob): understand ranges code. local_ranges = self._get_ranges(local_block_ids) staging_ranges = self._get_ranges(staging_block_ids) - local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges( + _, staging_rearranging_ranges = self._get_same_length_ranges( local_ranges, staging_ranges) # TODO(rob): understand tp multiplier. @@ -242,131 +459,24 @@ def read_blocks( self._recving_transfers[request_id].append(handle) # NOTE(rob): this is actually pretty serious problem. - # We need to figure out if we can put the staging blocks on the P worker side. + # We need to figure out if we can put the staging blocks on the P worker side. # noqa: E501 # The staging blocks need to be on the side that sends. - # for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): - # logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", + # for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): # noqa: E501 + # logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", # noqa: E501 # self.kv_caches[0].shape, local_range, staging_range) # for kv_cache in self.kv_caches: # for cache in kv_cache: - # rearrange_tensors(cache[local_range[0]:local_range[1] + 1], - # cache[staging_range[0]:staging_range[1] + 1], + # rearrange_tensors(cache[local_range[0]:local_range[1] + 1], # noqa: E501 + # cache[staging_range[0]:staging_range[1] + 1], # noqa: E501 # tp_multiplier, "read") -class NixlConnectorScheduler: - - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config - - # Requests that need to start recv. - # New requests are added by update_state_after_alloc in - # the scheduler. Used to make metadata passed to Worker. - self._reqs_need_recv: dict[str, Request] = {} - - def get_num_new_matched_tokens( - self, request: "Request", num_computed_tokens: int) -> int: - """ - For remote prefill, we allocate for all tokens. - """ - # Allocate space for external tokens. - if request.do_remote_prefill: - return len(request.prompt_token_ids) - num_computed_tokens - - def update_state_after_alloc(self, request: "Request", - num_external_tokens: int): - if request.do_remote_decode: - pass - if request.do_remote_prefill and num_external_tokens > 0: - self._reqs_need_recv[request.request_id] = request - - def build_connector_meta( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: - meta = NixlConnectorMetadata() - - # Loop through scheduled reqs and convert to ReqMeta. - for new_req in scheduler_output.scheduled_new_reqs: - req = self._reqs_need_recv.pop(new_req.req_id, None) - if req is not None: - meta.add_new_req( - request_id=new_req.req_id, - local_block_ids=new_req.block_ids, - kv_transfer_params=req.kv_transfer_params, - ) - - # Invariant: only new requests should need load - # and we should get all new requests each step(). - assert len(self._reqs_need_recv) == 0 - return meta - - -class NixlConnector(KVConnectorBase_V1): - - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - - if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler = NixlConnectorScheduler(vllm_config) - self.connector_worker = None - elif role == KVConnectorRole.WORKER: - self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker() - - ############################################################ - # Scheduler Side Methods - ############################################################ - def get_num_new_matched_tokens( - self, request: "Request", num_computed_tokens: int) -> int: - assert self.connector_scheduler is not None - return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) - - def update_state_after_alloc(self, request: "Request", - num_external_tokens: int): - assert self.connector_scheduler is not None - return self.connector_scheduler.update_state_after_alloc( - request, num_external_tokens) - - def build_connector_meta( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: - assert self.connector_scheduler is not None - return self.connector_scheduler.build_connector_meta( - scheduler_output) - - ############################################################ - # Worker Side Methods - ############################################################ - - def register_kv_caches(self, kv_caches: torch.Tensor): - assert self.connector_worker is not None - self.connector_worker.register_kv_caches(kv_caches) - - def get_finished(self) -> tuple[set[str], set[str]]: - """Get the finished requests and the requests that are still in progress.""" - assert self.connector_worker is not None - return self.connector_worker.get_finished() - - # def shutdown(self): - # for descs_list in self._registered_descs: - # self.nixl_wrapper.deregister_memory(descs_list) - # for agent_names in self._remote_agents.values(): - # for agent_name in agent_names: - # self.nixl_wrapper.remove_remote_agent(agent_name) - # for src_xfer_side_handle in self.src_xfer_side_handles.values(): - # self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) - # for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): - # for dst_xfer_side_handle in dst_xfer_side_handles.values(): - # self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) - def _get_ranges(self, block_ids): - # This function should return a list of ranges of block ids that are contiguous - # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] + # This function should return a list of ranges of block ids that are contiguous # noqa: E501 + # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] # noqa: E501 # The ranges are sorted by the starting block id # The function should also make sure that the block ids are contiguous - # If the block ids are not contiguous, the function should raise an error + # If the block ids are not contiguous, the function should raise an error # noqa: E501 ranges = [] for i in range(len(block_ids)): if i == 0 or block_ids[i] != block_ids[i - 1] + 1: @@ -416,21 +526,3 @@ def _get_block_descs_ids(self, descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id) return descs_ids - - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - assert self.connector_worker is not None - self.connector_worker.start_load_kv() - - def wait_for_layer_load(self, layer_name: str) -> None: - """NixlConnector does not do layerwise saving.""" - return - - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - """NixlConnector does not save explicitly.""" - return - - def wait_for_save(self): - """NixlConnector does not save explicitly.""" - return From 1aea5ba45911d9b84589550087118cb40b8465ea Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 20:27:36 +0000 Subject: [PATCH 029/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/nixl_connector.py | 48 +++++++++++-------- vllm/v1/worker/gpu_worker.py | 5 +- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index fa4931f989aa..3103cecd4fe3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -76,12 +76,15 @@ class NixlConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + self.engine_id = uuid.uuid4() + if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler = NixlConnectorScheduler(vllm_config) + self.connector_scheduler = NixlConnectorScheduler( + vllm_config, self.engine_id) self.connector_worker = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker() + self.connector_worker = NixlConnectorWorker(self.engine_id) ############################################################ # Scheduler Side Methods @@ -118,17 +121,6 @@ def get_finished(self) -> tuple[set[str], set[str]]: assert self.connector_worker is not None return self.connector_worker.get_finished() - # def shutdown(self): - # for descs_list in self._registered_descs: - # self.nixl_wrapper.deregister_memory(descs_list) - # for agent_names in self._remote_agents.values(): - # for agent_name in agent_names: - # self.nixl_wrapper.remove_remote_agent(agent_name) - # for src_xfer_side_handle in self.src_xfer_side_handles.values(): - # self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) - # for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): - # for dst_xfer_side_handle in dst_xfer_side_handles.values(): - # self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None @@ -151,8 +143,9 @@ def wait_for_save(self): class NixlConnectorScheduler: """Implementation of Scheduler side methods""" - def __init__(self, vllm_config: VllmConfig): + def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config + self.engine_id = engine_id # Requests that need to start recv. # New requests are added by update_state_after_alloc in @@ -197,7 +190,7 @@ def build_connector_meta( class NixlConnectorWorker: """Implementation of Worker side methods""" - def __init__(self): + def __init__(self, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") @@ -209,7 +202,7 @@ def __init__(self): self._remote_agents: dict[str, list[str]] = {} # Metadata. - self.engine_id = "123" + self.engine_id = engine_id self.rank = 1 # KV Caches and nixl tracking data. @@ -240,7 +233,8 @@ def __init__(self): # [req_id -> list[handle]] self._recving_transfers = defaultdict(list[any]) - def register_kv_caches(self, kv_caches: tuple[torch.Tensor, torch.Tensor]): + def register_kv_caches(self, kv_caches: list[tuple[torch.Tensor, + torch.Tensor]]): """Register the KV Cache data in nixl.""" _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape self.block_len = block_size * num_heads * head_dim * kv_caches[ @@ -425,9 +419,8 @@ def _read_blocks( _, staging_rearranging_ranges = self._get_same_length_ranges( local_ranges, staging_ranges) - # TODO(rob): understand tp multiplier. - tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[ - self.engine_id] + # TODO: support TP multipliers. + tp_multiplier = 1 remote_block_descs_ids = self._get_block_descs_ids( dst_engine_id, "all", remote_block_ids) local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] @@ -446,7 +439,7 @@ def _read_blocks( dst_engine_id][i] # NOTE(rob): we use the request_id as the notify msg, so we - # must use the aem request_id in both the p and d workers. + # must use the same request_id in both the p and d workers. handle = self.nixl_wrapper.make_prepped_xfer( "READ", local_xfer_side_handle, @@ -526,3 +519,16 @@ def _get_block_descs_ids(self, descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id) return descs_ids + + def _shutdown(self): + """Shutdown all the NIXL related items.""" + for descs_list in self._registered_descs: + self.nixl_wrapper.deregister_memory(descs_list) + for agent_names in self._remote_agents.values(): + for agent_name in agent_names: + self.nixl_wrapper.remove_remote_agent(agent_name) + for src_xfer_side_handle in self.src_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) + for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + for dst_xfer_side_handle in dst_xfer_side_handles.values(): + self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index deb801370f08..3a29f8d0deef 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,9 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) -from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - has_kv_transfer_group, - get_kv_transfer_group) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -202,7 +200,6 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: with context: self.model_runner.initialize_kv_cache(kv_cache_config) - def compile_or_warm_up_model(self) -> None: # warm up sizes that are not in cudagraph capture sizes, # but users still want to compile for better performance, From e0c112b45cbcda929bae16dc32bd09769490cf42 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 20:28:36 +0000 Subject: [PATCH 030/119] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 6d1d7b912f7a..a6ec69c25d74 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -62,7 +62,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self._vllm_config = vllm_config self._role = role - def init_with_kv_caches(self, kv_caches: tuple[torch.Tensor, torch.Tensor]): + def register_kv_caches( + self, + kv_caches: list[tuple[torch.Tensor, torch.Tensor]] + ): """ Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). From c7717c1172eb1813042e7c90bb6b641ea344cd1c Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 22:09:11 +0000 Subject: [PATCH 031/119] update Signed-off-by: rshaw@neuralmagic.com --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 3103cecd4fe3..4760d248bf22 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -519,16 +519,3 @@ def _get_block_descs_ids(self, descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id) return descs_ids - - def _shutdown(self): - """Shutdown all the NIXL related items.""" - for descs_list in self._registered_descs: - self.nixl_wrapper.deregister_memory(descs_list) - for agent_names in self._remote_agents.values(): - for agent_name in agent_names: - self.nixl_wrapper.remove_remote_agent(agent_name) - for src_xfer_side_handle in self.src_xfer_side_handles.values(): - self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) - for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): - for dst_xfer_side_handle in dst_xfer_side_handles.values(): - self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) From e0af1db8562eb2f5cde95b78eaebf823cab3fd7a Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 22:09:55 +0000 Subject: [PATCH 032/119] remove Signed-off-by: rshaw@neuralmagic.com --- .../configs/lmcache-decoder-config.yaml | 13 -- .../configs/lmcache-prefiller-config.yaml | 13 -- examples/other/LMCache/disagg-example.sh | 71 ------- examples/other/LMCache/disagg_proxy_server.py | 193 ------------------ 4 files changed, 290 deletions(-) delete mode 100644 examples/other/LMCache/configs/lmcache-decoder-config.yaml delete mode 100644 examples/other/LMCache/configs/lmcache-prefiller-config.yaml delete mode 100644 examples/other/LMCache/disagg-example.sh delete mode 100644 examples/other/LMCache/disagg_proxy_server.py diff --git a/examples/other/LMCache/configs/lmcache-decoder-config.yaml b/examples/other/LMCache/configs/lmcache-decoder-config.yaml deleted file mode 100644 index c3f5a0ae69c0..000000000000 --- a/examples/other/LMCache/configs/lmcache-decoder-config.yaml +++ /dev/null @@ -1,13 +0,0 @@ -local_cpu: False -max_local_cpu_size: 0 -#local_disk: -max_local_disk_size: 0 -remote_serde: NULL - -enable_nixl: True -nixl_role: "receiver" -nixl_peer_host: "localhost" -nixl_peer_port: 55555 -nixl_buffer_size: 1073741824 # 1GB -nixl_buffer_device: "cuda" -nixl_enable_gc: True diff --git a/examples/other/LMCache/configs/lmcache-prefiller-config.yaml b/examples/other/LMCache/configs/lmcache-prefiller-config.yaml deleted file mode 100644 index 8b0e82958a64..000000000000 --- a/examples/other/LMCache/configs/lmcache-prefiller-config.yaml +++ /dev/null @@ -1,13 +0,0 @@ -local_cpu: False -max_local_cpu_size: 0 -#local_disk: -max_local_disk_size: 0 -remote_serde: NULL - -enable_nixl: True -nixl_role: "sender" -nixl_peer_host: "localhost" -nixl_peer_port: 55555 -nixl_buffer_size: 1073741824 # 1GB -nixl_buffer_device: "cuda" -nixl_enable_gc: True diff --git a/examples/other/LMCache/disagg-example.sh b/examples/other/LMCache/disagg-example.sh deleted file mode 100644 index 8e52396c5eec..000000000000 --- a/examples/other/LMCache/disagg-example.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -if [[ $# -lt 1 ]]; then - echo "Usage: $0 [model]" - exit 1 -fi - -if [[ $# -eq 1 ]]; then - echo "Using default model: meta-llama/Llama-3.1-8B-Instruct" - MODEL="meta-llama/Llama-3.1-8B-Instruct" -else - echo "Using model: $2" - MODEL=$2 -fi - - -if [[ $1 == "prefill" ]]; then - # Prefiller listens on port 8100 - prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml - - UCX_TLS=cuda_ipc,cuda_copy,tcp \ - LMCACHE_CONFIG_FILE=$prefill_config_file \ - LMCACHE_USE_EXPERIMENTAL=True \ - VLLM_ENABLE_V1_MULTIPROCESSING=1 \ - VLLM_WORKER_MULTIPROC_METHOD=spawn \ - vllm serve $MODEL \ - --port 8100 \ - --enforce-eager \ - --kv-transfer-config \ - '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' - - # Potential Env vars and cmdline options - # LMCACHE_LOG_LEVEL=DEBUG -- Set log level to DEBUG - # --enforce-eager -- Enforce eager mode - -elif [[ $1 == "decode" ]]; then - # Decoder listens on port 8200 - decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml - - UCX_TLS=cuda_ipc,cuda_copy,tcp \ - LMCACHE_CONFIG_FILE=$decode_config_file \ - LMCACHE_USE_EXPERIMENTAL=True \ - VLLM_ENABLE_V1_MULTIPROCESSING=1 \ - VLLM_WORKER_MULTIPROC_METHOD=spawn \ - vllm serve $MODEL \ - --port 8200 \ - --enforce-eager \ - --kv-transfer-config \ - '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' - - # Potential Env vars and cmdline options - # LMCACHE_LOG_LEVEL=DEBUG -- Set log level to DEBUG - # --enforce-eager -- Enforce eager mode - -elif [[ $1 == "proxy" ]]; then - # Proxy listens on port 9000 - python3 $SCRIPT_DIR/disagg_proxy_server.py \ - --host localhost \ - --port 9000 \ - --prefiller-host localhost \ - --prefiller-port 8100 \ - --decoder-host localhost \ - --decoder-port 8200 - -else - echo "Invalid role: $1" - echo "Should be either prefill, decode, or proxy" - exit 1 -fi diff --git a/examples/other/LMCache/disagg_proxy_server.py b/examples/other/LMCache/disagg_proxy_server.py deleted file mode 100644 index 2639409a1522..000000000000 --- a/examples/other/LMCache/disagg_proxy_server.py +++ /dev/null @@ -1,193 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import argparse -import os -import time -from contextlib import asynccontextmanager - -import httpx -import numpy as np -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """ - Lifespan context manager to handle startup and shutdown events. - """ - # Startup: Initialize clients - prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' - decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' - - app.state.prefill_client = httpx.AsyncClient(timeout=None, - base_url=prefiller_base_url) - app.state.decode_client = httpx.AsyncClient(timeout=None, - base_url=decoder_base_url) - - yield - - # Shutdown: Close clients - await app.state.prefill_client.aclose() - await app.state.decode_client.aclose() - - -# Update FastAPI app initialization to use lifespan -app = FastAPI(lifespan=lifespan) - - -class StatsCalculator: - - def __init__(self): - self._stats = [] - self._last_log_time = time.time() - - def add(self, value): - self._stats.append(value) - if time.time() - self._last_log_time > 5: - self._log_stats() - self._last_log_time = time.time() - - def _log_stats(self): - # Print average, median, and 99th percentile - np_arr = np.array(self._stats) - output_str = f"\nNum requests: {len(self._stats)}" + \ - "\nPrefill node TTFT stats:" + \ - f"\n - Average (ms): {np.mean(np_arr)}" + \ - f"\n - Median (ms): {np.median(np_arr)}" + \ - f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" - print("===============================", output_str, - "===============================") - - -stats_calculator = StatsCalculator() -counter = 0 - - -def parse_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--prefiller-host", type=str, default="localhost") - parser.add_argument("--prefiller-port", type=int, default=8100) - parser.add_argument("--decoder-host", type=str, default="localhost") - parser.add_argument("--decoder-port", type=int, default=8200) - args = parser.parse_args() - return args - - -# Initialize variables to hold the persistent clients -app.state.prefill_client = None -app.state.decode_client = None - - -async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, - req_data: dict): - """ - Send a request to a service using a persistent client. - """ - req_data = req_data.copy() - req_data['do_remote_decode'] = True - req_data["stream"] = False - headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} - response = await client.post(endpoint, json=req_data, headers=headers) - response.raise_for_status() - - return response - - -async def stream_service_response(client: httpx.AsyncClient, endpoint: str, - req_data: dict): - """ - Asynchronously stream the response from a service using a persistent client. - """ - headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} - req_data['do_remote_prefill'] = True - async with client.stream("POST", endpoint, json=req_data, - headers=headers) as response: - response.raise_for_status() - async for chunk in response.aiter_bytes(): - yield chunk - - -@app.post("/v1/completions") -async def handle_completions(request: Request): - global counter, stats_calculator - counter += 1 - - st = time.time() - try: - req_data = await request.json() - - # Send request to prefill service, ignore the response - await send_request_to_service(app.state.prefill_client, "/completions", - req_data) - - et = time.time() - stats_calculator.add(et - st) - - # Stream response from decode service - async def generate_stream(): - async for chunk in stream_service_response(app.state.decode_client, - "/completions", - req_data): - yield chunk - - return StreamingResponse(generate_stream(), - media_type="application/json") - - except Exception as e: - import sys - import traceback - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server" - " - completions endpoint") - print(e) - print("".join(traceback.format_exception(*exc_info))) - raise - - -@app.post("/v1/chat/completions") -async def handle_chat_completions(request: Request): - global counter, stats_calculator - counter += 1 - - st = time.time() - try: - req_data = await request.json() - - # Send request to prefill service, ignore the response - await send_request_to_service(app.state.prefill_client, - "/chat/completions", req_data) - - et = time.time() - stats_calculator.add(et - st) - - # Stream response from decode service - async def generate_stream(): - async for chunk in stream_service_response(app.state.decode_client, - "/chat/completions", - req_data): - yield chunk - - return StreamingResponse(generate_stream(), - media_type="application/json") - - except Exception as e: - import sys - import traceback - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server " - " - chat completions endpoint") - print(e) - print("".join(traceback.format_exception(*exc_info))) - raise - - -if __name__ == '__main__': - global global_args - global_args = parse_args() - - import uvicorn - uvicorn.run(app, host=global_args.host, port=global_args.port) From 95334718feb5711eb540ba217ad84aef38260376 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 22:12:04 +0000 Subject: [PATCH 033/119] updated Signed-off-by: rshaw@neuralmagic.com --- examples/disagg_proxy_server.py | 193 ++++++++++++++++++++++++++++++++ examples/proxy_example.sh | 71 ++++++++++++ 2 files changed, 264 insertions(+) create mode 100644 examples/disagg_proxy_server.py create mode 100644 examples/proxy_example.sh diff --git a/examples/disagg_proxy_server.py b/examples/disagg_proxy_server.py new file mode 100644 index 000000000000..2639409a1522 --- /dev/null +++ b/examples/disagg_proxy_server.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import time +from contextlib import asynccontextmanager + +import httpx +import numpy as np +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' + decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +class StatsCalculator: + + def __init__(self): + self._stats = [] + self._last_log_time = time.time() + + def add(self, value): + self._stats.append(value) + if time.time() - self._last_log_time > 5: + self._log_stats() + self._last_log_time = time.time() + + def _log_stats(self): + # Print average, median, and 99th percentile + np_arr = np.array(self._stats) + output_str = f"\nNum requests: {len(self._stats)}" + \ + "\nPrefill node TTFT stats:" + \ + f"\n - Average (ms): {np.mean(np_arr)}" + \ + f"\n - Median (ms): {np.median(np_arr)}" + \ + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" + print("===============================", output_str, + "===============================") + + +stats_calculator = StatsCalculator() +counter = 0 + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['do_remote_decode'] = True + req_data["stream"] = False + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + req_data['do_remote_prefill'] = True + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, "/completions", + req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, + "/chat/completions", req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/chat/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server " + " - chat completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/proxy_example.sh b/examples/proxy_example.sh new file mode 100644 index 000000000000..8e52396c5eec --- /dev/null +++ b/examples/proxy_example.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [model]" + exit 1 +fi + +if [[ $# -eq 1 ]]; then + echo "Using default model: meta-llama/Llama-3.1-8B-Instruct" + MODEL="meta-llama/Llama-3.1-8B-Instruct" +else + echo "Using model: $2" + MODEL=$2 +fi + + +if [[ $1 == "prefill" ]]; then + # Prefiller listens on port 8100 + prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + LMCACHE_CONFIG_FILE=$prefill_config_file \ + LMCACHE_USE_EXPERIMENTAL=True \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + vllm serve $MODEL \ + --port 8100 \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' + + # Potential Env vars and cmdline options + # LMCACHE_LOG_LEVEL=DEBUG -- Set log level to DEBUG + # --enforce-eager -- Enforce eager mode + +elif [[ $1 == "decode" ]]; then + # Decoder listens on port 8200 + decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + LMCACHE_CONFIG_FILE=$decode_config_file \ + LMCACHE_USE_EXPERIMENTAL=True \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + vllm serve $MODEL \ + --port 8200 \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' + + # Potential Env vars and cmdline options + # LMCACHE_LOG_LEVEL=DEBUG -- Set log level to DEBUG + # --enforce-eager -- Enforce eager mode + +elif [[ $1 == "proxy" ]]; then + # Proxy listens on port 9000 + python3 $SCRIPT_DIR/disagg_proxy_server.py \ + --host localhost \ + --port 9000 \ + --prefiller-host localhost \ + --prefiller-port 8100 \ + --decoder-host localhost \ + --decoder-port 8200 + +else + echo "Invalid role: $1" + echo "Should be either prefill, decode, or proxy" + exit 1 +fi From 2eb068e46edd18c43a068cad12320cebc1712422 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 22:47:11 +0000 Subject: [PATCH 034/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/nixl_connector.py | 38 +++++++++++++++++-- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4760d248bf22..780a4ef0e07b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -5,6 +5,7 @@ import msgspec import torch +import zmq from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -75,7 +76,6 @@ def add_new_req( class NixlConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - self.engine_id = uuid.uuid4() if role == KVConnectorRole.SCHEDULER: @@ -231,7 +231,7 @@ def __init__(self, engine_id: str): # In progress transfers. # [req_id -> list[handle]] - self._recving_transfers = defaultdict(list[any]) + self._recving_transfers = defaultdict(list[any]) def register_kv_caches(self, kv_caches: list[tuple[torch.Tensor, torch.Tensor]]): @@ -261,8 +261,38 @@ def register_kv_caches(self, kv_caches: list[tuple[torch.Tensor, self._registered_descs.append(descs) self.nixl_wrapper.register_kv_caches(kv_caches) + # THIS IS FOR DEBUG and INSECURE + import os + _ctx = zmq.Context() # type: ignore + _side_channel = _ctx.socket(zmq.PAIR) # type: ignore + NIXL_ROLE = os.getenv("NIXL_ROLE") + if NIXL_ROLE == "SENDER": + _side_channel.bind("tcp://localhost:5555") + _side_channel.setsockopt(zmq.LINGER, 0) # type: ignore + metadata = NixlAgentMetadata( + self.engine_id, + agent_metadata=self.nixl_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.v_ + ) + encoder = msgspec.msgpack.Encoder() + _side_channel.send(encoder.encode(metadata)) + + elif NIXL_ROLE == "RECVER": + _side_channel.bind("tcp://localhost:5555") + _side_channel.setsockopt(zmq.LINGER, 0) # type: ignore + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata_bytes = _side_channel.recv() + metadata = decoder.decode(metadata_bytes) + self.add_remote_agent(metadata) + + else: + raise Exception("SET NIXL_ROLE to SENDER OR RECVER") + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): engine_id = nixl_agent_meta.engine_id + if engine_id in self._remote_agents: + return + num_blocks = nixl_agent_meta.num_blocks agent_names = [] @@ -274,8 +304,8 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - # NOTE: once we support heterogeneous TP, we will need maintain the src - # for each TP multiplier. + # NOTE: once we support heterogeneous TP, we will need maintain the + # src for each TP multiplier. # NOTE(rob): Dynamo only supports D TP size > P TP size. # https://github.com/vllm-project/vllm/pull/16124/files#diff-876efa5533f5dcff3fba850e8684a47d53c112e287988957c115b11691374f4bR331 # noqa: E501 # Create descs and xfer side handles. From 0f2b7e358de15bc8c02eebe77d39e393416e4158 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 23:21:46 +0000 Subject: [PATCH 035/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_transfer/kv_connector/v1/base.py | 10 +++---- .../kv_connector/v1/nixl_connector.py | 29 ++++++++++--------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index a6ec69c25d74..e8ef2ddea802 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -61,14 +61,14 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self._connector_metadata = KVConnectorMetadata() self._vllm_config = vllm_config self._role = role - - def register_kv_caches( - self, - kv_caches: list[tuple[torch.Tensor, torch.Tensor]] - ): + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache """ pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 780a4ef0e07b..c79796128126 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -203,14 +203,13 @@ def __init__(self, engine_id: str): # Metadata. self.engine_id = engine_id - self.rank = 1 + self.rank = 0 # KV Caches and nixl tracking data. self.num_layers: int = 0 self.num_layers: int = 0 self.num_heads: int = 0 - self.kv_caches: tuple[torch.Tensor, - torch.Tensor] = (torch.empty(), torch.empty()) + self.kv_caches: tuple[torch.Tensor, torch.Tensor] = None # Map of engine_id -> kv_caches_base_addr # For Local: base addr for *this* rank, each layer for K,V @@ -231,22 +230,28 @@ def __init__(self, engine_id: str): # In progress transfers. # [req_id -> list[handle]] - self._recving_transfers = defaultdict(list[any]) + self._recving_transfers = defaultdict(list[any]) - def register_kv_caches(self, kv_caches: list[tuple[torch.Tensor, - torch.Tensor]]): + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape - self.block_len = block_size * num_heads * head_dim * kv_caches[ + + first_layer_name = next(iter(kv_caches)) + first_kv_cache = kv_caches[first_layer_name] + + # [2 (k and v), num_blocks, ...] + _, num_blocks, block_size, num_heads, head_dim = first_kv_cache.shape + self.block_len = block_size * num_heads * head_dim * first_kv_cache[ 0].element_size() - logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) + logger.debug("Per layer kv cache size: %s", first_kv_cache[0].shape) self.num_layers = len(kv_caches) self.num_blocks = num_blocks self.num_heads = num_heads self.kv_caches = kv_caches kv_caches_base_addr = [] caches_data = [] - for key_cache, value_cache in kv_caches: + for layer_name in kv_caches: + kv_cache = kv_caches[layer_name] + key_cache, value_cache = kv_cache[0], kv_cache[1] base_addr = key_cache.data_ptr() region_len = 2 * num_blocks * self.block_len caches_data.append((base_addr, region_len, self.rank, "")) @@ -259,7 +264,6 @@ def register_kv_caches(self, kv_caches: list[tuple[torch.Tensor, logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs) self._registered_descs.append(descs) - self.nixl_wrapper.register_kv_caches(kv_caches) # THIS IS FOR DEBUG and INSECURE import os @@ -272,8 +276,7 @@ def register_kv_caches(self, kv_caches: list[tuple[torch.Tensor, metadata = NixlAgentMetadata( self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), - kv_caches_base_addr=self.v_ - ) + kv_caches_base_addr=self.v_) encoder = msgspec.msgpack.Encoder() _side_channel.send(encoder.encode(metadata)) From 6127cb82f1041877043b83967e67d481dd44e869 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 22 Apr 2025 23:23:17 +0000 Subject: [PATCH 036/119] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c79796128126..a5293937eaac 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -274,9 +274,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): _side_channel.bind("tcp://localhost:5555") _side_channel.setsockopt(zmq.LINGER, 0) # type: ignore metadata = NixlAgentMetadata( - self.engine_id, + engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), - kv_caches_base_addr=self.v_) + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + ) encoder = msgspec.msgpack.Encoder() _side_channel.send(encoder.encode(metadata)) From 568249e695fe185c27071add93db735b5099d174 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 23 Apr 2025 11:39:11 +0000 Subject: [PATCH 037/119] updated Signed-off-by: rshaw@neuralmagic.com --- examples/proxy_example.sh | 10 ++- .../kv_connector/v1/nixl_connector.py | 62 +++++++++++-------- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/examples/proxy_example.sh b/examples/proxy_example.sh index 8e52396c5eec..029291ecaaee 100644 --- a/examples/proxy_example.sh +++ b/examples/proxy_example.sh @@ -17,8 +17,13 @@ fi if [[ $1 == "prefill" ]]; then - # Prefiller listens on port 8100 - prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + vllm serve Qwen/Qwen2.5-1.5B-Instruct \ + --port 8100 \ + --enforce-eager \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' UCX_TLS=cuda_ipc,cuda_copy,tcp \ LMCACHE_CONFIG_FILE=$prefill_config_file \ @@ -30,7 +35,6 @@ if [[ $1 == "prefill" ]]; then --enforce-eager \ --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' - # Potential Env vars and cmdline options # LMCACHE_LOG_LEVEL=DEBUG -- Set log level to DEBUG # --enforce-eager -- Enforce eager mode diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a5293937eaac..e905bc537789 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -36,9 +36,10 @@ class NixlAgentMetadata( # required for @cached_property. dict=True): engine_id: str - agent_metadata: list[bytes] - # Base addr for each rank, each layer for KVs - kv_caches_base_addr: list[list[tuple[int, int]]] + agent_metadata: bytes + # Base addr for each layer for KVs + # NOTE: we will need another list for TP>1 + kv_caches_base_addr: list[tuple[int, int]] num_blocks: int @@ -259,6 +260,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): (key_cache.data_ptr(), value_cache.data_ptr())) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + print(f"{len(self.kv_caches_base_addr[self.engine_id])=}") + print(f"{self.kv_caches_base_addr[self.engine_id][0]=}") descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") logger.debug("Registering descs: %s", caches_data) @@ -271,7 +274,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): _side_channel = _ctx.socket(zmq.PAIR) # type: ignore NIXL_ROLE = os.getenv("NIXL_ROLE") if NIXL_ROLE == "SENDER": - _side_channel.bind("tcp://localhost:5555") + _side_channel.connect("tcp://localhost:5555") _side_channel.setsockopt(zmq.LINGER, 0) # type: ignore metadata = NixlAgentMetadata( engine_id=self.engine_id, @@ -282,6 +285,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): encoder = msgspec.msgpack.Encoder() _side_channel.send(encoder.encode(metadata)) + logger.debug("WAITING ON RECV") + ack = _side_channel.recv() + logger.debug("GOT ACK %s", ack) + elif NIXL_ROLE == "RECVER": _side_channel.bind("tcp://localhost:5555") _side_channel.setsockopt(zmq.LINGER, 0) # type: ignore @@ -289,6 +296,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): metadata_bytes = _side_channel.recv() metadata = decoder.decode(metadata_bytes) self.add_remote_agent(metadata) + print("SENDING ACK") + _side_channel.send(b"ack") else: raise Exception("SET NIXL_ROLE to SENDER OR RECVER") @@ -301,9 +310,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): num_blocks = nixl_agent_meta.num_blocks agent_names = [] - for agent_meta in nixl_agent_meta.agent_metadata: - agent_name = self.nixl_wrapper.add_remote_agent(agent_meta) - agent_names.append(agent_name) + agent_name = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + agent_names.append(agent_name) self._remote_agents[engine_id] = agent_names self.kv_caches_base_addr[ @@ -321,6 +330,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): blocks_data = [] for layer_id in range(self.num_layers): # Both K and V. + print(f"{len(self.kv_caches_base_addr[self.engine_id])=}") + print(f"{len(self.kv_caches_base_addr[self.engine_id][layer_id])=}") + print(f"{self.kv_caches_base_addr[self.engine_id][layer_id]=}") for base_addr in self.kv_caches_base_addr[ self.engine_id][layer_id]: for block_id in range(self.num_blocks): @@ -341,25 +353,23 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): # create dst xfer side handles self.dst_num_blocks[engine_id] = num_blocks - for i in range(tp_multiplier): - blocks_data = [] - for layer_id in range(self.num_layers): - for base_addr in self.kv_caches_base_addr[engine_id][ - self.rank * tp_multiplier + i][layer_id]: - for block_id in range(num_blocks): - block_offset = block_id * dst_block_len - blocks_data.append( - (base_addr + block_offset, dst_block_len, - self.rank * tp_multiplier + i)) - logger.debug("Created %s blocks for dst engine %s and rank %s", - len(blocks_data), engine_id, - self.rank * tp_multiplier + i) - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[engine_id][i] = ( - self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id][self.rank * tp_multiplier + - i], descs)) + blocks_data = [] + for layer_id in range(self.num_layers): + for base_addr in self.kv_caches_base_addr[engine_id][layer_id]: + for block_id in range(num_blocks): + block_offset = block_id * dst_block_len + blocks_data.append( + (base_addr + block_offset, dst_block_len, + self.rank * tp_multiplier)) + logger.debug("Created %s blocks for dst engine %s and rank %s", + len(blocks_data), engine_id, + self.rank * tp_multiplier + i) + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[engine_id][i] = ( + self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][self.rank * tp_multiplier + + i], descs)) def get_finished(self) -> tuple[set[str], set[str]]: """Get requests that are done sending or recving.""" From ccb44eaee0c45a29e522c28d73940821d316a613 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 23 Apr 2025 12:45:39 +0000 Subject: [PATCH 038/119] seems to load properly Signed-off-by: rshaw@neuralmagic.com --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index e8ef2ddea802..3fd8c3344e2e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -61,14 +61,17 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self._connector_metadata = KVConnectorMetadata() self._vllm_config = vllm_config self._role = role - - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + + def register_kv_caches( + self, + kv_caches: dict[str, torch.Tensor] + ): """ Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). Args: kv_caches: - dictionary of layer names, kv cache + dictionary of layer names, kv cache """ pass From 3785905c02e667f137af4545e2d40729a4886f5b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 23 Apr 2025 22:34:13 +0000 Subject: [PATCH 039/119] updated Signed-off-by: rshaw@neuralmagic.com --- examples/offline_inference/basic/basic.py | 13 ++- tests/v1/core/test_scheduler.py | 118 +--------------------- 2 files changed, 9 insertions(+), 122 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index ae5ae7cb4834..60148bfd62c8 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -4,10 +4,10 @@ # Sample prompts. prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + "Hello, my name is Robert and I work for Red Hat software", + "The president of the United States is Joe Biden who is ", + "The capital of France is different from the capital of USA because", + "The future of AI is open source because there is a race to the bottom", ] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) @@ -15,7 +15,10 @@ def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m") + llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", + enforce_eager=True, + max_num_batched_tokens=16, + max_num_seqs=8) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 691ca59b062c..c76e90d3e3b0 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -16,123 +16,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager - -EOS_TOKEN_ID = 50256 - - -def create_scheduler( - model: str = "facebook/opt-125m", - max_num_seqs: int = 16, - max_num_batched_tokens: int = 8192, - enable_prefix_caching: Optional[bool] = None, - long_prefill_token_threshold: int = 0, - disable_chunked_mm_input: bool = False, - use_kv_connector: bool = False, - num_blocks: int = 10000, - block_size: int = 16, -) -> Scheduler: - '''Create scheduler under test. - - Args: - model: model under test - max_num_seqs: max sequences to schedule - max_num_batch_tokens: max num tokens to batch - enable_prefix_caching: optionally force APC config - (True/False) or use default - (None) - - Returns: - :class:`Scheduler` instance - ''' - scheduler_config = SchedulerConfig( - max_num_seqs=max_num_seqs, - max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_num_batched_tokens, - long_prefill_token_threshold=long_prefill_token_threshold, - disable_chunked_mm_input=disable_chunked_mm_input, - ) - model_config = ModelConfig( - model=model, - task="auto", - tokenizer=model, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=42, - ) - # Cache config, optionally force APC - kwargs_cache = ({} if enable_prefix_caching is None else { - 'enable_prefix_caching': enable_prefix_caching - }) - cache_config = CacheConfig( - block_size=block_size, - gpu_memory_utilization=0.9, - swap_space=0, - cache_dtype="auto", - **kwargs_cache, - ) - kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": "local_storage"}, - ) if use_kv_connector else None - - vllm_config = VllmConfig( - scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - ) - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, # A large number of blocks to hold all requests - tensors={}, - kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) - ], - ) - cache_config.num_gpu_blocks = num_blocks - return Scheduler( - vllm_config=vllm_config, - kv_cache_config=kv_cache_config, - log_stats=True, - structured_output_manager=StructuredOutputManager(vllm_config), - ) - - -def create_requests(num_requests: int, - num_tokens: int = 10, - mm_positions: Optional[list[PlaceholderRange]] = None, - max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): - sampling_params = SamplingParams(ignore_eos=False, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - prompt_logprobs=prompt_logprobs) - requests = [] - for i in range(num_requests): - if mm_positions is not None: - mm_position = mm_positions[i] - mm_inputs = [MultiModalKwargs({})] * len(mm_position) - else: - mm_position = None - mm_inputs = None - request = Request( - request_id=f"{i}", - prompt=None, - prompt_token_ids=[i] * num_tokens, - sampling_params=sampling_params, - multi_modal_inputs=mm_inputs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=None, - eos_token_id=EOS_TOKEN_ID, - arrival_time=0, - ) - requests.append(request) - return requests - +from vllm.tests.v1.utils import (create_scheduler, create_requests, EOS_TOKEN_ID) def test_add_requests(): scheduler = create_scheduler() From 8a94b2ea0dc9431ba3d8748e078881b80e6e754d Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 01:02:06 +0000 Subject: [PATCH 040/119] updated Signed-off-by: rshaw@neuralmagic.com --- requirements/test.txt | 21 +- tests/v1/core/__init__.py | 0 tests/v1/core/test_scheduler.py | 15 +- tests/v1/kv_connector/__init__.py | 0 tests/v1/kv_connector/test_scheduler.py | 198 ++++++++++++++++++ tests/v1/utils.py | 130 ++++++++++++ .../kv_connector/v1/nixl_connector.py | 44 ++-- vllm/v1/core/sched/scheduler.py | 12 +- vllm/v1/outputs.py | 8 +- vllm/v1/worker/gpu_model_runner.py | 5 +- 10 files changed, 391 insertions(+), 42 deletions(-) create mode 100644 tests/v1/core/__init__.py create mode 100644 tests/v1/kv_connector/__init__.py create mode 100644 tests/v1/kv_connector/test_scheduler.py create mode 100644 tests/v1/utils.py diff --git a/requirements/test.txt b/requirements/test.txt index 6dcd4ff01460..a4c8f1b129ff 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -27,6 +27,10 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration +async-timeout==5.0.1 + # via + # aiohttp + # redis attrs==24.2.0 # via # aiohttp @@ -132,6 +136,11 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval +exceptiongroup==1.2.2 + # via + # anyio + # hypothesis + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -633,7 +642,6 @@ setuptools==75.8.0 # via # mamba-ssm # pytablewriter - # torch shellingham==1.5.4 # via typer six==1.16.0 @@ -692,8 +700,13 @@ tokenizers==0.21.1 # via # -r requirements/test.in # transformers +toml==0.10.2 + # via datamodel-code-generator tomli==2.2.1 - # via schemathesis + # via + # black + # pytest + # schemathesis tomli-w==1.2.0 # via schemathesis torch==2.6.0 @@ -765,12 +778,16 @@ types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 # via + # anyio + # black # huggingface-hub # librosa # mistral-common + # multidict # pqdm # pydantic # pydantic-core + # rich # torch # typer tzdata==2024.2 diff --git a/tests/v1/core/__init__.py b/tests/v1/core/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index c76e90d3e3b0..8a46a474fd96 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -3,20 +3,15 @@ from unittest.mock import Mock import pytest -import torch -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, VllmConfig) -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange -from vllm.sampling_params import SamplingParams +from vllm.multimodal.inputs import PlaceholderRange from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.request import Request, RequestStatus -from vllm.v1.structured_output import StructuredOutputManager -from vllm.tests.v1.utils import (create_scheduler, create_requests, EOS_TOKEN_ID) +from vllm.v1.request import RequestStatus + +from ..utils import EOS_TOKEN_ID, create_requests, create_scheduler + def test_add_requests(): scheduler = create_scheduler() diff --git a/tests/v1/kv_connector/__init__.py b/tests/v1/kv_connector/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_scheduler.py new file mode 100644 index 000000000000..ace50b97101b --- /dev/null +++ b/tests/v1/kv_connector/test_scheduler.py @@ -0,0 +1,198 @@ +import copy + +import torch + +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, VllmConfig) +from vllm.sampling_params import KVTransferParams, SamplingParams +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +# SPDX-License-Identifier: Apache-2.0 + +EOS_TOKEN_ID = 50256 + + +def create_scheduler( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 64, + num_blocks: int = 10000, + block_size: int = 8, +) -> Scheduler: + '''Create scheduler under test. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + :class:`Scheduler` instance + ''' + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=False, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_requests( + num_requests: int, + num_tokens: int = 10, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, +) -> list[Request]: + if do_remote_decode: + assert not do_remote_prefill + kv_transfer_params = KVTransferParams(do_remote_prefill=True, ) + elif do_remote_prefill: + kv_transfer_params = KVTransferParams( + do_remote_prefill=True, + remote_engine_id="abc", + remote_block_ids=[1, 2, 3], + ) + else: + kv_transfer_params = None + + sampling_params = SamplingParams( + max_tokens=max_tokens, + kv_transfer_params=kv_transfer_params, + ) + requests = [] + for i in range(num_requests): + request = Request( + request_id=f"{i}", + prompt=None, + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=None, + multi_modal_placeholders=None, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + requests.append(request) + return requests + + +def test_remote_prefill_lifecycle(): + scheduler = create_scheduler() + + NUM_TOKENS = 16 + request = create_requests(num_requests=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True)[0] + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): + # Remote Prefill: req should be scheduled with 0 tokens + # but have the entire prompt "computed" from the POV of + # the scheduler + persistent batch (since the KVConnector + # will write directly into allocated blocks). + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler.recving_KV_req_ids) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + scheduled_req = scheduler_output.scheduled_new_reqs[0] + # We compute + assert scheduled_req.num_computed_tokens == NUM_TOKENS - 1 + assert scheduler_output.num_scheduled_tokens[scheduled_req.req_id] == 0 + + engine_core_outputs = scheduler.update_from_output( + scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + # Request should still be in the running and recving state. + assert len(scheduler.running) == 1 + assert len(scheduler.recving_KV_req_ids) == 1 + assert len(engine_core_outputs.outputs) == 0 + + # STEP (2): + # Remote Prefill: req should be running. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler.recving_KV_req_ids) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving.append(request_id) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + + # Request should be out of the recving state. + assert len(scheduler.running) == 1 + assert len(scheduler.recving_KV_req_ids) == 0 + assert len(engine_core_outputs.outputs) == 0 + + # STEP (3): + # Remote Prefill: the request should now have scheduled tokens. + scheduler_output = scheduler.schedule() + assert (len(scheduler_output.scheduled_cached_reqs)) == 1 + + # req_to_index = { + # request.request_id: i + # for i, request in enumerate(requests) + # } + # model_runner_output = ModelRunnerOutput( + # req_ids=[request.request_id for request in requests], + # req_id_to_index=req_to_index, + # # Only the first request has a sampled token id because + # # the rest requests are still being prefilled. + # sampled_token_ids=[[0], [], []], + # spec_token_ids=None, + # logprobs=None, + # prompt_logprobs_dict={}, + # ) diff --git a/tests/v1/utils.py b/tests/v1/utils.py new file mode 100644 index 000000000000..853df5396eff --- /dev/null +++ b/tests/v1/utils.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, VllmConfig) +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +EOS_TOKEN_ID = 50256 + + +def create_scheduler( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_prefix_caching: Optional[bool] = None, + long_prefill_token_threshold: int = 0, + disable_chunked_mm_input: bool = False, + use_kv_connector: bool = False, + num_blocks: int = 10000, + block_size: int = 16, +) -> Scheduler: + '''Create scheduler under test. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + :class:`Scheduler` instance + ''' + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + long_prefill_token_threshold=long_prefill_token_threshold, + disable_chunked_mm_input=disable_chunked_mm_input, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + kwargs_cache = ({} if enable_prefix_caching is None else { + 'enable_prefix_caching': enable_prefix_caching + }) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_requests(num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None): + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt=None, + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + requests.append(request) + return requests diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e905bc537789..c7c1fab2b7cf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -47,11 +47,11 @@ class ReqMeta: def __init__( self, - block_ids: list[int], + local_block_ids: list[int], remote_block_ids: list[int], remote_engine_id: list[int], ): - self.block_ids = block_ids + self.local_block_ids = local_block_ids self.remote_block_ids = remote_block_ids self.remote_engine_id = remote_engine_id @@ -63,13 +63,13 @@ def __init__(self): def add_new_req( self, - req_id: str, - block_ids: list[int], + request_id: str, + local_block_ids: list[int], kv_transfer_params: KVTransferParams, ): - assert req_id not in self.requests - self.requests[req_id] = ReqMeta( - block_ids, + assert request_id not in self.requests + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params.remote_block_ids, remote_engine_id=kv_transfer_params.remote_engine_id) @@ -157,7 +157,10 @@ def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> int: """For remote prefill, allocate for all tokens.""" if request.do_remote_prefill: - return len(request.prompt_token_ids) - num_computed_tokens + # Subtract 1 since we do not compute the last prompt + # token so that we can sample the first token here. + num_external_tokens = len(request.prompt_token_ids) - 1 + return num_external_tokens - num_computed_tokens def update_state_after_alloc(self, request: "Request", num_external_tokens: int): @@ -331,7 +334,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): for layer_id in range(self.num_layers): # Both K and V. print(f"{len(self.kv_caches_base_addr[self.engine_id])=}") - print(f"{len(self.kv_caches_base_addr[self.engine_id][layer_id])=}") + print( + f"{len(self.kv_caches_base_addr[self.engine_id][layer_id])=}" + ) print(f"{self.kv_caches_base_addr[self.engine_id][layer_id]=}") for base_addr in self.kv_caches_base_addr[ self.engine_id][layer_id]: @@ -360,16 +365,16 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): block_offset = block_id * dst_block_len blocks_data.append( (base_addr + block_offset, dst_block_len, - self.rank * tp_multiplier)) + self.rank * tp_multiplier)) logger.debug("Created %s blocks for dst engine %s and rank %s", - len(blocks_data), engine_id, - self.rank * tp_multiplier + i) + len(blocks_data), engine_id, + self.rank * tp_multiplier + i) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") self.dst_xfer_side_handles[engine_id][i] = ( self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id][self.rank * tp_multiplier + - i], descs)) + self._remote_agents[engine_id][self.rank * tp_multiplier + i], + descs)) def get_finished(self) -> tuple[set[str], set[str]]: """Get requests that are done sending or recving.""" @@ -377,16 +382,15 @@ def get_finished(self) -> tuple[set[str], set[str]]: done_recving = self._pop_done_transfers(self._recving_transfers) return done_sending, done_recving - def _get_new_notifs(self) -> set[str]: + def _get_new_notifs(self) -> list[str]: """Get req_ids which got a remote xfer message.""" - notified_req_ids: set[str] = set() + notified_req_ids: list[str] = [] # TODO: handle the TP case (N notifies for TP=N). # See: vllm/worker/worker_base.py L476 in DynamoPR. for req_ids in self.nixl_wrapper.get_new_notifs().values(): for req_id in req_ids: - assert req_id not in notified_req_ids - notified_req_ids.add(req_id) + notified_req_ids.append(req_id) return notified_req_ids def _pop_done_transfers(self, transfers: dict[str, list[str]]) -> set[str]: @@ -397,7 +401,7 @@ def _pop_done_transfers(self, transfers: dict[str, list[str]]) -> set[str]: Returns: set of req_ids that have all done xfers """ - done_req_ids: str[str] = set() + done_req_ids: list[str] = [] for req_id, handles in transfers.items(): running_reqs = [] for handle in handles: @@ -412,7 +416,7 @@ def _pop_done_transfers(self, transfers: dict[str, list[str]]) -> set[str]: raise RuntimeError("Transfer failed with state %s", xfer_state) if len(running_reqs) == 0: - done_req_ids.add(req_id) + done_req_ids.append(req_id) else: transfers[req_id] = running_reqs return done_req_ids diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 995bc7512e22..c81260dfc15f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -395,6 +395,10 @@ def schedule(self) -> SchedulerOutput: break self.recving_KV_req_ids.add(request.request_id) + # TODO: clean up code + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. @@ -785,12 +789,14 @@ def update_from_output( new_running.append(request) # P/D: update recv and send status from last step. - for req_id in list(model_runner_output.finished_recving): + for req_id in (model_runner_output.finished_recving or []): # TODO(rob): Implement this method. # Cache blocks for APC after KVs have been recv'ed. - self.kv_cache_manager.cache_blocks(req_id) + # self.kv_cache_manager.cache_blocks(req_id) + self.scheduled_req_ids.remove(req_id) self.recving_KV_req_ids.remove(req_id) - for req_id in list(model_runner_output.finished_sending): + print(f"{self.requests[req_id].num_computed_tokens=}") + for req_id in (model_runner_output.finished_sending or []): self._free_request(self.requests[req_id]) self.running = new_running diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index d1eae6a8ba7c..24052e01f006 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -101,8 +101,8 @@ class ModelRunnerOutput: prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] # [req_ids] - finished_sending: set[str] - finished_recving: set[str] + finished_sending: Optional[list[str]] = None + finished_recving: Optional[list[str]] = None EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], @@ -111,5 +111,5 @@ class ModelRunnerOutput: spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - finished_sending=set(), - finished_recving=set()) + finished_sending=[], + finished_recving=[]) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c79ecfdfab5d..d442aff3b03e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1021,13 +1021,12 @@ def maybe_wait_for_save(): kv_connector = get_kv_transfer_group() kv_connector.wait_for_save() - def maybe_get_finished() -> tuple[set[str], set[str]]: + def maybe_get_finished() -> tuple[list[str], list[str]]: if has_kv_transfer_group(): kv_connector = get_kv_transfer_group() return kv_connector.get_finished() else: - # TODO: make this optional instead. - return set(), set() + return [], [] self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: From ac1943755d066781f7911bef6e564541d0e4ffb3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 01:12:09 +0000 Subject: [PATCH 041/119] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/kv_connector/test_scheduler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_scheduler.py index ace50b97101b..685e7ce12d84 100644 --- a/tests/v1/kv_connector/test_scheduler.py +++ b/tests/v1/kv_connector/test_scheduler.py @@ -147,11 +147,16 @@ def test_remote_prefill_lifecycle(): assert len(scheduler.running) == 1 assert len(scheduler.recving_KV_req_ids) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1 + scheduled_req = scheduler_output.scheduled_new_reqs[0] - # We compute assert scheduled_req.num_computed_tokens == NUM_TOKENS - 1 assert scheduler_output.num_scheduled_tokens[scheduled_req.req_id] == 0 + # We should not cache blocks until the kvs are recved. + cache = scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block + assert len(cache) == 0 + assert request_id not in scheduler.kv_cache_manager.num_cached_block + engine_core_outputs = scheduler.update_from_output( scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) # Request should still be in the running and recving state. From 6391ec9563cb1c1675d5cf84b56e8c47e133a6e2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 01:17:04 +0000 Subject: [PATCH 042/119] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/core/__init__.py | 0 tests/v1/core/test_scheduler.py | 3 +-- 2 files changed, 1 insertion(+), 2 deletions(-) delete mode 100644 tests/v1/core/__init__.py diff --git a/tests/v1/core/__init__.py b/tests/v1/core/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 8a46a474fd96..367a07876774 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -5,13 +5,12 @@ import pytest from vllm.multimodal.inputs import PlaceholderRange +from vllm.tests.v1.utils import EOS_TOKEN_ID, create_requests, create_scheduler from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import RequestStatus -from ..utils import EOS_TOKEN_ID, create_requests, create_scheduler - def test_add_requests(): scheduler = create_scheduler() From 7dd764bfc9900d04c189e2dc84d3aad18be5a0c7 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 01:17:23 +0000 Subject: [PATCH 043/119] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/utils.py | 130 ---------------------------------------------- 1 file changed, 130 deletions(-) delete mode 100644 tests/v1/utils.py diff --git a/tests/v1/utils.py b/tests/v1/utils.py deleted file mode 100644 index 853df5396eff..000000000000 --- a/tests/v1/utils.py +++ /dev/null @@ -1,130 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional - -import torch - -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, VllmConfig) -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange -from vllm.sampling_params import SamplingParams -from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) -from vllm.v1.request import Request -from vllm.v1.structured_output import StructuredOutputManager - -EOS_TOKEN_ID = 50256 - - -def create_scheduler( - model: str = "facebook/opt-125m", - max_num_seqs: int = 16, - max_num_batched_tokens: int = 8192, - enable_prefix_caching: Optional[bool] = None, - long_prefill_token_threshold: int = 0, - disable_chunked_mm_input: bool = False, - use_kv_connector: bool = False, - num_blocks: int = 10000, - block_size: int = 16, -) -> Scheduler: - '''Create scheduler under test. - - Args: - model: model under test - max_num_seqs: max sequences to schedule - max_num_batch_tokens: max num tokens to batch - enable_prefix_caching: optionally force APC config - (True/False) or use default - (None) - - Returns: - :class:`Scheduler` instance - ''' - scheduler_config = SchedulerConfig( - max_num_seqs=max_num_seqs, - max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_num_batched_tokens, - long_prefill_token_threshold=long_prefill_token_threshold, - disable_chunked_mm_input=disable_chunked_mm_input, - ) - model_config = ModelConfig( - model=model, - task="auto", - tokenizer=model, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=42, - ) - # Cache config, optionally force APC - kwargs_cache = ({} if enable_prefix_caching is None else { - 'enable_prefix_caching': enable_prefix_caching - }) - cache_config = CacheConfig( - block_size=block_size, - gpu_memory_utilization=0.9, - swap_space=0, - cache_dtype="auto", - **kwargs_cache, - ) - kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": "local_storage"}, - ) if use_kv_connector else None - - vllm_config = VllmConfig( - scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - ) - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, # A large number of blocks to hold all requests - tensors={}, - kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) - ], - ) - cache_config.num_gpu_blocks = num_blocks - return Scheduler( - vllm_config=vllm_config, - kv_cache_config=kv_cache_config, - log_stats=True, - structured_output_manager=StructuredOutputManager(vllm_config), - ) - - -def create_requests(num_requests: int, - num_tokens: int = 10, - mm_positions: Optional[list[PlaceholderRange]] = None, - max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): - sampling_params = SamplingParams(ignore_eos=False, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - prompt_logprobs=prompt_logprobs) - requests = [] - for i in range(num_requests): - if mm_positions is not None: - mm_position = mm_positions[i] - mm_inputs = [MultiModalKwargs({})] * len(mm_position) - else: - mm_position = None - mm_inputs = None - request = Request( - request_id=f"{i}", - prompt=None, - prompt_token_ids=[i] * num_tokens, - sampling_params=sampling_params, - multi_modal_inputs=mm_inputs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=None, - eos_token_id=EOS_TOKEN_ID, - arrival_time=0, - ) - requests.append(request) - return requests From 97316d9df800639b2ef7db77c0a5fbf76d50c747 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 01:18:09 +0000 Subject: [PATCH 044/119] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/core/test_scheduler.py | 145 +++++++++++++++++++++++++++++--- 1 file changed, 134 insertions(+), 11 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 367a07876774..560a60a81446 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -3,13 +3,139 @@ from unittest.mock import Mock import pytest +import torch -from vllm.multimodal.inputs import PlaceholderRange -from vllm.tests.v1.utils import EOS_TOKEN_ID, create_requests, create_scheduler +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, VllmConfig) +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.request import RequestStatus +from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager + +EOS_TOKEN_ID = 50256 + + +def create_scheduler( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_prefix_caching: Optional[bool] = None, + long_prefill_token_threshold: int = 0, + disable_chunked_mm_input: bool = False, + use_kv_connector: bool = False, + num_blocks: int = 10000, + block_size: int = 16, + max_model_len: Optional[int] = None, +) -> Scheduler: + '''Create scheduler under test. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + :class:`Scheduler` instance + ''' + if max_model_len is None: + max_model_len = max_num_batched_tokens + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + long_prefill_token_threshold=long_prefill_token_threshold, + disable_chunked_mm_input=disable_chunked_mm_input, + enable_chunked_prefill=True, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + kwargs_cache = ({} if enable_prefix_caching is None else { + 'enable_prefix_caching': enable_prefix_caching + }) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_requests(num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None): + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt=None, + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + requests.append(request) + return requests def test_add_requests(): @@ -174,6 +300,7 @@ def test_no_mm_input_chunking(): model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, disable_chunked_mm_input=True, + max_model_len=2048, ) mm_positions = [[PlaceholderRange(offset=400, length=800)]] requests = create_requests(num_requests=1, @@ -677,20 +804,17 @@ def _assert_right_kv_cache_manager( """Check whether KVCacheManager is correct after allocate.""" # Make sure the request stats are right. - EXPECTED_ACTUAL_BLOCKS = num_tokens // block_size - EXPECTED_TOTAL_BLOCKS = (EXPECTED_ACTUAL_BLOCKS + - scheduler.kv_cache_manager.num_preallocate_blocks) + EXPECTED_TOTAL_BLOCKS = num_tokens // block_size for req_id in req_ids: blocks = scheduler.kv_cache_manager.req_to_blocks[req_id] hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] assert (scheduler.kv_cache_manager.num_cached_block[req_id] == - EXPECTED_ACTUAL_BLOCKS) + EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS - assert len(hashes) == EXPECTED_ACTUAL_BLOCKS + assert len(hashes) == EXPECTED_TOTAL_BLOCKS # Make sure we actually touched all the blocks. - BLOCKS_PER_REQ = (num_tokens / block_size + - scheduler.kv_cache_manager.num_preallocate_blocks) + BLOCKS_PER_REQ = num_tokens / block_size assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == num_total_blocks - num_requests * BLOCKS_PER_REQ) @@ -925,7 +1049,6 @@ def test_kv_connector_handles_preemption(): block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, ) - scheduler.kv_cache_manager.num_preallocate_blocks = 0 NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler.connector.get_num_new_matched_tokens = Mock(name="method") From 2771353a6d1d9c87f45609097e9dda0b8dfecb3e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 01:19:09 +0000 Subject: [PATCH 045/119] Revert "updated" This reverts commit 97316d9df800639b2ef7db77c0a5fbf76d50c747. --- tests/v1/core/test_scheduler.py | 145 +++----------------------------- 1 file changed, 11 insertions(+), 134 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 560a60a81446..367a07876774 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -3,139 +3,13 @@ from unittest.mock import Mock import pytest -import torch -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, VllmConfig) -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange -from vllm.sampling_params import SamplingParams +from vllm.multimodal.inputs import PlaceholderRange +from vllm.tests.v1.utils import EOS_TOKEN_ID, create_requests, create_scheduler from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.request import Request, RequestStatus -from vllm.v1.structured_output import StructuredOutputManager - -EOS_TOKEN_ID = 50256 - - -def create_scheduler( - model: str = "facebook/opt-125m", - max_num_seqs: int = 16, - max_num_batched_tokens: int = 8192, - enable_prefix_caching: Optional[bool] = None, - long_prefill_token_threshold: int = 0, - disable_chunked_mm_input: bool = False, - use_kv_connector: bool = False, - num_blocks: int = 10000, - block_size: int = 16, - max_model_len: Optional[int] = None, -) -> Scheduler: - '''Create scheduler under test. - - Args: - model: model under test - max_num_seqs: max sequences to schedule - max_num_batch_tokens: max num tokens to batch - enable_prefix_caching: optionally force APC config - (True/False) or use default - (None) - - Returns: - :class:`Scheduler` instance - ''' - if max_model_len is None: - max_model_len = max_num_batched_tokens - scheduler_config = SchedulerConfig( - max_num_seqs=max_num_seqs, - max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_model_len, - long_prefill_token_threshold=long_prefill_token_threshold, - disable_chunked_mm_input=disable_chunked_mm_input, - enable_chunked_prefill=True, - ) - model_config = ModelConfig( - model=model, - task="auto", - tokenizer=model, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=42, - ) - # Cache config, optionally force APC - kwargs_cache = ({} if enable_prefix_caching is None else { - 'enable_prefix_caching': enable_prefix_caching - }) - cache_config = CacheConfig( - block_size=block_size, - gpu_memory_utilization=0.9, - swap_space=0, - cache_dtype="auto", - **kwargs_cache, - ) - kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": "local_storage"}, - ) if use_kv_connector else None - - vllm_config = VllmConfig( - scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - ) - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, # A large number of blocks to hold all requests - tensors={}, - kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) - ], - ) - cache_config.num_gpu_blocks = num_blocks - return Scheduler( - vllm_config=vllm_config, - kv_cache_config=kv_cache_config, - log_stats=True, - structured_output_manager=StructuredOutputManager(vllm_config), - ) - - -def create_requests(num_requests: int, - num_tokens: int = 10, - mm_positions: Optional[list[PlaceholderRange]] = None, - max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): - sampling_params = SamplingParams(ignore_eos=False, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - prompt_logprobs=prompt_logprobs) - requests = [] - for i in range(num_requests): - if mm_positions is not None: - mm_position = mm_positions[i] - mm_inputs = [MultiModalKwargs({})] * len(mm_position) - else: - mm_position = None - mm_inputs = None - request = Request( - request_id=f"{i}", - prompt=None, - prompt_token_ids=[i] * num_tokens, - sampling_params=sampling_params, - multi_modal_inputs=mm_inputs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=None, - eos_token_id=EOS_TOKEN_ID, - arrival_time=0, - ) - requests.append(request) - return requests +from vllm.v1.request import RequestStatus def test_add_requests(): @@ -300,7 +174,6 @@ def test_no_mm_input_chunking(): model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, disable_chunked_mm_input=True, - max_model_len=2048, ) mm_positions = [[PlaceholderRange(offset=400, length=800)]] requests = create_requests(num_requests=1, @@ -804,17 +677,20 @@ def _assert_right_kv_cache_manager( """Check whether KVCacheManager is correct after allocate.""" # Make sure the request stats are right. - EXPECTED_TOTAL_BLOCKS = num_tokens // block_size + EXPECTED_ACTUAL_BLOCKS = num_tokens // block_size + EXPECTED_TOTAL_BLOCKS = (EXPECTED_ACTUAL_BLOCKS + + scheduler.kv_cache_manager.num_preallocate_blocks) for req_id in req_ids: blocks = scheduler.kv_cache_manager.req_to_blocks[req_id] hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] assert (scheduler.kv_cache_manager.num_cached_block[req_id] == - EXPECTED_TOTAL_BLOCKS) + EXPECTED_ACTUAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS - assert len(hashes) == EXPECTED_TOTAL_BLOCKS + assert len(hashes) == EXPECTED_ACTUAL_BLOCKS # Make sure we actually touched all the blocks. - BLOCKS_PER_REQ = num_tokens / block_size + BLOCKS_PER_REQ = (num_tokens / block_size + + scheduler.kv_cache_manager.num_preallocate_blocks) assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == num_total_blocks - num_requests * BLOCKS_PER_REQ) @@ -1049,6 +925,7 @@ def test_kv_connector_handles_preemption(): block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, ) + scheduler.kv_cache_manager.num_preallocate_blocks = 0 NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler.connector.get_num_new_matched_tokens = Mock(name="method") From baed1bff432e58290d60503d84b6dc6dd7d3b7c0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 01:30:08 +0000 Subject: [PATCH 046/119] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/kv_connector/test_scheduler.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_scheduler.py index 685e7ce12d84..b09224431ccf 100644 --- a/tests/v1/kv_connector/test_scheduler.py +++ b/tests/v1/kv_connector/test_scheduler.py @@ -127,9 +127,10 @@ def create_requests( return requests -def test_remote_prefill_lifecycle(): +def test_basic_remote_prefill(): scheduler = create_scheduler() - + START_FREE_BLOCK_QUEUE_SIZE = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) NUM_TOKENS = 16 request = create_requests(num_requests=1, num_tokens=NUM_TOKENS, @@ -152,9 +153,12 @@ def test_remote_prefill_lifecycle(): assert scheduled_req.num_computed_tokens == NUM_TOKENS - 1 assert scheduler_output.num_scheduled_tokens[scheduled_req.req_id] == 0 - # We should not cache blocks until the kvs are recved. - cache = scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block - assert len(cache) == 0 + # Blocks should not be cached until the KVs are recv, + # but they should be touched so that they are not preempted. + block_pool = scheduler.kv_cache_manager.block_pool + assert len(block_pool.cached_block_hash_to_block) == 0 + assert (block_pool.free_block_queue.num_free_blocks + < START_FREE_BLOCK_QUEUE_SIZE) assert request_id not in scheduler.kv_cache_manager.num_cached_block engine_core_outputs = scheduler.update_from_output( From d0ad6d949bea4c79cfbe149d9b7a979aadc1bc87 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 01:30:40 +0000 Subject: [PATCH 047/119] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/kv_connector/test_scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_scheduler.py index b09224431ccf..981663c77611 100644 --- a/tests/v1/kv_connector/test_scheduler.py +++ b/tests/v1/kv_connector/test_scheduler.py @@ -186,6 +186,9 @@ def test_basic_remote_prefill(): assert len(scheduler.recving_KV_req_ids) == 0 assert len(engine_core_outputs.outputs) == 0 + # TODO(rob): once we support caching, we should check that the + # blocks are cached here. + # STEP (3): # Remote Prefill: the request should now have scheduled tokens. scheduler_output = scheduler.schedule() From 055885eba8a24d7d617350dc208eff739ef4309b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 01:30:56 +0000 Subject: [PATCH 048/119] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/kv_connector/test_scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_scheduler.py index 981663c77611..662d20372f5a 100644 --- a/tests/v1/kv_connector/test_scheduler.py +++ b/tests/v1/kv_connector/test_scheduler.py @@ -193,6 +193,7 @@ def test_basic_remote_prefill(): # Remote Prefill: the request should now have scheduled tokens. scheduler_output = scheduler.schedule() assert (len(scheduler_output.scheduled_cached_reqs)) == 1 + print(f"{scheduler_output=}") # req_to_index = { # request.request_id: i From 5ed38060464c29d78ed6495892d13abc3c634da9 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 01:37:09 +0000 Subject: [PATCH 049/119] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/kv_connector/test_scheduler.py | 45 ++++++++++++++----------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_scheduler.py index 662d20372f5a..4aed8abfc207 100644 --- a/tests/v1/kv_connector/test_scheduler.py +++ b/tests/v1/kv_connector/test_scheduler.py @@ -17,26 +17,12 @@ EOS_TOKEN_ID = 50256 -def create_scheduler( +def create_vllm_config( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 64, - num_blocks: int = 10000, block_size: int = 8, -) -> Scheduler: - '''Create scheduler under test. - - Args: - model: model under test - max_num_seqs: max sequences to schedule - max_num_batch_tokens: max num tokens to batch - enable_prefix_caching: optionally force APC config - (True/False) or use default - (None) - - Returns: - :class:`Scheduler` instance - ''' +) -> VllmConfig: scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, @@ -63,12 +49,32 @@ def create_scheduler( kv_connector="NixlConnector", kv_role="kv_both", ) - vllm_config = VllmConfig( + return VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, kv_transfer_config=kv_transfer_config, ) + + +def create_scheduler( + vllm_config: VllmConfig, + num_blocks: int = 10000, +) -> Scheduler: + '''Create scheduler under test. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + :class:`Scheduler` instance + ''' + block_size = vllm_config.cache_config.block_size kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests tensors={}, @@ -78,7 +84,7 @@ def create_scheduler( False)) ], ) - cache_config.num_gpu_blocks = num_blocks + vllm_config.cache_config.num_gpu_blocks = num_blocks return Scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, @@ -128,7 +134,8 @@ def create_requests( def test_basic_remote_prefill(): - scheduler = create_scheduler() + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) START_FREE_BLOCK_QUEUE_SIZE = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) NUM_TOKENS = 16 From 58266b53f932a9cbedc99ed9e54870deff3c83bc Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 02:37:31 +0000 Subject: [PATCH 050/119] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/kv_connector/test_model_runner.py | 0 tests/v1/kv_connector/test_scheduler.py | 133 +-------------------- 2 files changed, 4 insertions(+), 129 deletions(-) create mode 100644 tests/v1/kv_connector/test_model_runner.py diff --git a/tests/v1/kv_connector/test_model_runner.py b/tests/v1/kv_connector/test_model_runner.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_scheduler.py index 4aed8abfc207..ff345d260d1a 100644 --- a/tests/v1/kv_connector/test_scheduler.py +++ b/tests/v1/kv_connector/test_scheduler.py @@ -1,141 +1,16 @@ import copy -import torch - -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, VllmConfig) -from vllm.sampling_params import KVTransferParams, SamplingParams -from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT -from vllm.v1.request import Request -from vllm.v1.structured_output import StructuredOutputManager -# SPDX-License-Identifier: Apache-2.0 +from .utils import create_requests, create_scheduler, create_vllm_config -EOS_TOKEN_ID = 50256 - - -def create_vllm_config( - model: str = "facebook/opt-125m", - max_num_seqs: int = 16, - max_num_batched_tokens: int = 64, - block_size: int = 8, -) -> VllmConfig: - scheduler_config = SchedulerConfig( - max_num_seqs=max_num_seqs, - max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_num_batched_tokens, - ) - model_config = ModelConfig( - model=model, - task="auto", - tokenizer=model, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=42, - ) - # Cache config, optionally force APC - cache_config = CacheConfig( - block_size=block_size, - gpu_memory_utilization=0.9, - swap_space=0, - cache_dtype="auto", - enable_prefix_caching=False, - ) - kv_transfer_config = KVTransferConfig( - kv_connector="NixlConnector", - kv_role="kv_both", - ) - return VllmConfig( - scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - ) - - -def create_scheduler( - vllm_config: VllmConfig, - num_blocks: int = 10000, -) -> Scheduler: - '''Create scheduler under test. - - Args: - model: model under test - max_num_seqs: max sequences to schedule - max_num_batch_tokens: max num tokens to batch - enable_prefix_caching: optionally force APC config - (True/False) or use default - (None) - - Returns: - :class:`Scheduler` instance - ''' - block_size = vllm_config.cache_config.block_size - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, # A large number of blocks to hold all requests - tensors={}, - kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) - ], - ) - vllm_config.cache_config.num_gpu_blocks = num_blocks - return Scheduler( - vllm_config=vllm_config, - kv_cache_config=kv_cache_config, - log_stats=True, - structured_output_manager=StructuredOutputManager(vllm_config), - ) - - -def create_requests( - num_requests: int, - num_tokens: int = 10, - max_tokens: int = 16, - do_remote_decode: bool = False, - do_remote_prefill: bool = False, -) -> list[Request]: - if do_remote_decode: - assert not do_remote_prefill - kv_transfer_params = KVTransferParams(do_remote_prefill=True, ) - elif do_remote_prefill: - kv_transfer_params = KVTransferParams( - do_remote_prefill=True, - remote_engine_id="abc", - remote_block_ids=[1, 2, 3], - ) - else: - kv_transfer_params = None - - sampling_params = SamplingParams( - max_tokens=max_tokens, - kv_transfer_params=kv_transfer_params, - ) - requests = [] - for i in range(num_requests): - request = Request( - request_id=f"{i}", - prompt=None, - prompt_token_ids=[i] * num_tokens, - sampling_params=sampling_params, - multi_modal_inputs=None, - multi_modal_placeholders=None, - multi_modal_hashes=None, - eos_token_id=EOS_TOKEN_ID, - arrival_time=0, - ) - requests.append(request) - return requests +# SPDX-License-Identifier: Apache-2.0 def test_basic_remote_prefill(): vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) + START_FREE_BLOCK_QUEUE_SIZE = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) NUM_TOKENS = 16 @@ -200,7 +75,7 @@ def test_basic_remote_prefill(): # Remote Prefill: the request should now have scheduled tokens. scheduler_output = scheduler.schedule() assert (len(scheduler_output.scheduled_cached_reqs)) == 1 - print(f"{scheduler_output=}") + assert (scheduler_output.num_scheduled_tokens[request_id]) == 1 # req_to_index = { # request.request_id: i From 344d9da72eed0caee13db907718f5e557d8d91e0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 03:29:54 +0000 Subject: [PATCH 051/119] stash Signed-off-by: rshaw@neuralmagic.com --- tests/v1/kv_connector/test_model_runner.py | 55 +++++++++++++++++++ tests/v1/kv_connector/test_scheduler.py | 8 +-- .../kv_connector/v1/nixl_connector.py | 2 + 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/tests/v1/kv_connector/test_model_runner.py b/tests/v1/kv_connector/test_model_runner.py index e69de29bb2d1..9bde8693c4ac 100644 --- a/tests/v1/kv_connector/test_model_runner.py +++ b/tests/v1/kv_connector/test_model_runner.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from .utils import (create_model_runner, create_request, create_scheduler, + create_vllm_config) + + +def test_basic_remote_prefill(): + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + model_runner = create_model_runner(vllm_config=vllm_config, + device=torch.device(type="cuda")) + + NUM_TOKENS = 16 + + normal_request = create_request(request_id=0, num_tokens=NUM_TOKENS) + + remote_request = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) + + scheduler.add_request(normal_request) + scheduler.add_request(remote_request) + + scheduler_output = scheduler.schedule() + + # Both should be running, but only the normal request + # should have scheduled tokens. + assert len(scheduler.running) == 2 + assert scheduler_output.num_scheduled_tokens[ + normal_request.request_id] == NUM_TOKENS + assert scheduler_output.num_scheduled_tokens[ + remote_request.request_id] == 0 + + for scheduled_new_req in scheduler_output.scheduled_new_reqs: + # Remote request has all tokens computed externally. + if scheduled_new_req.req_id == remote_request.request_id: + assert scheduled_new_req.num_computed_tokens == NUM_TOKENS - 1 + # Normal request has no tokens computed externally. + if scheduled_new_req.req_id == normal_request.request_id: + assert scheduled_new_req.num_computed_tokens == 0 + + # model_runner.execute_model does: + # * _update_states + # * returns if no tokens scheduled + # * _prepare_inputs + model_runner._update_states(scheduler_output) + attn_metadata, logits_indices, spec_decode_metadata = ( + model_runner._prepare_inputs(scheduler_output)) + + print(f"{attn_metadata=}") + print(f"{logits_indices=}") + print(f"{spec_decode_metadata=}") diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_scheduler.py index ff345d260d1a..dd5727c46717 100644 --- a/tests/v1/kv_connector/test_scheduler.py +++ b/tests/v1/kv_connector/test_scheduler.py @@ -2,7 +2,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT -from .utils import create_requests, create_scheduler, create_vllm_config +from .utils import create_request, create_scheduler, create_vllm_config # SPDX-License-Identifier: Apache-2.0 @@ -14,9 +14,9 @@ def test_basic_remote_prefill(): START_FREE_BLOCK_QUEUE_SIZE = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) NUM_TOKENS = 16 - request = create_requests(num_requests=1, - num_tokens=NUM_TOKENS, - do_remote_prefill=True)[0] + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) scheduler.add_request(request) request_id = request.request_id diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c7c1fab2b7cf..1cb5292bf4ba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -161,6 +161,8 @@ def get_num_new_matched_tokens(self, request: "Request", # token so that we can sample the first token here. num_external_tokens = len(request.prompt_token_ids) - 1 return num_external_tokens - num_computed_tokens + else: + return 0 def update_state_after_alloc(self, request: "Request", num_external_tokens: int): From 29966382a1f5feea559b296a23188794bd8e87c6 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 24 Apr 2025 12:06:18 +0000 Subject: [PATCH 052/119] added Signed-off-by: rshaw@neuralmagic.com --- tests/v1/kv_connector/utils.py | 123 +++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 tests/v1/kv_connector/utils.py diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py new file mode 100644 index 000000000000..a70a6e8c7d92 --- /dev/null +++ b/tests/v1/kv_connector/utils.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, VllmConfig) +from vllm.sampling_params import KVTransferParams, SamplingParams +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +EOS_TOKEN_ID = 50256 + + +def create_vllm_config( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 64, + block_size: int = 8, +) -> VllmConfig: + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=False, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + return VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + ) + + +def create_model_runner( + vllm_config: VllmConfig, + device: torch.device, +) -> GPUModelRunner: + return GPUModelRunner( + vllm_config=vllm_config, + device=device, + ) + + +def create_scheduler( + vllm_config: VllmConfig, + num_blocks: int = 10000, +) -> Scheduler: + block_size = vllm_config.cache_config.block_size + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + vllm_config.cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_request( + request_id: int, + num_tokens: int = 10, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, +) -> list[Request]: + if do_remote_decode: + assert not do_remote_prefill + kv_transfer_params = KVTransferParams(do_remote_prefill=True, ) + elif do_remote_prefill: + kv_transfer_params = KVTransferParams( + do_remote_prefill=True, + remote_engine_id="abc", + remote_block_ids=[1, 2, 3], + ) + else: + kv_transfer_params = None + + sampling_params = SamplingParams( + max_tokens=max_tokens, + kv_transfer_params=kv_transfer_params, + ) + return Request( + request_id=f"id-{request_id}", + prompt=None, + prompt_token_ids=[request_id] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=None, + multi_modal_placeholders=None, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) From bcc88dcefc5196d0df829b2bfa129b8101704bf0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 24 Apr 2025 08:36:40 -0400 Subject: [PATCH 053/119] diffs for local dev on macos Signed-off-by: Robert Shaw --- tests/v1/kv_connector/test_model_runner.py | 2 +- tests/v1/kv_connector/utils.py | 7 ++++--- .../kv_transfer/kv_connector/v1/nixl_connector.py | 3 ++- vllm/platforms/cpu.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 9 +++++---- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/v1/kv_connector/test_model_runner.py b/tests/v1/kv_connector/test_model_runner.py index 9bde8693c4ac..a6124ceeeb81 100644 --- a/tests/v1/kv_connector/test_model_runner.py +++ b/tests/v1/kv_connector/test_model_runner.py @@ -9,7 +9,7 @@ def test_basic_remote_prefill(): vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) model_runner = create_model_runner(vllm_config=vllm_config, - device=torch.device(type="cuda")) + device=torch.device(type="cpu")) NUM_TOKENS = 16 diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index a70a6e8c7d92..1aa42cfafd89 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, VllmConfig) +from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, + ModelConfig, SchedulerConfig, VllmConfig) from vllm.sampling_params import KVTransferParams, SamplingParams from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -18,7 +18,7 @@ def create_vllm_config( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 64, - block_size: int = 8, + block_size: int = 16, ) -> VllmConfig: scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, @@ -51,6 +51,7 @@ def create_vllm_config( model_config=model_config, cache_config=cache_config, kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu") ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1cb5292bf4ba..7691d2fd1224 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -203,7 +203,8 @@ def __init__(self, engine_id: str): logger.info("Initializing NIXL wrapper") # Agent. - self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + # self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + self.nixl_wrapper = None # Map of engine_id -> list[agent_names] (1 per rank). self._remote_agents: dict[str, list[str]] = {} diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 70553354a060..47a48126ed5c 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -43,7 +43,8 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, logger.info("Using CPU MLA backend.") return "vllm.attention.backends.cpu_mla.CPUMLABackend" logger.info("Using Torch SDPA backend.") - return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" + return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + # return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d442aff3b03e..d463e16cc1b1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -38,8 +38,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import RejectionSampler -from vllm.v1.spec_decode.eagle import EagleProposer +# from vllm.v1.sample.rejection_sampler import RejectionSampler +# from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported @@ -199,8 +199,9 @@ def __init__( self.vllm_config.compilation_config.cudagraph_capture_sizes)) # Cache the device properties. - self.device_properties = torch.cuda.get_device_properties(self.device) - self.num_sms = self.device_properties.multi_processor_count + # self.device_properties = torch.cuda.get_device_properties(self.device) + # self.num_sms = self.device_properties.multi_processor_count + self.num_sms = 0 # Persistent buffers for CUDA graphs. self.input_ids = torch.zeros(self.max_num_tokens, From 62205ae43393a688ae3a3804305a85f3d6175369 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 24 Apr 2025 11:12:58 -0400 Subject: [PATCH 054/119] updated Signed-off-by: Robert Shaw --- vllm/v1/core/kv_cache_manager.py | 11 ++- vllm/v1/core/sched/scheduler.py | 134 +++++++++++++++---------------- vllm/v1/request.py | 1 + 3 files changed, 76 insertions(+), 70 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 33bb825a11a7..b634c2bc14b6 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -283,13 +283,22 @@ def allocate_slots( new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) - if not self.enable_caching or skip_cache_blocks: + if not self.enable_caching: return new_blocks # Use `new_computed_blocks` for a new request, and `num_cached_block` # for a running request. num_cached_blocks = self.num_cached_block.get(request.request_id, len(new_computed_blocks)) + + # Skip performing the actual caching + # This is useful for P/D such that we do not prematurely cache + # blocks which are being filled over multiple steps. + if skip_cache_blocks: + self.num_cached_block[ + request.request_id] = num_cached_blocks + return new_blocks + # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c81260dfc15f..80710993a5fd 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -100,8 +100,7 @@ def __init__( self.finished_req_ids: set[str] = set() # Requests in states for tracking KV transfers for P/D disagg - self.sending_KV_req_ids: set[str] = set() - self.recving_KV_req_ids: set[str] = set() + self.finished_recving_KV_req_ids: set[str] = set() # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. @@ -176,11 +175,6 @@ def schedule(self) -> SchedulerOutput: req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - if (request.request_id in self.recving_KV_req_ids - or request.request_id in self.sending_KV_req_ids): - # P/D: This request is still recv/sending KVs. - req_index += 1 - continue if request.request_id in self.scheduled_req_ids: # This request has already been scheduled. req_index += 1 @@ -223,11 +217,6 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. # Preempt the lowest-priority request. preempted_req = self.running.pop() - # NOTE(rob): we cannot free these blocks once in flight. - # TODO(rob): understand full implications of this. - if preempted_req.request_id in self.recving_KV_req_ids: - pass - self.kv_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 @@ -305,6 +294,16 @@ def schedule(self) -> SchedulerOutput: request = self.waiting[0] + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + if request.request_id in self.finished_recving_req_ids: + # We delayed caching the blocks until after they + # are recved to avoid cache hits from other reqs + # before the KVs are written. + self.kv_cache_manager.cache_blocks(request) + # TODO: how can we do a better job with this? + request.num_computed_tokens = len(request.all_token_ids) - 1 + request.status = RequestStatus.WAITING + # Skip request if the structured output request is still waiting # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: @@ -340,46 +339,8 @@ def schedule(self) -> SchedulerOutput: # Total computed tokens (local + external). num_computed_tokens += num_external_tokens - # TODO: how can we make this code clean? - if not request.do_remote_prefill: - - # Number of tokens to be scheduled. - # We use `request.num_tokens` instead of - # `request.num_prompt_tokens` to consider the resumed reqs, - # which have output tokens. - num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) - num_new_tokens = min(num_new_tokens, token_budget) - assert num_new_tokens > 0 - - # Schedule encoder inputs. - if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_budget) - if num_new_tokens == 0: - # The request cannot be scheduled. - break - else: - encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget - - new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens + num_external_tokens, - computed_blocks) - if new_blocks is None: - # The request cannot be scheduled. - break - else: - # TODO: handle preempted state. - assert request.status != RequestStatus.PREEMPTED - assert self.connector is not None - + if (request.do_remote_prefill and + num_external_tokens > 0): # Schedule 0 tokens until the recv is done. num_new_tokens = 0 @@ -391,14 +352,46 @@ def schedule(self) -> SchedulerOutput: computed_blocks, skip_cache_blocks=True) if new_blocks is None: - # Request cannot be scheduled. + # Blocked cannot be allocated. break - self.recving_KV_req_ids.add(request.request_id) + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue - # TODO: clean up code + # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed reqs, + # which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget + ) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + else: encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens + num_external_tokens, + computed_blocks) + if new_blocks is None: + # The request cannot be scheduled. + break + # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. @@ -756,9 +749,8 @@ def update_from_output( # inside AsyncLLM. if request.do_remote_decode and not stopped: request.status = RequestStatus.FINISHED_REMOTE_DECODE - self.sending_KV_req_ids.add(req_id) + self._free_request(request, skip_free_blocks=True) # TODO(rob): do this on a per-Connector basis. - # From POV of DWorker, this is a remote prefill. kv_transfer_params = KVTransferParams( do_remote_prefill=True, # put the remote block ids here @@ -790,14 +782,9 @@ def update_from_output( # P/D: update recv and send status from last step. for req_id in (model_runner_output.finished_recving or []): - # TODO(rob): Implement this method. - # Cache blocks for APC after KVs have been recv'ed. - # self.kv_cache_manager.cache_blocks(req_id) - self.scheduled_req_ids.remove(req_id) - self.recving_KV_req_ids.remove(req_id) - print(f"{self.requests[req_id].num_computed_tokens=}") + self.finished_recving_KV_req_ids.add(req_id) for req_id in (model_runner_output.finished_sending or []): - self._free_request(self.requests[req_id]) + self._free_blocks(self.requests[req_id]) self.running = new_running engine_core_outputs = EngineCoreOutputs( @@ -847,16 +834,25 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, request: Request) -> None: + def _free_request(self, request: Request, + skip_free_blocks: bool = False) -> None: assert request.is_finished() - self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) - del self.requests[request.request_id] - self.sending_KV_req_ids.discard(request.request_id) self.finished_req_ids.add(request.request_id) + if not skip_free_blocks: + self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) + del self.requests[request.request_id] + + def _free_blocks(self, request: Request): + assert request.is_finished() + assert request.request_id not in self._cached_reqs_data + self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) + del self.requests[request.request_id] + def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 60b004dc0b2d..bd0d57ee4c8e 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -159,6 +159,7 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() + WAITING_FOR_REMOTE_KVS = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED will be considered From b4609a5d02078b073c2de6ee3f79be59165c20c1 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 24 Apr 2025 17:24:42 -0500 Subject: [PATCH 055/119] update Signed-off-by: Robert Shaw --- tests/v1/kv_connector/test_scheduler.py | 123 ++++++++++-------- tests/v1/kv_connector/utils.py | 2 +- .../kv_connector/v1/nixl_connector.py | 18 ++- vllm/v1/core/kv_cache_manager.py | 48 ++++--- vllm/v1/core/sched/scheduler.py | 37 +++--- 5 files changed, 135 insertions(+), 93 deletions(-) diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_scheduler.py index dd5727c46717..f34072e3b678 100644 --- a/tests/v1/kv_connector/test_scheduler.py +++ b/tests/v1/kv_connector/test_scheduler.py @@ -1,19 +1,20 @@ import copy from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import Request, RequestStatus from .utils import create_request, create_scheduler, create_vllm_config # SPDX-License-Identifier: Apache-2.0 -def test_basic_remote_prefill(): +def test_single_remote_prefill(): vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) START_FREE_BLOCK_QUEUE_SIZE = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) - NUM_TOKENS = 16 + NUM_TOKENS = 32 request = create_request(request_id=1, num_tokens=NUM_TOKENS, do_remote_prefill=True) @@ -22,72 +23,84 @@ def test_basic_remote_prefill(): request_id = request.request_id # STEP (1): - # Remote Prefill: req should be scheduled with 0 tokens - # but have the entire prompt "computed" from the POV of - # the scheduler + persistent batch (since the KVConnector - # will write directly into allocated blocks). + # (1a): schedule() scheduler_output = scheduler.schedule() - assert len(scheduler.running) == 1 - assert len(scheduler.recving_KV_req_ids) == 1 - assert len(scheduler_output.scheduled_new_reqs) == 1 - - scheduled_req = scheduler_output.scheduled_new_reqs[0] - assert scheduled_req.num_computed_tokens == NUM_TOKENS - 1 - assert scheduler_output.num_scheduled_tokens[scheduled_req.req_id] == 0 - # Blocks should not be cached until the KVs are recv, - # but they should be touched so that they are not preempted. + # Nothing running and empty scheduler output. + assert len(scheduler.running) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler_output.num_scheduled_tokens) == 0 + assert scheduler_output.total_num_scheduled_tokens == 0 + + # Req waiting for KVs with no computed + # or scheduled tokens. + assert len(scheduler.waiting) == 1 + assert request in scheduler.waiting + assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + assert (request.num_computed_tokens == 0) + + # ... but should have (uncached) blocks allocated to it. block_pool = scheduler.kv_cache_manager.block_pool - assert len(block_pool.cached_block_hash_to_block) == 0 assert (block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE) - assert request_id not in scheduler.kv_cache_manager.num_cached_block + assert len(block_pool.cached_block_hash_to_block) == 0 + # (1b): forward() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (1c): update_from_output() engine_core_outputs = scheduler.update_from_output( - scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) - # Request should still be in the running and recving state. - assert len(scheduler.running) == 1 - assert len(scheduler.recving_KV_req_ids) == 1 + scheduler_output, model_runner_output) assert len(engine_core_outputs.outputs) == 0 # STEP (2): - # Remote Prefill: req should be running. + # (2a): schedule(): nothing happens! scheduler_output = scheduler.schedule() - assert len(scheduler.running) == 1 - assert len(scheduler.recving_KV_req_ids) == 1 - assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 0 + + # (2b): forward(): request finishes recv. + model_runner_output = copy.deepcopy( + EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] - model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - model_runner_output.finished_recving.append(request_id) - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + # (2c): update_from_output(): + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_KV_req_ids) - # Request should be out of the recving state. + # (3a): schedule(): this should actually schedule. + scheduler_output = scheduler.schedule() assert len(scheduler.running) == 1 - assert len(scheduler.recving_KV_req_ids) == 0 - assert len(engine_core_outputs.outputs) == 0 - # TODO(rob): once we support caching, we should check that the - # blocks are cached here. - # STEP (3): - # Remote Prefill: the request should now have scheduled tokens. - scheduler_output = scheduler.schedule() - assert (len(scheduler_output.scheduled_cached_reqs)) == 1 - assert (scheduler_output.num_scheduled_tokens[request_id]) == 1 - - # req_to_index = { - # request.request_id: i - # for i, request in enumerate(requests) - # } - # model_runner_output = ModelRunnerOutput( - # req_ids=[request.request_id for request in requests], - # req_id_to_index=req_to_index, - # # Only the first request has a sampled token id because - # # the rest requests are still being prefilled. - # sampled_token_ids=[[0], [], []], - # spec_token_ids=None, - # logprobs=None, - # prompt_logprobs_dict={}, - # ) + # # Request should be out of the recving state. + # assert len(scheduler.running) == 1 + # assert len(scheduler.recving_KV_req_ids) == 0 + # assert len(engine_core_outputs.outputs) == 0 + + # # TODO(rob): once we support caching, we should check that the + # # blocks are cached here. + + # # STEP (3): + # # Remote Prefill: the request should now have scheduled tokens. + # scheduler_output = scheduler.schedule() + # assert (len(scheduler_output.scheduled_cached_reqs)) == 1 + # assert (scheduler_output.num_scheduled_tokens[request_id]) == 1 + + # # req_to_index = { + # # request.request_id: i + # # for i, request in enumerate(requests) + # # } + # # model_runner_output = ModelRunnerOutput( + # # req_ids=[request.request_id for request in requests], + # # req_id_to_index=req_to_index, + # # # Only the first request has a sampled token id because + # # # the rest requests are still being prefilled. + # # sampled_token_ids=[[0], [], []], + # # spec_token_ids=None, + # # logprobs=None, + # # prompt_logprobs_dict={}, + # # ) diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index 1aa42cfafd89..f8a82c6bcb23 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -40,7 +40,7 @@ def create_vllm_config( gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", - enable_prefix_caching=False, + enable_prefix_caching=True, ) kv_transfer_config = KVTransferConfig( kv_connector="NixlConnector", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7691d2fd1224..f607c4131fd0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -146,6 +146,7 @@ class NixlConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size self.engine_id = engine_id # Requests that need to start recv. @@ -156,11 +157,20 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> int: """For remote prefill, allocate for all tokens.""" + + # NOTE: this function is called in the WAITING loop. + # So we should only have full blocks of computed tokens. + assert num_computed_tokens % self.block_size == 0 + if request.do_remote_prefill: - # Subtract 1 since we do not compute the last prompt - # token so that we can sample the first token here. - num_external_tokens = len(request.prompt_token_ids) - 1 - return num_external_tokens - num_computed_tokens + # NOTE: subtract 1 since we compute the last token + # here so that we can sample the first token. + num_prompt_tokens = len(request.prompt_token_ids) - 1 + + # Round down to a full block shape. + num_external_blocks = num_prompt_tokens // self.block_size + rounded_num_prompt_tokens = num_external_blocks * self.block_size + return max(rounded_num_prompt_tokens - num_computed_tokens, 0) else: return 0 diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b634c2bc14b6..ffab0f1c38a9 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -166,10 +166,6 @@ def get_computed_blocks( num_computed_tokens = len(computed_blocks) * self.block_size return computed_blocks, num_computed_tokens - def cache_blocks(self, request: Request): - # TODO: implement this. - pass - def allocate_slots( self, request: Request, @@ -283,28 +279,47 @@ def allocate_slots( new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) - if not self.enable_caching: + if not self.enable_caching or skip_cache_blocks: + # If self.enable_caching, this is true since can only + # get to this codepath when we have never been scheduled. + assert request.request_id not in self.num_cached_block return new_blocks + self.cache_blocks( + request=request, + num_tokens=num_tokens, + num_computed_tokens=num_computed_tokens, + new_computed_blocks=new_computed_blocks, + ) + return new_blocks + + def cache_blocks( + self, + request: Request, + num_tokens: int, + num_computed_tokens: int, + new_computed_blocks: Optional[list[KVCacheBlock]] = None, + ): + if new_computed_blocks is None: + new_computed_blocks = [] + + req_blocks = self.req_to_blocks[request.request_id] + # Use `new_computed_blocks` for a new request, and `num_cached_block` # for a running request. num_cached_blocks = self.num_cached_block.get(request.request_id, len(new_computed_blocks)) - - # Skip performing the actual caching - # This is useful for P/D such that we do not prematurely cache - # blocks which are being filled over multiple steps. - if skip_cache_blocks: - self.num_cached_block[ - request.request_id] = num_cached_blocks - return new_blocks - # Speculated tokens might be rejected in the future, so we does + # Speculated tokens might be rejected in the future, so we do # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. - num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( - request.spec_token_ids)) // self.block_size + num_full_blocks_after_append = ( + num_computed_tokens + num_tokens - len(request.spec_token_ids)) // self.block_size + print(f"{req_blocks=}") + print(f"{self.req_to_block_hashes[request.request_id]=}") + print(f"{num_cached_blocks=}") + print(f"{num_full_blocks_after_append=}") self.block_pool.cache_full_blocks( request=request, blocks=req_blocks, @@ -317,7 +332,6 @@ def allocate_slots( self.num_cached_block[ request.request_id] = num_full_blocks_after_append - return new_blocks def free(self, request: Request) -> None: """Free the blocks allocated for the request. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 80710993a5fd..96438deec3c5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -294,15 +294,24 @@ def schedule(self) -> SchedulerOutput: request = self.waiting[0] + # Skip request if the remote KV recv is still waiting + # for the requests to arrive. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: - if request.request_id in self.finished_recving_req_ids: - # We delayed caching the blocks until after they - # are recved to avoid cache hits from other reqs - # before the KVs are written. - self.kv_cache_manager.cache_blocks(request) - # TODO: how can we do a better job with this? - request.num_computed_tokens = len(request.all_token_ids) - 1 + if request.request_id in self.finished_recving_KV_req_ids: + assert self.kv_cache_manager.enable_caching + # Now that the KVs have been recved, we can cache + # them and set num_computed_tokens. + self.kv_cache_manager.cache_blocks( + request, + num_tokens=0, + num_computed_tokens=(len(request.all_token_ids) - 1) + ) request.status = RequestStatus.WAITING + self.kv_cache_manager.free(request) + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue # Skip request if the structured output request is still waiting # for FSM compilation. @@ -335,15 +344,12 @@ def schedule(self) -> SchedulerOutput: 0 if self.connector is None else self.connector.get_num_new_matched_tokens( request, num_computed_tokens)) + print(f"{num_external_tokens=}") # Total computed tokens (local + external). num_computed_tokens += num_external_tokens - if (request.do_remote_prefill and - num_external_tokens > 0): - # Schedule 0 tokens until the recv is done. - num_new_tokens = 0 - + if (request.do_remote_prefill and num_external_tokens > 0): # Allocate slots for the external tokens, but skip # caching until after the KV transfer is done. new_blocks = self.kv_cache_manager.allocate_slots( @@ -352,8 +358,9 @@ def schedule(self) -> SchedulerOutput: computed_blocks, skip_cache_blocks=True) if new_blocks is None: - # Blocked cannot be allocated. + # Requests cannot be scheduled break + self.waiting.popleft() skipped_waiting_requests.appendleft(request) request.status = RequestStatus.WAITING_FOR_REMOTE_KVS @@ -842,9 +849,7 @@ def _free_request(self, request: Request, self.finished_req_ids.add(request.request_id) if not skip_free_blocks: - self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) - del self.requests[request.request_id] + self._free_blocks(request) def _free_blocks(self, request: Request): assert request.is_finished() From 5d78ba6091d83c692573c0936ecf7b902c481658 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 25 Apr 2025 08:14:10 -0500 Subject: [PATCH 056/119] updaed Signed-off-by: Robert Shaw --- tests/v1/kv_connector/test_scheduler.py | 20 +++++++++++++++----- tests/v1/kv_connector/utils.py | 2 +- vllm/v1/core/kv_cache_manager.py | 4 ---- vllm/v1/core/sched/scheduler.py | 1 + 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_scheduler.py index f34072e3b678..747e28b6d069 100644 --- a/tests/v1/kv_connector/test_scheduler.py +++ b/tests/v1/kv_connector/test_scheduler.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 import copy from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT @@ -5,16 +6,16 @@ from .utils import create_request, create_scheduler, create_vllm_config -# SPDX-License-Identifier: Apache-2.0 - - def test_single_remote_prefill(): vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) - + + # 2 and a half full external blocks. + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(vllm_config.cache_config.block_size * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) START_FREE_BLOCK_QUEUE_SIZE = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) - NUM_TOKENS = 32 + request = create_request(request_id=1, num_tokens=NUM_TOKENS, do_remote_prefill=True) @@ -45,6 +46,8 @@ def test_single_remote_prefill(): assert (block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE) assert len(block_pool.cached_block_hash_to_block) == 0 + for block in scheduler.kv_cache_manager.req_to_blocks[request_id]: + assert block._block_hash is None # (1b): forward() model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT @@ -74,6 +77,13 @@ def test_single_remote_prefill(): # (3a): schedule(): this should actually schedule. scheduler_output = scheduler.schedule() assert len(scheduler.running) == 1 + + # Confirm the block are actually allocated. + num_hashed_blocks = 0 + for block in scheduler.kv_cache_manager.req_to_blocks[request_id]: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS # # Request should be out of the recving state. diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index f8a82c6bcb23..6ccdb7662d74 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -114,7 +114,7 @@ def create_request( return Request( request_id=f"id-{request_id}", prompt=None, - prompt_token_ids=[request_id] * num_tokens, + prompt_token_ids=[i * request_id for i in range(num_tokens)], sampling_params=sampling_params, multi_modal_inputs=None, multi_modal_placeholders=None, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ffab0f1c38a9..4c0c56a3a96f 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -316,10 +316,6 @@ def cache_blocks( num_full_blocks_after_append = ( num_computed_tokens + num_tokens - len(request.spec_token_ids)) // self.block_size - print(f"{req_blocks=}") - print(f"{self.req_to_block_hashes[request.request_id]=}") - print(f"{num_cached_blocks=}") - print(f"{num_full_blocks_after_append=}") self.block_pool.cache_full_blocks( request=request, blocks=req_blocks, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 96438deec3c5..550ecfb9267c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -307,6 +307,7 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens=(len(request.all_token_ids) - 1) ) request.status = RequestStatus.WAITING + print(f"{self.kv_cache_manager.req_to_blocks[request.request_id]=}") self.kv_cache_manager.free(request) else: self.waiting.popleft() From c1f26b96c75d9b7cfc15d3e31a00d74386d6eef0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 25 Apr 2025 08:19:45 -0500 Subject: [PATCH 057/119] updated Signed-off-by: Robert Shaw --- .../{test_scheduler.py => test_nixl_connector.py} | 6 ++---- vllm/v1/core/sched/scheduler.py | 2 -- 2 files changed, 2 insertions(+), 6 deletions(-) rename tests/v1/kv_connector/{test_scheduler.py => test_nixl_connector.py} (97%) diff --git a/tests/v1/kv_connector/test_scheduler.py b/tests/v1/kv_connector/test_nixl_connector.py similarity index 97% rename from tests/v1/kv_connector/test_scheduler.py rename to tests/v1/kv_connector/test_nixl_connector.py index 747e28b6d069..cc07c35aa9e0 100644 --- a/tests/v1/kv_connector/test_scheduler.py +++ b/tests/v1/kv_connector/test_nixl_connector.py @@ -85,15 +85,13 @@ def test_single_remote_prefill(): num_hashed_blocks += (1 if block._block_hash is not None else 0) assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS - + # Confirm the rest of the prompt is scheduled. + print(f"{scheduler_output=}") # # Request should be out of the recving state. # assert len(scheduler.running) == 1 # assert len(scheduler.recving_KV_req_ids) == 0 # assert len(engine_core_outputs.outputs) == 0 - # # TODO(rob): once we support caching, we should check that the - # # blocks are cached here. - # # STEP (3): # # Remote Prefill: the request should now have scheduled tokens. # scheduler_output = scheduler.schedule() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 550ecfb9267c..62c0b95c111a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -307,7 +307,6 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens=(len(request.all_token_ids) - 1) ) request.status = RequestStatus.WAITING - print(f"{self.kv_cache_manager.req_to_blocks[request.request_id]=}") self.kv_cache_manager.free(request) else: self.waiting.popleft() @@ -345,7 +344,6 @@ def schedule(self) -> SchedulerOutput: 0 if self.connector is None else self.connector.get_num_new_matched_tokens( request, num_computed_tokens)) - print(f"{num_external_tokens=}") # Total computed tokens (local + external). num_computed_tokens += num_external_tokens From 9b9ef3696cf1e3b70033c4df22a2287e24673718 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 25 Apr 2025 11:58:21 -0500 Subject: [PATCH 058/119] updated Signed-off-by: Robert Shaw --- ...or.py => test_remote_prefill_lifecycle.py} | 119 +++++++++++++----- tests/v1/kv_connector/utils.py | 12 +- 2 files changed, 100 insertions(+), 31 deletions(-) rename tests/v1/kv_connector/{test_nixl_connector.py => test_remote_prefill_lifecycle.py} (50%) diff --git a/tests/v1/kv_connector/test_nixl_connector.py b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py similarity index 50% rename from tests/v1/kv_connector/test_nixl_connector.py rename to tests/v1/kv_connector/test_remote_prefill_lifecycle.py index cc07c35aa9e0..fbda09c66a42 100644 --- a/tests/v1/kv_connector/test_nixl_connector.py +++ b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py @@ -2,17 +2,22 @@ import copy from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT -from vllm.v1.request import Request, RequestStatus +from vllm.v1.request import RequestStatus from .utils import create_request, create_scheduler, create_vllm_config def test_single_remote_prefill(): + """ + Test that the request lifecycle for a remote prefill + works as expected. + """ vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) - # 2 and a half full external blocks. + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 - NUM_TOKENS = int(vllm_config.cache_config.block_size * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) START_FREE_BLOCK_QUEUE_SIZE = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) @@ -85,30 +90,84 @@ def test_single_remote_prefill(): num_hashed_blocks += (1 if block._block_hash is not None else 0) assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS - # Confirm the rest of the prompt is scheduled. - print(f"{scheduler_output=}") - # # Request should be out of the recving state. - # assert len(scheduler.running) == 1 - # assert len(scheduler.recving_KV_req_ids) == 0 - # assert len(engine_core_outputs.outputs) == 0 - - # # STEP (3): - # # Remote Prefill: the request should now have scheduled tokens. - # scheduler_output = scheduler.schedule() - # assert (len(scheduler_output.scheduled_cached_reqs)) == 1 - # assert (scheduler_output.num_scheduled_tokens[request_id]) == 1 - - # # req_to_index = { - # # request.request_id: i - # # for i, request in enumerate(requests) - # # } - # # model_runner_output = ModelRunnerOutput( - # # req_ids=[request.request_id for request in requests], - # # req_id_to_index=req_to_index, - # # # Only the first request has a sampled token id because - # # # the rest requests are still being prefilled. - # # sampled_token_ids=[[0], [], []], - # # spec_token_ids=None, - # # logprobs=None, - # # prompt_logprobs_dict={}, - # # ) + # Confirm the rest of the prompt is scheduled in this step. + scheduled_req = scheduler_output.scheduled_new_reqs[0] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] + num_computed_tokens = scheduled_req.num_computed_tokens + total_prompt_tokens = len(scheduled_req.prompt_token_ids) + assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + + +def test_remote_prefill_no_prefix_cache_uncomputed_blocks(): + """ + With P/D, blocks can be allocated but uncomputed for + multiple engine steps. This test confirms that we do + not accidentally have cache hits against uncomputed + blocks. + """ + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 and a half full external blocks. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + # Both of these requests have prompts like [1,1,1,1,1, ...] + request_remote = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + use_all_1s_for_prompt_tokens=True, + ) + + request_local = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + do_remote_prefill=False, + use_all_1s_for_prompt_tokens=True, + ) + + # Schedule the remote prefill request. This should not + # cause any blocks to be cached. + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + scheduler.update_from_output( + scheduler_output, + EMPTY_MODEL_RUNNER_OUTPUT + ) + assert len(scheduler.waiting) == 1 + + # Schedule the local prefill request. This should + # cause blocks to be cached, but separately from + scheduler.add_request(request_local) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + local_blocks = scheduler.kv_cache_manager.req_to_blocks[request_local.request_id] + remote_blocks = scheduler.kv_cache_manager.req_to_blocks[request_remote.request_id] + + # Local should have cached blocks (but not all due to preallocate). + num_hashed_blocks = 0 + for block in local_blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += ( + 1 if block._block_hash is not None else 0) + assert num_hashed_blocks > 0 + + # Remote blocks should not be cached. + for block in remote_blocks: + assert block.ref_cnt == 1 + assert block._block_hash is None + + +def test_remote_prefill_no_blocks_available(): + """ + letTest whether we properly handle no blocks available + """ + pass \ No newline at end of file diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index 6ccdb7662d74..17a24f590392 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -94,6 +94,7 @@ def create_request( max_tokens: int = 16, do_remote_decode: bool = False, do_remote_prefill: bool = False, + use_all_1s_for_prompt_tokens: bool = False, ) -> list[Request]: if do_remote_decode: assert not do_remote_prefill @@ -111,10 +112,19 @@ def create_request( max_tokens=max_tokens, kv_transfer_params=kv_transfer_params, ) + + + if use_all_1s_for_prompt_tokens: + prompt_token_ids = [1] * num_tokens + else: + prompt_token_ids = [ + i * request_id for i in range(num_tokens) + ] + return Request( request_id=f"id-{request_id}", prompt=None, - prompt_token_ids=[i * request_id for i in range(num_tokens)], + prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, multi_modal_inputs=None, multi_modal_placeholders=None, From c60639e60dd3f0d3fc3c9fafe6d807d8ead6b669 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 25 Apr 2025 15:36:05 -0400 Subject: [PATCH 059/119] Checkpoint. Signed-off-by: Tyler Michael Smith --- .../openai_completion_client.py | 2 +- .../kv_connector/v1/nixl_connector.py | 163 +++++++++++++++--- 2 files changed, 138 insertions(+), 27 deletions(-) diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index 7917ac4797b5..b31ebcccce3b 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -4,7 +4,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:9000/v1" +openai_api_base = "http://localhost:8192/v1" def main(): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e905bc537789..bc3960a58aa6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import time import uuid from collections import defaultdict from typing import TYPE_CHECKING @@ -81,11 +82,11 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): if role == KVConnectorRole.SCHEDULER: self.connector_scheduler = NixlConnectorScheduler( - vllm_config, self.engine_id) + vllm_config, str(self.engine_id)) self.connector_worker = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker(self.engine_id) + self.connector_worker = NixlConnectorWorker(str(self.engine_id)) ############################################################ # Scheduler Side Methods @@ -241,8 +242,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # [2 (k and v), num_blocks, ...] _, num_blocks, block_size, num_heads, head_dim = first_kv_cache.shape - self.block_len = block_size * num_heads * head_dim * first_kv_cache[ - 0].element_size() + kv_elem_size = first_kv_cache[0].element_size() + self.block_len = block_size * num_heads * head_dim * kv_elem_size logger.debug("Per layer kv cache size: %s", first_kv_cache[0].shape) self.num_layers = len(kv_caches) self.num_blocks = num_blocks @@ -266,6 +267,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + self._registered_descs.append(descs) # THIS IS FOR DEBUG and INSECURE @@ -273,6 +276,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): _ctx = zmq.Context() # type: ignore _side_channel = _ctx.socket(zmq.PAIR) # type: ignore NIXL_ROLE = os.getenv("NIXL_ROLE") + + # For debug, SENDER puts some stuff in the KV caches + # so the RECVER can check it + n_blocks_to_send = 4096 + debug_xfer_gb = 2.0 * n_blocks_to_send * self.block_len / 1024 / 1024 / 1024 + print(f"gb {debug_xfer_gb} -- block_len {self.block_len}") + if NIXL_ROLE == "SENDER": + for b in range(n_blocks_to_send): + kv_caches[first_layer_name][0, b, 0, 0, 0] = b + 100.0 + kv_caches[first_layer_name][1, b, 0, 0, 0] = b + 200.0 + for b in range(5): + print( + f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][0, b, 0, 0, 0]}" + ) + print( + f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][1, b, 0, 0, 0]}" + ) + remote_engine_id = None # HACK for debug send + if NIXL_ROLE == "SENDER": _side_channel.connect("tcp://localhost:5555") _side_channel.setsockopt(zmq.LINGER, 0) # type: ignore @@ -283,6 +305,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): num_blocks=self.num_blocks, ) encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug( + f"Size of encoded NixlAgentMetadata: {size_in_bytes} bytes") _side_channel.send(encoder.encode(metadata)) logger.debug("WAITING ON RECV") @@ -295,6 +321,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) metadata_bytes = _side_channel.recv() metadata = decoder.decode(metadata_bytes) + + remote_engine_id = metadata.engine_id #HACK + + logger.debug(f"Adding remote {metadata}") self.add_remote_agent(metadata) print("SENDING ACK") _side_channel.send(b"ack") @@ -302,6 +332,73 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): else: raise Exception("SET NIXL_ROLE to SENDER OR RECVER") + # FOR DEBUG: try to send some shit + + if NIXL_ROLE == "RECVER": + logger.debug("Sending blocks") + metadata = NixlConnectorMetadata() + assert remote_engine_id is not None + xfer_params = KVTransferParams( + do_remote_decode=True, + do_remote_prefill=False, + remote_block_ids=list(range(n_blocks_to_send)), + remote_engine_id=remote_engine_id #HACK + ) + + metadata.add_new_req(req_id="tms", + block_ids=list(range(n_blocks_to_send)), + kv_transfer_params=xfer_params) + self.start_load_kv(metadata) + + # Wait for Receive to complete + logger.debug("TMS START RECEIVE XFER") + done = False + start_time = time.time() + while (not done): + finished = self.get_finished() + # NOTE: Should fix discrepancy between bytes/str finished sets + # Here we have str. For sender we have bytes. + done = "tms" in finished[1] + time.sleep(1e-5) + end_time = time.time() + execution_time = end_time - start_time + logger.debug( + "Transfer Received. " + f"Duration: {1e3 * execution_time:.3f} ms " + f"Bandwidth: {debug_xfer_gb / execution_time:.3f} GB/s") + + if NIXL_ROLE == "SENDER": + # Wait for Send to complete + logger.debug("TMS START SEND XFER") + done = False + start_time = time.time() + while (not done): + finished = self.get_finished() + # NOTE: Should fix discrepancy between bytes/str finished sets + # Here we have bytes. For receiver we have str. + done = b'tms' in finished[0] + time.sleep(1e-5) + end_time = time.time() + execution_time = end_time - start_time + logger.debug( + "Transfer Sent. " + f"Duration: {1e3 * execution_time:.3f} ms " + f"Bandwidth: {debug_xfer_gb / execution_time:.3f} GB/s") + + # Put some different stuff in there + if NIXL_ROLE == "SENDER": + for b in range(n_blocks_to_send): + kv_caches[first_layer_name][0, b, 0, 0, 0] = b + 300.0 + kv_caches[first_layer_name][1, b, 0, 0, 0] = b + 400.0 + + for b in range(5): + print( + f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][0, b, 0, 0, 0]}" + ) + print( + f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][1, b, 0, 0, 0]}" + ) + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): engine_id = nixl_agent_meta.engine_id if engine_id in self._remote_agents: @@ -330,9 +427,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): blocks_data = [] for layer_id in range(self.num_layers): # Both K and V. - print(f"{len(self.kv_caches_base_addr[self.engine_id])=}") - print(f"{len(self.kv_caches_base_addr[self.engine_id][layer_id])=}") - print(f"{self.kv_caches_base_addr[self.engine_id][layer_id]=}") + # print(f"{len(self.kv_caches_base_addr[self.engine_id])=}") + # print(f"{len(self.kv_caches_base_addr[self.engine_id][layer_id])=}") + # print(f"{self.kv_caches_base_addr[self.engine_id][layer_id]=}") for base_addr in self.kv_caches_base_addr[ self.engine_id][layer_id]: for block_id in range(self.num_blocks): @@ -360,16 +457,16 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): block_offset = block_id * dst_block_len blocks_data.append( (base_addr + block_offset, dst_block_len, - self.rank * tp_multiplier)) + self.rank * tp_multiplier)) logger.debug("Created %s blocks for dst engine %s and rank %s", - len(blocks_data), engine_id, - self.rank * tp_multiplier + i) + len(blocks_data), engine_id, + self.rank * tp_multiplier + i) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") self.dst_xfer_side_handles[engine_id][i] = ( self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id][self.rank * tp_multiplier + - i], descs)) + self._remote_agents[engine_id][self.rank * tp_multiplier + i], + descs)) def get_finished(self) -> tuple[set[str], set[str]]: """Get requests that are done sending or recving.""" @@ -389,7 +486,7 @@ def _get_new_notifs(self) -> set[str]: notified_req_ids.add(req_id) return notified_req_ids - def _pop_done_transfers(self, transfers: dict[str, list[str]]) -> set[str]: + def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: """ Pop completed xfers by checking for DONE state. Args: @@ -398,7 +495,7 @@ def _pop_done_transfers(self, transfers: dict[str, list[str]]) -> set[str]: set of req_ids that have all done xfers """ done_req_ids: str[str] = set() - for req_id, handles in transfers.items(): + for req_id, handles in list(transfers.items()): running_reqs = [] for handle in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) @@ -413,6 +510,7 @@ def _pop_done_transfers(self, transfers: dict[str, list[str]]) -> set[str]: xfer_state) if len(running_reqs) == 0: done_req_ids.add(req_id) + del transfers[req_id] else: transfers[req_id] = running_reqs return done_req_ids @@ -453,16 +551,19 @@ def _read_blocks( # NOTE(rob): we could potentially do the rearranging during the load_kv! - assert len(local_block_ids) == len(staging_block_ids) == len( - remote_block_ids) + assert len(local_block_ids) == len(remote_block_ids) + assert (staging_block_ids is None + or len(staging_block_ids) == len(remote_block_ids)) if len(local_block_ids) == 0: return # TODO(rob): understand ranges code. local_ranges = self._get_ranges(local_block_ids) - staging_ranges = self._get_ranges(staging_block_ids) - _, staging_rearranging_ranges = self._get_same_length_ranges( - local_ranges, staging_ranges) + + # Note(tms): commenting out staging code + # staging_ranges = self._get_ranges(staging_block_ids) + # _, staging_rearranging_ranges = self._get_same_length_ranges( + # local_ranges, staging_ranges) # TODO: support TP multipliers. tp_multiplier = 1 @@ -472,28 +573,37 @@ def _read_blocks( # Read the data from the remote. for i in range(tp_multiplier): - staging_block_descs_ids = self._get_block_descs_ids( - self.engine_id, + local_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, "all", - staging_block_ids, - i=i, + local_block_ids, + i=None, #TODO: Enable both tp_multiplier and staging_ranges. tp_multiplier=tp_multiplier, - staging_ranges=staging_rearranging_ranges) - assert len(staging_block_descs_ids) == len(remote_block_descs_ids) + staging_ranges=None) + assert len(local_block_descs_ids) == len(remote_block_descs_ids) remote_xfer_side_handle = self.dst_xfer_side_handles[ dst_engine_id][i] # NOTE(rob): we use the request_id as the notify msg, so we # must use the same request_id in both the p and d workers. + logger.debug( + f"Making prepped xfer {local_xfer_side_handle} - {remote_xfer_side_handle}." + ) handle = self.nixl_wrapper.make_prepped_xfer( "READ", local_xfer_side_handle, - staging_block_descs_ids, + local_block_descs_ids, remote_xfer_side_handle, remote_block_descs_ids, notif_msg=request_id, ) # NOTE(rob): we will check this is done in the next forward pass. + + # Note: without this, the request handle's backendHandle won't be set. + # and we will fail during check_xfer_state + self.nixl_wrapper.transfer(handle) + logger.debug(f"Made prepped xfer {request_id} - {handle}.") + self._recving_transfers[request_id].append(handle) # NOTE(rob): this is actually pretty serious problem. @@ -557,6 +667,7 @@ def _get_block_descs_ids(self, start_offset * tp_multiplier + i_offset + (block_id - start_offset)) else: + logger.debug(f"engine_id: {engine_id}") num_blocks = self.dst_num_blocks[engine_id] for layer_id in layer_ids: for is_value in [0, 1]: From c5e023e2652a23b0641cd7e41714f63c0d7c7680 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 25 Apr 2025 16:42:57 -0500 Subject: [PATCH 060/119] updated Signed-off-by: Robert Shaw --- .../test_remote_prefill_lifecycle.py | 122 +++++++++++++++++- .../kv_connector/v1/nixl_connector.py | 2 - 2 files changed, 115 insertions(+), 9 deletions(-) diff --git a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py index fbda09c66a42..a951871b9146 100644 --- a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py @@ -1,16 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 import copy +from typing import Optional -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT -from vllm.v1.request import RequestStatus +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.request import RequestStatus, Request from .utils import create_request, create_scheduler, create_vllm_config -def test_single_remote_prefill(): - """ - Test that the request lifecycle for a remote prefill - works as expected. - """ +def test_basic_remote_prefill_cycle(): + """Test Remote Prefills Lifecycle.""" + vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) @@ -98,6 +97,115 @@ def test_single_remote_prefill(): assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) +def make_model_runner_output( + reqs: list[Request], + finished_sending: Optional[list[str]] = None, + finished_recving: Optional[list[str]] = None, +) -> ModelRunnerOutput: + req_ids = [req.request_id for req in reqs] + req_id_to_index = { + req_id: idx for idx, req_id in enumerate(req_ids) + } + sampled_token_ids = [[0] for _ in req_ids] + + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) + +def test_interleaved_remote_prefill_cycle(): + """Test Remote Prefills Work Well With Other Requests.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True + ) + request_local_a = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + ) + request_local_b = create_request( + request_id=3, + num_tokens=NUM_TOKENS, + ) + + # STEP 1: Regular request is running. + scheduler.add_request(request_local_a) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + model_runner_output = make_model_runner_output( + [request_local_a]) + scheduler.update_from_output(scheduler_output, + model_runner_output) + + # STEP 2: Add a local and remote request. + scheduler.add_request(request_local_b) + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 1 + + model_runner_output = make_model_runner_output( + [request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, + model_runner_output) + + # STEP 3: continue running, KVs not arrived yet. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = make_model_runner_output( + reqs=[request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + # STEP 4: KVs arrive. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = make_model_runner_output( + [request_local_a, request_local_b], + finished_recving=[request_remote.request_id] + ) + scheduler.update_from_output(scheduler_output, + model_runner_output) + + # STEP 5: RECVed KVs are sent to ModelRunner. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 3 + assert len(scheduler.waiting) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + def test_remote_prefill_no_prefix_cache_uncomputed_blocks(): """ With P/D, blocks can be allocated but uncomputed for diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f607c4131fd0..0cef64cecd5a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -276,8 +276,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): (key_cache.data_ptr(), value_cache.data_ptr())) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr - print(f"{len(self.kv_caches_base_addr[self.engine_id])=}") - print(f"{self.kv_caches_base_addr[self.engine_id][0]=}") descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") logger.debug("Registering descs: %s", caches_data) From 8b0c93cd9a25285a20126527fde906c13ac2610d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sat, 26 Apr 2025 19:34:34 +0000 Subject: [PATCH 061/119] Cleanup Signed-off-by: Tyler Michael Smith --- .../kv_connector/v1/nixl_connector.py | 207 ++++++------------ 1 file changed, 67 insertions(+), 140 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index bc3960a58aa6..089bd04c734c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -2,11 +2,12 @@ import time import uuid from collections import defaultdict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import msgspec import torch import zmq +from typing_extensions import Optional from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -40,7 +41,7 @@ class NixlAgentMetadata( agent_metadata: bytes # Base addr for each layer for KVs # NOTE: we will need another list for TP>1 - kv_caches_base_addr: list[tuple[int, int]] + kv_caches_base_addr: list[int] num_blocks: int @@ -50,7 +51,7 @@ def __init__( self, block_ids: list[int], remote_block_ids: list[int], - remote_engine_id: list[int], + remote_engine_id: str, ): self.block_ids = block_ids self.remote_block_ids = remote_block_ids @@ -81,9 +82,9 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): self.engine_id = uuid.uuid4() if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler = NixlConnectorScheduler( - vllm_config, str(self.engine_id)) - self.connector_worker = None + self.connector_scheduler : Optional[NixlConnectorScheduler] = \ + NixlConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None self.connector_worker = NixlConnectorWorker(str(self.engine_id)) @@ -159,6 +160,8 @@ def get_num_new_matched_tokens(self, request: "Request", """For remote prefill, allocate for all tokens.""" if request.do_remote_prefill: return len(request.prompt_token_ids) - num_computed_tokens + else: + return 0 def update_state_after_alloc(self, request: "Request", num_external_tokens: int): @@ -178,8 +181,8 @@ def build_connector_meta( req = self._reqs_need_recv.pop(new_req.req_id, None) if req is not None: meta.add_new_req( - request_id=new_req.req_id, - local_block_ids=new_req.block_ids, + req_id=new_req.req_id, + block_ids=new_req.block_ids, kv_transfer_params=req.kv_transfer_params, ) @@ -209,16 +212,15 @@ def __init__(self, engine_id: str): # KV Caches and nixl tracking data. self.num_layers: int = 0 - self.num_layers: int = 0 self.num_heads: int = 0 - self.kv_caches: tuple[torch.Tensor, torch.Tensor] = None + self.kv_caches: dict[str, torch.Tensor] = {} # Map of engine_id -> kv_caches_base_addr # For Local: base addr for *this* rank, each layer for K,V # For Remote: base addr for *each* rank, each layer for K,V # KV_CACHES_ADDR_TYPE = Union[list[tuple[int, int]], # list[list[tuple[int, int]]]] - self.kv_caches_base_addr: dict[str, any] = {} + self.kv_caches_base_addr: dict[str, list[int]] = {} # Map of tp_mult -> nixl_prepped_dlist_handle (int). self.src_xfer_side_handles: dict[int, int] = {} @@ -228,11 +230,11 @@ def __init__(self, engine_id: str): int]] = defaultdict(dict) # Map of engine_id -> num_blocks. self.dst_num_blocks: dict[str, int] = {} - self._registered_descs: list[any] = [] + self._registered_descs: list[Any] = [] # In progress transfers. # [req_id -> list[handle]] - self._recving_transfers = defaultdict(list[any]) + self._recving_transfers: dict[str, list[Any]] = defaultdict(list[Any]) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -252,17 +254,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches_base_addr = [] caches_data = [] for layer_name in kv_caches: - kv_cache = kv_caches[layer_name] - key_cache, value_cache = kv_cache[0], kv_cache[1] - base_addr = key_cache.data_ptr() - region_len = 2 * num_blocks * self.block_len - caches_data.append((base_addr, region_len, self.rank, "")) - kv_caches_base_addr.append( - (key_cache.data_ptr(), value_cache.data_ptr())) - + for cache in kv_caches[layer_name]: + base_addr = cache.data_ptr() + region_len = num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank, "")) + kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr - print(f"{len(self.kv_caches_base_addr[self.engine_id])=}") - print(f"{self.kv_caches_base_addr[self.engine_id][0]=}") descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") logger.debug("Registering descs: %s", caches_data) @@ -280,7 +277,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # For debug, SENDER puts some stuff in the KV caches # so the RECVER can check it n_blocks_to_send = 4096 - debug_xfer_gb = 2.0 * n_blocks_to_send * self.block_len / 1024 / 1024 / 1024 + debug_xfer_gb = 2.0 * n_blocks_to_send * self.block_len / 1e9 print(f"gb {debug_xfer_gb} -- block_len {self.block_len}") if NIXL_ROLE == "SENDER": for b in range(n_blocks_to_send): @@ -288,15 +285,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches[first_layer_name][1, b, 0, 0, 0] = b + 200.0 for b in range(5): print( - f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][0, b, 0, 0, 0]}" + f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][0, b, 0, 0, 0]}" #noqa ) print( - f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][1, b, 0, 0, 0]}" + f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][1, b, 0, 0, 0]}" #noqa ) remote_engine_id = None # HACK for debug send if NIXL_ROLE == "SENDER": - _side_channel.connect("tcp://localhost:5555") + _side_channel.connect("tcp://localhost:5577") _side_channel.setsockopt(zmq.LINGER, 0) # type: ignore metadata = NixlAgentMetadata( engine_id=self.engine_id, @@ -307,8 +304,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) size_in_bytes = len(encoded_data) - logger.debug( - f"Size of encoded NixlAgentMetadata: {size_in_bytes} bytes") + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", + str(size_in_bytes)) _side_channel.send(encoder.encode(metadata)) logger.debug("WAITING ON RECV") @@ -316,7 +313,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("GOT ACK %s", ack) elif NIXL_ROLE == "RECVER": - _side_channel.bind("tcp://localhost:5555") + _side_channel.bind("tcp://localhost:5577") _side_channel.setsockopt(zmq.LINGER, 0) # type: ignore decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) metadata_bytes = _side_channel.recv() @@ -324,7 +321,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): remote_engine_id = metadata.engine_id #HACK - logger.debug(f"Adding remote {metadata}") self.add_remote_agent(metadata) print("SENDING ACK") _side_channel.send(b"ack") @@ -336,7 +332,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if NIXL_ROLE == "RECVER": logger.debug("Sending blocks") - metadata = NixlConnectorMetadata() + connector_metadata = NixlConnectorMetadata() assert remote_engine_id is not None xfer_params = KVTransferParams( do_remote_decode=True, @@ -345,10 +341,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): remote_engine_id=remote_engine_id #HACK ) - metadata.add_new_req(req_id="tms", - block_ids=list(range(n_blocks_to_send)), - kv_transfer_params=xfer_params) - self.start_load_kv(metadata) + connector_metadata.add_new_req(req_id="tms", + block_ids=list( + range(n_blocks_to_send)), + kv_transfer_params=xfer_params) + self.start_load_kv(connector_metadata) # Wait for Receive to complete logger.debug("TMS START RECEIVE XFER") @@ -363,9 +360,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): end_time = time.time() execution_time = end_time - start_time logger.debug( - "Transfer Received. " - f"Duration: {1e3 * execution_time:.3f} ms " - f"Bandwidth: {debug_xfer_gb / execution_time:.3f} GB/s") + "Transfer Received. Duration: %f ms Bandwidth %f GB/s", + 1e3 * execution_time, debug_xfer_gb / execution_time) if NIXL_ROLE == "SENDER": # Wait for Send to complete @@ -380,10 +376,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): time.sleep(1e-5) end_time = time.time() execution_time = end_time - start_time - logger.debug( - "Transfer Sent. " - f"Duration: {1e3 * execution_time:.3f} ms " - f"Bandwidth: {debug_xfer_gb / execution_time:.3f} GB/s") + logger.debug("Transfer Sent. Duration: %f ms Bandwidth %f GB/s", + 1e3 * execution_time, debug_xfer_gb / execution_time) # Put some different stuff in there if NIXL_ROLE == "SENDER": @@ -393,13 +387,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for b in range(5): print( - f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][0, b, 0, 0, 0]}" + f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][0, b, 0, 0, 0]}" #noqa ) print( - f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][1, b, 0, 0, 0]}" + f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][1, b, 0, 0, 0]}" #noqa ) - def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, tp_idx=0): engine_id = nixl_agent_meta.engine_id if engine_id in self._remote_agents: return @@ -425,23 +419,17 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): if tp_multiplier not in self.src_xfer_side_handles: # Create descs and xfer side handles. blocks_data = [] - for layer_id in range(self.num_layers): - # Both K and V. - # print(f"{len(self.kv_caches_base_addr[self.engine_id])=}") - # print(f"{len(self.kv_caches_base_addr[self.engine_id][layer_id])=}") - # print(f"{self.kv_caches_base_addr[self.engine_id][layer_id]=}") - for base_addr in self.kv_caches_base_addr[ - self.engine_id][layer_id]: - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len - for i in range(tp_multiplier): - tp_multiplier_offset = i * dst_block_len - blocks_data.append((base_addr + block_offset + - tp_multiplier_offset, - dst_block_len, self.rank)) + for base_addr in self.kv_caches_base_addr[self.engine_id]: + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + for i in range(tp_multiplier): + tp_multiplier_offset = tp_idx * dst_block_len + blocks_data.append( + (base_addr + block_offset + tp_multiplier_offset, + dst_block_len, self.rank)) + print(len(self.kv_caches_base_addr[self.engine_id])) logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, - self.rank * tp_multiplier + i) + len(blocks_data), self.engine_id, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") @@ -451,22 +439,19 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): # create dst xfer side handles self.dst_num_blocks[engine_id] = num_blocks blocks_data = [] - for layer_id in range(self.num_layers): - for base_addr in self.kv_caches_base_addr[engine_id][layer_id]: - for block_id in range(num_blocks): - block_offset = block_id * dst_block_len - blocks_data.append( - (base_addr + block_offset, dst_block_len, - self.rank * tp_multiplier)) + for base_addr in self.kv_caches_base_addr[engine_id]: + for block_id in range(num_blocks): + block_offset = block_id * dst_block_len + blocks_data.append((base_addr + block_offset, dst_block_len, + self.rank * tp_multiplier)) logger.debug("Created %s blocks for dst engine %s and rank %s", - len(blocks_data), engine_id, - self.rank * tp_multiplier + i) + len(blocks_data), engine_id, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[engine_id][i] = ( + self.dst_xfer_side_handles[engine_id][tp_idx] = ( self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id][self.rank * tp_multiplier + i], - descs)) + self._remote_agents[engine_id][self.rank * tp_multiplier + + tp_idx], descs)) def get_finished(self) -> tuple[set[str], set[str]]: """Get requests that are done sending or recving.""" @@ -494,7 +479,7 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: Returns: set of req_ids that have all done xfers """ - done_req_ids: str[str] = set() + done_req_ids: set[str] = set() for req_id, handles in list(transfers.items()): running_reqs = [] for handle in handles: @@ -557,14 +542,6 @@ def _read_blocks( if len(local_block_ids) == 0: return - # TODO(rob): understand ranges code. - local_ranges = self._get_ranges(local_block_ids) - - # Note(tms): commenting out staging code - # staging_ranges = self._get_ranges(staging_block_ids) - # _, staging_rearranging_ranges = self._get_same_length_ranges( - # local_ranges, staging_ranges) - # TODO: support TP multipliers. tp_multiplier = 1 remote_block_descs_ids = self._get_block_descs_ids( @@ -586,9 +563,6 @@ def _read_blocks( # NOTE(rob): we use the request_id as the notify msg, so we # must use the same request_id in both the p and d workers. - logger.debug( - f"Making prepped xfer {local_xfer_side_handle} - {remote_xfer_side_handle}." - ) handle = self.nixl_wrapper.make_prepped_xfer( "READ", local_xfer_side_handle, @@ -597,42 +571,13 @@ def _read_blocks( remote_block_descs_ids, notif_msg=request_id, ) - # NOTE(rob): we will check this is done in the next forward pass. - # Note: without this, the request handle's backendHandle won't be set. - # and we will fail during check_xfer_state + # Call transfer to begin the async transfer + # We will check this is done in the next forward pass. self.nixl_wrapper.transfer(handle) - logger.debug(f"Made prepped xfer {request_id} - {handle}.") self._recving_transfers[request_id].append(handle) - # NOTE(rob): this is actually pretty serious problem. - # We need to figure out if we can put the staging blocks on the P worker side. # noqa: E501 - # The staging blocks need to be on the side that sends. - - # for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): # noqa: E501 - # logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", # noqa: E501 - # self.kv_caches[0].shape, local_range, staging_range) - # for kv_cache in self.kv_caches: - # for cache in kv_cache: - # rearrange_tensors(cache[local_range[0]:local_range[1] + 1], # noqa: E501 - # cache[staging_range[0]:staging_range[1] + 1], # noqa: E501 - # tp_multiplier, "read") - - def _get_ranges(self, block_ids): - # This function should return a list of ranges of block ids that are contiguous # noqa: E501 - # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] # noqa: E501 - # The ranges are sorted by the starting block id - # The function should also make sure that the block ids are contiguous - # If the block ids are not contiguous, the function should raise an error # noqa: E501 - ranges = [] - for i in range(len(block_ids)): - if i == 0 or block_ids[i] != block_ids[i - 1] + 1: - ranges.append([block_ids[i], block_ids[i]]) - else: - ranges[-1][1] = block_ids[i] - return ranges - def _get_block_descs_ids(self, engine_id, layer_ids, @@ -649,29 +594,11 @@ def _get_block_descs_ids(self, descs_ids = [] if i is not None: - num_blocks = self.num_blocks - for layer_id in layer_ids: - for is_value in [0, 1]: - staging_range_idx = 0 - for block_id in block_ids: - if block_id > staging_ranges[staging_range_idx][ - 1] or block_id < staging_ranges[ - staging_range_idx][0]: - staging_range_idx += 1 - start_offset = staging_ranges[staging_range_idx][0] - i_offset = i * (staging_ranges[staging_range_idx][-1] - - start_offset + 1) - descs_ids.append( - layer_id * 2 * num_blocks * tp_multiplier + - is_value * num_blocks * tp_multiplier + - start_offset * tp_multiplier + i_offset + - (block_id - start_offset)) + raise NotImplementedError("Prefill and Decode instances must have " + "the same TP size.") else: - logger.debug(f"engine_id: {engine_id}") num_blocks = self.dst_num_blocks[engine_id] - for layer_id in layer_ids: - for is_value in [0, 1]: - for block_id in block_ids: - descs_ids.append(layer_id * 2 * num_blocks + - is_value * num_blocks + block_id) + for layer_id in 2 * layer_ids: + for block_id in block_ids: + descs_ids.append(layer_id * num_blocks + block_id) return descs_ids From 5e45d90fa9ef90d84da4c97d5892829b41f3de2e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sat, 26 Apr 2025 22:26:40 +0000 Subject: [PATCH 062/119] WIP Signed-off-by: Tyler Michael Smith --- Justfile | 44 ++++ start_proxy.py | 200 ++++++++++++++++++ .../kv_connector/v1/nixl_connector.py | 25 ++- vllm/sampling_params.py | 2 +- vllm/v1/core/kv_cache_manager.py | 6 +- vllm/v1/core/sched/scheduler.py | 48 ++++- vllm/v1/worker/gpu_model_runner.py | 6 +- 7 files changed, 311 insertions(+), 20 deletions(-) create mode 100644 Justfile create mode 100644 start_proxy.py diff --git a/Justfile b/Justfile new file mode 100644 index 000000000000..8fd0f0b44a3a --- /dev/null +++ b/Justfile @@ -0,0 +1,44 @@ +notes: + UCX_RNDV_THRESH=0 # Force rendezvous protocol for all messages + UCX_MEMTYPE_CACHE=n # Disable memory type caching + UCX_TLS=rc,ud,dc,cuda_copy,cuda_ipc,gdr_copy # Prioritize RDMA transports + UCX_ZCOPY_THRESH=0 # Force zero-copy for all sizes + +prefill: + UCX_LOG_LEVEL=debug \ + NIXL_ROLE="SENDER" \ + CUDA_VISIBLE_DEVICES=3 \ + VLLM_LOGGING_LEVEL="DEBUG" \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 \ + vllm serve meta-llama/Llama-3.2-1B-Instruct \ + --port 8100 \ + --enforce-eager \ + --load-format dummy \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' + +decode: + UCX_LOG_LEVEL=info \ + NIXL_ROLE="RECVER" \ + CUDA_VISIBLE_DEVICES=4 \ + VLLM_LOGGING_LEVEL="DEBUG" \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 \ + vllm serve meta-llama/Llama-3.2-1B-Instruct \ + --port 8200 \ + --enforce-eager \ + --load-format dummy \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' + +proxy: + python start_proxy.py --port 8192 --prefiller-port 8100 --decoder-port 8200 + +send_request: + curl -X POST http://localhost:8192/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ \ + "model": "meta-llama/Llama-3.2-1B-Instruct", \ + "prompt": "Generate a curl command to send to an openai server hosted at local_host:8192 with this as the prompt", \ + "max_tokens": 150, \ + "temperature": 0.7 \ + }' diff --git a/start_proxy.py b/start_proxy.py new file mode 100644 index 000000000000..4c7af86eea98 --- /dev/null +++ b/start_proxy.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import time +from contextlib import asynccontextmanager + +import httpx +import numpy as np +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' + decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +class StatsCalculator: + + def __init__(self): + self._stats = [] + self._last_log_time = time.time() + + def add(self, value): + self._stats.append(value) + if time.time() - self._last_log_time > 5: + self._log_stats() + self._last_log_time = time.time() + + def _log_stats(self): + # Print average, median, and 99th percentile + np_arr = np.array(self._stats) + output_str = f"\nNum requests: {len(self._stats)}" + \ + "\nPrefill node TTFT stats:" + \ + f"\n - Average (ms): {np.mean(np_arr)}" + \ + f"\n - Median (ms): {np.median(np_arr)}" + \ + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" + print("===============================", output_str, + "===============================") + + +stats_calculator = StatsCalculator() +counter = 0 + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8192) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['do_remote_decode'] = True + req_data["stream"] = False + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": "vllm-d-debug", + } + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": "vllm-d-debug", + } + req_data['do_remote_prefill'] = True + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + print(req_data) + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, "/completions", + req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, + "/chat/completions", req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/chat/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server " + " - chat completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 089bd04c734c..be2a4761599b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -127,7 +127,8 @@ def get_finished(self) -> tuple[set[str], set[str]]: def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None - self.connector_worker.start_load_kv() + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) def wait_for_layer_load(self, layer_name: str) -> None: """NixlConnector does not do layerwise saving.""" @@ -328,6 +329,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): else: raise Exception("SET NIXL_ROLE to SENDER OR RECVER") + # Very, very hacky + self.remote_engine_id = metadata.engine_id + # FOR DEBUG: try to send some shit if NIXL_ROLE == "RECVER": @@ -370,9 +374,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): start_time = time.time() while (not done): finished = self.get_finished() - # NOTE: Should fix discrepancy between bytes/str finished sets - # Here we have bytes. For receiver we have str. - done = b'tms' in finished[0] + done = "tms" in finished[0] time.sleep(1e-5) end_time = time.time() execution_time = end_time - start_time @@ -457,6 +459,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: """Get requests that are done sending or recving.""" done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) + if len(done_sending) > 0 or len(done_recving) > 0: + logger.debug( + "get_finished: %s requests done sending " + "and %s requests done recving", len(done_sending), + len(done_recving)) return done_sending, done_recving def _get_new_notifs(self) -> set[str]: @@ -468,7 +475,7 @@ def _get_new_notifs(self) -> set[str]: for req_ids in self.nixl_wrapper.get_new_notifs().values(): for req_id in req_ids: assert req_id not in notified_req_ids - notified_req_ids.add(req_id) + notified_req_ids.add(req_id.decode('utf-8')) return notified_req_ids def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: @@ -507,12 +514,16 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): """ for req_id, meta in metadata.requests.items(): # NOTE: this is non-blocking + logger.debug("start_load_kv for request " + req_id) self._read_blocks( local_block_ids=meta.block_ids, # TODO: support staging once we do heterogeneous TP staging_block_ids=meta.block_ids, - remote_block_ids=meta.remote_block_ids, - dst_engine_id=meta.remote_engine_id, + # DISGUSTING HACKs + #remote_block_ids=meta.remote_block_ids, + #dst_engine_id=meta.remote_engine_id, + remote_block_ids=meta.block_ids, + dst_engine_id=self.remote_engine_id, request_id=req_id, ) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4ecebb808c29..a30658b43c9a 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -95,7 +95,7 @@ def from_optional( @property def backend_name(self) -> str: """Return the backend name without any options. - + For example if the backend is "xgrammar:no-fallback", returns "xgrammar" """ return (self.backend or "").split(":")[0] diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 33bb825a11a7..45110e34d523 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -188,7 +188,7 @@ def allocate_slots( new_computed_blocks: A list of new computed blocks just hitting the prefix caching. num_lookahead_tokens: The number of speculative tokens to allocate. - This is used by spec decode proposers with kv-cache such + This is used by spec decode proposers with kv-cache such as eagle. skip_cache_blocks: Whether to skip cachings the blocks. This is used by P/D when allocating blocks that used in KV transfer @@ -383,7 +383,9 @@ def get_num_common_prefix_blocks( Returns: int: The number of common prefix blocks. """ - assert request.status == RequestStatus.RUNNING + assert request.status in [ + RequestStatus.RUNNING, RequestStatus.FINISHED_REMOTE_DECODE + ] blocks = self.req_to_blocks[request.request_id] num_common_blocks = 0 for block in blocks: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 995bc7512e22..3358a66ae740 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -32,6 +32,30 @@ logger = init_logger(__name__) +import sys +import traceback + + +def print_last_3_stack_levels(): + try: + # Either raise an exception or get the current stack + stack = traceback.extract_stack() + # Print only the last 3 levels + for frame in stack[-6:-3]: + print( + f"File: {frame.filename}, Line: {frame.lineno}, Function: {frame.name}" + ) + print(f" {frame.line}") + except Exception: + # Get the exception's traceback + tb_list = traceback.extract_tb(sys.exc_info()[2]) + # Print only the last 3 levels + for frame in tb_list[-3:]: + print( + f"File: {frame.filename}, Line: {frame.lineno}, Function: {frame.name}" + ) + print(f" {frame.line}") + class Scheduler(SchedulerInterface): @@ -433,13 +457,14 @@ def schedule(self) -> SchedulerOutput: request.num_computed_tokens = num_computed_tokens # Encoder-related. - if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) - # Allocate the encoder cache. - for i in encoder_inputs_to_schedule: - self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget + if not request.do_remote_prefill: + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: @@ -530,7 +555,9 @@ def schedule(self) -> SchedulerOutput: # 3. If some tokens (e.g. spec tokens) are rejected later, the number of # computed tokens will be adjusted in update_from_output. for req_id, num_scheduled_token in num_scheduled_tokens.items(): - self.requests[req_id].num_computed_tokens += num_scheduled_token + if req_id in self.requests: + self.requests[ + req_id].num_computed_tokens += num_scheduled_token self.finished_req_ids = set() return scheduler_output @@ -808,6 +835,7 @@ def update_from_output( def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request + print(f"Adding {request.request_id} to requests") if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) @@ -839,9 +867,13 @@ def finish_requests( else: self.waiting.remove(request) request.status = finished_status + print(f"freeing request {req_id}") self._free_request(request) def _free_request(self, request: Request) -> None: + logger.debug(f"Freeing request {request.request_id}") + print_last_3_stack_levels() + assert request.is_finished() self.kv_cache_manager.free(request) self.kv_cache_manager.free_block_hashes(request) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c79ecfdfab5d..74fcb61fa8d3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -393,6 +393,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the states of the running/resumed requests. for req_data in scheduler_output.scheduled_cached_reqs: req_id = req_data.req_id + if req_id not in self.requests: + print(f"{req_id} {self.requests}") + continue req_state = self.requests[req_id] # Update the cached states. @@ -1013,8 +1016,7 @@ def maybe_setup_kv_connector(): # These transfers are designed to be async and the requests # involved may be disjoint from the running requests. # Do this here to save a collective_rpc. - if get_forward_context().attn_metadata is not None: - kv_connector.start_load_kv(get_forward_context()) + kv_connector.start_load_kv(get_forward_context()) def maybe_wait_for_save(): if has_kv_transfer_group(): From 20a5491bd8b3d77c091194b942fbd0d143316a9d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 08:55:38 -0500 Subject: [PATCH 063/119] updated Signed-off-by: Robert Shaw --- tests/v1/kv_connector/test_model_runner.py | 55 ------------------- tests/v1/kv_connector/test_nixl_connector.py | 38 +++++++++++++ ...le.py => test_remote_prefill_scheduler.py} | 0 .../kv_transfer/kv_connector/v1/base.py | 1 + .../kv_connector/v1/nixl_connector.py | 25 ++++----- vllm/v1/core/sched/scheduler.py | 11 ++++ 6 files changed, 61 insertions(+), 69 deletions(-) delete mode 100644 tests/v1/kv_connector/test_model_runner.py create mode 100644 tests/v1/kv_connector/test_nixl_connector.py rename tests/v1/kv_connector/{test_remote_prefill_lifecycle.py => test_remote_prefill_scheduler.py} (100%) diff --git a/tests/v1/kv_connector/test_model_runner.py b/tests/v1/kv_connector/test_model_runner.py deleted file mode 100644 index a6124ceeeb81..000000000000 --- a/tests/v1/kv_connector/test_model_runner.py +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import torch - -from .utils import (create_model_runner, create_request, create_scheduler, - create_vllm_config) - - -def test_basic_remote_prefill(): - vllm_config = create_vllm_config() - scheduler = create_scheduler(vllm_config) - model_runner = create_model_runner(vllm_config=vllm_config, - device=torch.device(type="cpu")) - - NUM_TOKENS = 16 - - normal_request = create_request(request_id=0, num_tokens=NUM_TOKENS) - - remote_request = create_request( - request_id=1, - num_tokens=NUM_TOKENS, - do_remote_prefill=True, - ) - - scheduler.add_request(normal_request) - scheduler.add_request(remote_request) - - scheduler_output = scheduler.schedule() - - # Both should be running, but only the normal request - # should have scheduled tokens. - assert len(scheduler.running) == 2 - assert scheduler_output.num_scheduled_tokens[ - normal_request.request_id] == NUM_TOKENS - assert scheduler_output.num_scheduled_tokens[ - remote_request.request_id] == 0 - - for scheduled_new_req in scheduler_output.scheduled_new_reqs: - # Remote request has all tokens computed externally. - if scheduled_new_req.req_id == remote_request.request_id: - assert scheduled_new_req.num_computed_tokens == NUM_TOKENS - 1 - # Normal request has no tokens computed externally. - if scheduled_new_req.req_id == normal_request.request_id: - assert scheduled_new_req.num_computed_tokens == 0 - - # model_runner.execute_model does: - # * _update_states - # * returns if no tokens scheduled - # * _prepare_inputs - model_runner._update_states(scheduler_output) - attn_metadata, logits_indices, spec_decode_metadata = ( - model_runner._prepare_inputs(scheduler_output)) - - print(f"{attn_metadata=}") - print(f"{logits_indices=}") - print(f"{spec_decode_metadata=}") diff --git a/tests/v1/kv_connector/test_nixl_connector.py b/tests/v1/kv_connector/test_nixl_connector.py new file mode 100644 index 000000000000..ca16d66753e2 --- /dev/null +++ b/tests/v1/kv_connector/test_nixl_connector.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy +from typing import Optional + +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorMetadata) +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.request import RequestStatus, Request + +from .utils import create_request, create_scheduler, create_vllm_config + +def test_scheduler_worker_inferface(): + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_id = request.request_id + + scheduler.add_request(request) + + # Remote Prefill, triggers NixlConnectorMetdata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, NixlConnectorMetadata) + + assert len(kv_connector_metadata.requests) == 1 + assert request_id in kv_connector_metadata.requests + print(f"{kv_connector_metadata.requests=}") + diff --git a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/test_remote_prefill_scheduler.py similarity index 100% rename from tests/v1/kv_connector/test_remote_prefill_lifecycle.py rename to tests/v1/kv_connector/test_remote_prefill_scheduler.py diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 3fd8c3344e2e..95d3dfb7c841 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -201,6 +201,7 @@ def get_num_new_matched_tokens( @abstractmethod def update_state_after_alloc(self, request: "Request", + block_ids: list[int], num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0cef64cecd5a..28da7b1ef031 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -97,10 +97,11 @@ def get_num_new_matched_tokens(self, request: "Request", request, num_computed_tokens) def update_state_after_alloc(self, request: "Request", + block_ids: list[int], num_external_tokens: int): assert self.connector_scheduler is not None return self.connector_scheduler.update_state_after_alloc( - request, num_external_tokens) + request, block_ids, num_external_tokens) def build_connector_meta( self, @@ -175,11 +176,13 @@ def get_num_new_matched_tokens(self, request: "Request", return 0 def update_state_after_alloc(self, request: "Request", + block_ids: list[int], num_external_tokens: int): if request.do_remote_decode: pass if request.do_remote_prefill and num_external_tokens > 0: - self._reqs_need_recv[request.request_id] = request + self._reqs_need_recv[request.request_id] = ( + request, block_ids) def build_connector_meta( self, @@ -188,18 +191,12 @@ def build_connector_meta( meta = NixlConnectorMetadata() # Loop through scheduled reqs and convert to ReqMeta. - for new_req in scheduler_output.scheduled_new_reqs: - req = self._reqs_need_recv.pop(new_req.req_id, None) - if req is not None: - meta.add_new_req( - request_id=new_req.req_id, - local_block_ids=new_req.block_ids, - kv_transfer_params=req.kv_transfer_params, - ) - - # Invariant: only new requests should need load - # and we should get all new requests each step(). - assert len(self._reqs_need_recv) == 0 + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) return meta diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 62c0b95c111a..3a014d61f5a6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -363,6 +363,16 @@ def schedule(self) -> SchedulerOutput: self.waiting.popleft() skipped_waiting_requests.appendleft(request) request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + + # KVConnector: update internal state after allocation. + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + [b.block_id for b in computed_blocks + new_blocks], + num_external_tokens, + ) continue # Number of tokens to be scheduled. @@ -404,6 +414,7 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: self.connector.update_state_after_alloc( request, + [b.block_id for b in computed_blocks + new_blocks], num_external_tokens, ) From cee3c614ddd645d217802d7277d823bf1ab92c58 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 09:01:06 -0500 Subject: [PATCH 064/119] updated Signed-off-by: Robert Shaw --- tests/v1/kv_connector/test_nixl_connector.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/test_nixl_connector.py b/tests/v1/kv_connector/test_nixl_connector.py index ca16d66753e2..b746978907ea 100644 --- a/tests/v1/kv_connector/test_nixl_connector.py +++ b/tests/v1/kv_connector/test_nixl_connector.py @@ -34,5 +34,9 @@ def test_scheduler_worker_inferface(): assert len(kv_connector_metadata.requests) == 1 assert request_id in kv_connector_metadata.requests - print(f"{kv_connector_metadata.requests=}") - + req_meta = kv_connector_metadata.requests[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, + scheduler.kv_cache_manager.req_to_blocks[request_id]): + assert block_id == block.block_id From 597257129a8563b76c8f70d768c4293710f65513 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 09:06:37 -0500 Subject: [PATCH 065/119] updated on scheduler side Signed-off-by: Robert Shaw --- vllm/v1/core/sched/scheduler.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3a014d61f5a6..06b9bbf18b52 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -32,7 +32,6 @@ logger = init_logger(__name__) - class Scheduler(SchedulerInterface): def __init__( @@ -768,12 +767,17 @@ def update_from_output( request.status = RequestStatus.FINISHED_REMOTE_DECODE self._free_request(request, skip_free_blocks=True) # TODO(rob): do this on a per-Connector basis. + remote_blocks = [ + block.block_id for block in + self.kv_cache_manager.req_to_blocks[request.request_id] + ] + kv_transfer_params = KVTransferParams( do_remote_prefill=True, # put the remote block ids here - remote_block_ids=[1, 2, 3], + remote_block_ids=remote_blocks, # put the enigne id here - remote_engine_id="abcdefg", + remote_engine_id=self.connector.engine_id, ) # Add EngineCoreOutput for this Request. From 1b69d33dab2bc4296cd88b20ecd78de4ea6bc035 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 09:14:37 -0500 Subject: [PATCH 066/119] updated Signed-off-by: Robert Shaw --- .../test_remote_decode_scheduler.py | 259 ++++++++++++++++++ .../test_remote_prefill_scheduler.py | 35 +-- tests/v1/kv_connector/utils.py | 29 +- vllm/v1/core/sched/scheduler.py | 1 + 4 files changed, 294 insertions(+), 30 deletions(-) create mode 100644 tests/v1/kv_connector/test_remote_decode_scheduler.py diff --git a/tests/v1/kv_connector/test_remote_decode_scheduler.py b/tests/v1/kv_connector/test_remote_decode_scheduler.py new file mode 100644 index 000000000000..f02462eae264 --- /dev/null +++ b/tests/v1/kv_connector/test_remote_decode_scheduler.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy +from typing import Optional + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.request import RequestStatus, Request + +from .utils import create_request, create_scheduler, create_vllm_config + +def test_basic_remote_prefill_cycle(): + """Test Remote Prefills Lifecycle.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + START_FREE_BLOCK_QUEUE_SIZE = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): + # (1a): schedule() + scheduler_output = scheduler.schedule() + + # Nothing running and empty scheduler output. + assert len(scheduler.running) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler_output.num_scheduled_tokens) == 0 + assert scheduler_output.total_num_scheduled_tokens == 0 + + # Req waiting for KVs with no computed + # or scheduled tokens. + assert len(scheduler.waiting) == 1 + assert request in scheduler.waiting + assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + assert (request.num_computed_tokens == 0) + + # ... but should have (uncached) blocks allocated to it. + block_pool = scheduler.kv_cache_manager.block_pool + assert (block_pool.free_block_queue.num_free_blocks + < START_FREE_BLOCK_QUEUE_SIZE) + assert len(block_pool.cached_block_hash_to_block) == 0 + for block in scheduler.kv_cache_manager.req_to_blocks[request_id]: + assert block._block_hash is None + + # (1b): forward() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output) + assert len(engine_core_outputs.outputs) == 0 + + # STEP (2): + # (2a): schedule(): nothing happens! + scheduler_output = scheduler.schedule() + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 0 + + # (2b): forward(): request finishes recv. + model_runner_output = copy.deepcopy( + EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + + # (2c): update_from_output(): + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_KV_req_ids) + + # (3a): schedule(): this should actually schedule. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + # Confirm the block are actually allocated. + num_hashed_blocks = 0 + for block in scheduler.kv_cache_manager.req_to_blocks[request_id]: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS + + # Confirm the rest of the prompt is scheduled in this step. + scheduled_req = scheduler_output.scheduled_new_reqs[0] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] + num_computed_tokens = scheduled_req.num_computed_tokens + total_prompt_tokens = len(scheduled_req.prompt_token_ids) + assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + + +def test_interleaved_remote_prefill_cycle(): + """Test Remote Prefills Work Well With Other Requests.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True + ) + request_local_a = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + ) + request_local_b = create_request( + request_id=3, + num_tokens=NUM_TOKENS, + ) + + # STEP 1: Regular request is running. + scheduler.add_request(request_local_a) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + model_runner_output = make_model_runner_output( + [request_local_a]) + scheduler.update_from_output(scheduler_output, + model_runner_output) + + # STEP 2: Add a local and remote request. + scheduler.add_request(request_local_b) + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 1 + + model_runner_output = make_model_runner_output( + [request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, + model_runner_output) + + # STEP 3: continue running, KVs not arrived yet. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = make_model_runner_output( + reqs=[request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + # STEP 4: KVs arrive. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = make_model_runner_output( + [request_local_a, request_local_b], + finished_recving=[request_remote.request_id] + ) + scheduler.update_from_output(scheduler_output, + model_runner_output) + + # STEP 5: RECVed KVs are sent to ModelRunner. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 3 + assert len(scheduler.waiting) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + +def test_remote_prefill_no_prefix_cache_uncomputed_blocks(): + """ + With P/D, blocks can be allocated but uncomputed for + multiple engine steps. This test confirms that we do + not accidentally have cache hits against uncomputed + blocks. + """ + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 and a half full external blocks. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + # Both of these requests have prompts like [1,1,1,1,1, ...] + request_remote = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + use_all_1s_for_prompt_tokens=True, + ) + + request_local = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + do_remote_prefill=False, + use_all_1s_for_prompt_tokens=True, + ) + + # Schedule the remote prefill request. This should not + # cause any blocks to be cached. + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + scheduler.update_from_output( + scheduler_output, + EMPTY_MODEL_RUNNER_OUTPUT + ) + assert len(scheduler.waiting) == 1 + + # Schedule the local prefill request. This should + # cause blocks to be cached, but separately from + scheduler.add_request(request_local) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + local_blocks = scheduler.kv_cache_manager.req_to_blocks[request_local.request_id] + remote_blocks = scheduler.kv_cache_manager.req_to_blocks[request_remote.request_id] + + # Local should have cached blocks (but not all due to preallocate). + num_hashed_blocks = 0 + for block in local_blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += ( + 1 if block._block_hash is not None else 0) + assert num_hashed_blocks > 0 + + # Remote blocks should not be cached. + for block in remote_blocks: + assert block.ref_cnt == 1 + assert block._block_hash is None + + +def test_remote_prefill_no_blocks_available(): + """ + letTest whether we properly handle no blocks available + """ + pass \ No newline at end of file diff --git a/tests/v1/kv_connector/test_remote_prefill_scheduler.py b/tests/v1/kv_connector/test_remote_prefill_scheduler.py index a951871b9146..98a65904e7ea 100644 --- a/tests/v1/kv_connector/test_remote_prefill_scheduler.py +++ b/tests/v1/kv_connector/test_remote_prefill_scheduler.py @@ -2,10 +2,11 @@ import copy from typing import Optional -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT from vllm.v1.request import RequestStatus, Request -from .utils import create_request, create_scheduler, create_vllm_config +from .utils import (create_request, create_scheduler, + create_vllm_config, create_model_runner_output) def test_basic_remote_prefill_cycle(): """Test Remote Prefills Lifecycle.""" @@ -97,28 +98,6 @@ def test_basic_remote_prefill_cycle(): assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) -def make_model_runner_output( - reqs: list[Request], - finished_sending: Optional[list[str]] = None, - finished_recving: Optional[list[str]] = None, -) -> ModelRunnerOutput: - req_ids = [req.request_id for req in reqs] - req_id_to_index = { - req_id: idx for idx, req_id in enumerate(req_ids) - } - sampled_token_ids = [[0] for _ in req_ids] - - return ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_id_to_index, - sampled_token_ids=sampled_token_ids, - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - finished_sending=finished_sending, - finished_recving=finished_recving, - ) - def test_interleaved_remote_prefill_cycle(): """Test Remote Prefills Work Well With Other Requests.""" @@ -149,7 +128,7 @@ def test_interleaved_remote_prefill_cycle(): scheduler_output = scheduler.schedule() assert len(scheduler.running) == 1 - model_runner_output = make_model_runner_output( + model_runner_output = create_model_runner_output( [request_local_a]) scheduler.update_from_output(scheduler_output, model_runner_output) @@ -163,7 +142,7 @@ def test_interleaved_remote_prefill_cycle(): assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_cached_reqs) == 1 - model_runner_output = make_model_runner_output( + model_runner_output = create_model_runner_output( [request_local_a, request_local_b]) scheduler.update_from_output(scheduler_output, model_runner_output) @@ -175,7 +154,7 @@ def test_interleaved_remote_prefill_cycle(): assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_cached_reqs) == 2 - model_runner_output = make_model_runner_output( + model_runner_output = create_model_runner_output( reqs=[request_local_a, request_local_b]) scheduler.update_from_output(scheduler_output, model_runner_output) @@ -191,7 +170,7 @@ def test_interleaved_remote_prefill_cycle(): assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_cached_reqs) == 2 - model_runner_output = make_model_runner_output( + model_runner_output = create_model_runner_output( [request_local_a, request_local_b], finished_recving=[request_remote.request_id] ) diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index 17a24f590392..4fdc5beef78a 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional import torch from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, @@ -7,13 +8,13 @@ from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.worker.gpu_model_runner import GPUModelRunner EOS_TOKEN_ID = 50256 - def create_vllm_config( model: str = "facebook/opt-125m", max_num_seqs: int = 16, @@ -98,7 +99,9 @@ def create_request( ) -> list[Request]: if do_remote_decode: assert not do_remote_prefill - kv_transfer_params = KVTransferParams(do_remote_prefill=True, ) + kv_transfer_params = KVTransferParams( + do_remote_prefill=True + ) elif do_remote_prefill: kv_transfer_params = KVTransferParams( do_remote_prefill=True, @@ -132,3 +135,25 @@ def create_request( eos_token_id=EOS_TOKEN_ID, arrival_time=0, ) + +def create_model_runner_output( + reqs: list[Request], + finished_sending: Optional[list[str]] = None, + finished_recving: Optional[list[str]] = None, +) -> ModelRunnerOutput: + req_ids = [req.request_id for req in reqs] + req_id_to_index = { + req_id: idx for idx, req_id in enumerate(req_ids) + } + sampled_token_ids = [[0] for _ in req_ids] + + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 06b9bbf18b52..ac27a96a855a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -766,6 +766,7 @@ def update_from_output( if request.do_remote_decode and not stopped: request.status = RequestStatus.FINISHED_REMOTE_DECODE self._free_request(request, skip_free_blocks=True) + # TODO(rob): do this on a per-Connector basis. remote_blocks = [ block.block_id for block in From 8adf1ad71347c19682944bb4daa1ae583953c1d6 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 13:14:14 -0500 Subject: [PATCH 067/119] updated Signed-off-by: Robert Shaw --- .../test_remote_decode_scheduler.py | 252 ++---------------- tests/v1/kv_connector/utils.py | 14 +- 2 files changed, 26 insertions(+), 240 deletions(-) diff --git a/tests/v1/kv_connector/test_remote_decode_scheduler.py b/tests/v1/kv_connector/test_remote_decode_scheduler.py index f02462eae264..8a891fe16ecf 100644 --- a/tests/v1/kv_connector/test_remote_decode_scheduler.py +++ b/tests/v1/kv_connector/test_remote_decode_scheduler.py @@ -1,14 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from typing import Optional -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput -from vllm.v1.request import RequestStatus, Request +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import RequestStatus, FinishReason -from .utils import create_request, create_scheduler, create_vllm_config +from .utils import (create_request, create_scheduler, + create_vllm_config, create_model_runner_output) -def test_basic_remote_prefill_cycle(): - """Test Remote Prefills Lifecycle.""" +def test_basic_remote_decode_cycle(): + """Test Remote Decode Lifecycle.""" vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) @@ -22,7 +22,7 @@ def test_basic_remote_prefill_cycle(): request = create_request(request_id=1, num_tokens=NUM_TOKENS, - do_remote_prefill=True) + do_remote_decode=True) scheduler.add_request(request) request_id = request.request_id @@ -30,230 +30,26 @@ def test_basic_remote_prefill_cycle(): # STEP (1): # (1a): schedule() scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 - # Nothing running and empty scheduler output. - assert len(scheduler.running) == 0 - assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 0 - assert len(scheduler_output.num_scheduled_tokens) == 0 - assert scheduler_output.total_num_scheduled_tokens == 0 - - # Req waiting for KVs with no computed - # or scheduled tokens. - assert len(scheduler.waiting) == 1 - assert request in scheduler.waiting - assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) - assert (request.num_computed_tokens == 0) - - # ... but should have (uncached) blocks allocated to it. - block_pool = scheduler.kv_cache_manager.block_pool - assert (block_pool.free_block_queue.num_free_blocks - < START_FREE_BLOCK_QUEUE_SIZE) - assert len(block_pool.cached_block_hash_to_block) == 0 - for block in scheduler.kv_cache_manager.req_to_blocks[request_id]: - assert block._block_hash is None - - # (1b): forward() - model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) # (1c): update_from_output() engine_core_outputs = scheduler.update_from_output( scheduler_output, model_runner_output) - assert len(engine_core_outputs.outputs) == 0 - - # STEP (2): - # (2a): schedule(): nothing happens! - scheduler_output = scheduler.schedule() - assert len(scheduler.waiting) == 1 - assert len(scheduler.running) == 0 - - # (2b): forward(): request finishes recv. - model_runner_output = copy.deepcopy( - EMPTY_MODEL_RUNNER_OUTPUT) - model_runner_output.finished_recving = [request_id] - - # (2c): update_from_output(): - engine_core_outputs = scheduler.update_from_output( - scheduler_output, model_runner_output) - assert len(scheduler.waiting) == 1 - assert (request_id in scheduler.finished_recving_KV_req_ids) - - # (3a): schedule(): this should actually schedule. - scheduler_output = scheduler.schedule() - assert len(scheduler.running) == 1 - - # Confirm the block are actually allocated. - num_hashed_blocks = 0 - for block in scheduler.kv_cache_manager.req_to_blocks[request_id]: - assert block.ref_cnt == 1 - num_hashed_blocks += (1 if block._block_hash is not None else 0) - assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS - - # Confirm the rest of the prompt is scheduled in this step. - scheduled_req = scheduler_output.scheduled_new_reqs[0] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] - num_computed_tokens = scheduled_req.num_computed_tokens - total_prompt_tokens = len(scheduled_req.prompt_token_ids) - assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) - - -def test_interleaved_remote_prefill_cycle(): - """Test Remote Prefills Work Well With Other Requests.""" - - vllm_config = create_vllm_config() - scheduler = create_scheduler(vllm_config) - - # 2 Full Blocks and 1 Half Block. - BLOCK_SIZE = vllm_config.cache_config.block_size - NUM_EXTERNAL_FULL_BLOCKS = 2 - NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_remote = create_request( - request_id=1, - num_tokens=NUM_TOKENS, - do_remote_prefill=True - ) - request_local_a = create_request( - request_id=2, - num_tokens=NUM_TOKENS, - ) - request_local_b = create_request( - request_id=3, - num_tokens=NUM_TOKENS, - ) - - # STEP 1: Regular request is running. - scheduler.add_request(request_local_a) - scheduler_output = scheduler.schedule() - assert len(scheduler.running) == 1 - - model_runner_output = make_model_runner_output( - [request_local_a]) - scheduler.update_from_output(scheduler_output, - model_runner_output) - - # STEP 2: Add a local and remote request. - scheduler.add_request(request_local_b) - scheduler.add_request(request_remote) - scheduler_output = scheduler.schedule() - assert len(scheduler.running) == 2 - assert len(scheduler.waiting) == 1 - assert len(scheduler_output.scheduled_new_reqs) == 1 - assert len(scheduler_output.scheduled_cached_reqs) == 1 - - model_runner_output = make_model_runner_output( - [request_local_a, request_local_b]) - scheduler.update_from_output(scheduler_output, - model_runner_output) - - # STEP 3: continue running, KVs not arrived yet. - scheduler_output = scheduler.schedule() - assert len(scheduler.running) == 2 - assert len(scheduler.waiting) == 1 - assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 2 - - model_runner_output = make_model_runner_output( - reqs=[request_local_a, request_local_b]) - scheduler.update_from_output(scheduler_output, - model_runner_output) - assert len(scheduler.running) == 2 - assert len(scheduler.waiting) == 1 - assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 2 - - # STEP 4: KVs arrive. - scheduler_output = scheduler.schedule() - assert len(scheduler.running) == 2 - assert len(scheduler.waiting) == 1 - assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 2 - - model_runner_output = make_model_runner_output( - [request_local_a, request_local_b], - finished_recving=[request_remote.request_id] - ) - scheduler.update_from_output(scheduler_output, - model_runner_output) - - # STEP 5: RECVed KVs are sent to ModelRunner. - scheduler_output = scheduler.schedule() - assert len(scheduler.running) == 3 - assert len(scheduler.waiting) == 0 - assert len(scheduler_output.scheduled_new_reqs) == 1 - assert len(scheduler_output.scheduled_cached_reqs) == 2 - - -def test_remote_prefill_no_prefix_cache_uncomputed_blocks(): - """ - With P/D, blocks can be allocated but uncomputed for - multiple engine steps. This test confirms that we do - not accidentally have cache hits against uncomputed - blocks. - """ - - vllm_config = create_vllm_config() - scheduler = create_scheduler(vllm_config) - - vllm_config = create_vllm_config() - scheduler = create_scheduler(vllm_config) - - # 2 and a half full external blocks. - BLOCK_SIZE = vllm_config.cache_config.block_size - NUM_EXTERNAL_FULL_BLOCKS = 2 - NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - - # Both of these requests have prompts like [1,1,1,1,1, ...] - request_remote = create_request( - request_id=1, - num_tokens=NUM_TOKENS, - do_remote_prefill=True, - use_all_1s_for_prompt_tokens=True, - ) - - request_local = create_request( - request_id=2, - num_tokens=NUM_TOKENS, - do_remote_prefill=False, - use_all_1s_for_prompt_tokens=True, - ) - - # Schedule the remote prefill request. This should not - # cause any blocks to be cached. - scheduler.add_request(request_remote) - scheduler_output = scheduler.schedule() - scheduler.update_from_output( - scheduler_output, - EMPTY_MODEL_RUNNER_OUTPUT - ) - assert len(scheduler.waiting) == 1 - - # Schedule the local prefill request. This should - # cause blocks to be cached, but separately from - scheduler.add_request(request_local) - scheduler_output = scheduler.schedule() - assert len(scheduler.running) == 1 - assert len(scheduler.waiting) == 1 - - local_blocks = scheduler.kv_cache_manager.req_to_blocks[request_local.request_id] - remote_blocks = scheduler.kv_cache_manager.req_to_blocks[request_remote.request_id] - - # Local should have cached blocks (but not all due to preallocate). - num_hashed_blocks = 0 - for block in local_blocks: - assert block.ref_cnt == 1 - num_hashed_blocks += ( - 1 if block._block_hash is not None else 0) - assert num_hashed_blocks > 0 - - # Remote blocks should not be cached. - for block in remote_blocks: - assert block.ref_cnt == 1 - assert block._block_hash is None - - -def test_remote_prefill_no_blocks_available(): - """ - letTest whether we properly handle no blocks available - """ - pass \ No newline at end of file + # Ensure the request is finished after 1 tokens. + assert request.is_finished() + assert request.status == RequestStatus.FINISHED_REMOTE_DECODE + blocks = scheduler.kv_cache_manager.req_to_blocks[request_id] + output = engine_core_outputs.outputs[0] + assert output.finish_reason == FinishReason.REMOTE_DECODE + + # Ensure the return gives the proper transfer params. + remote_block_ids = output.kv_transfer_params.remote_block_ids + for remote_block_id, block in zip(remote_block_ids, blocks): + assert remote_block_id == block.block_id + assert (output.kv_transfer_params.remote_engine_id == + scheduler.connector.engine_id) diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index 4fdc5beef78a..5c3d0fdeada1 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -56,16 +56,6 @@ def create_vllm_config( ) -def create_model_runner( - vllm_config: VllmConfig, - device: torch.device, -) -> GPUModelRunner: - return GPUModelRunner( - vllm_config=vllm_config, - device=device, - ) - - def create_scheduler( vllm_config: VllmConfig, num_blocks: int = 10000, @@ -96,11 +86,11 @@ def create_request( do_remote_decode: bool = False, do_remote_prefill: bool = False, use_all_1s_for_prompt_tokens: bool = False, -) -> list[Request]: +) -> Request: if do_remote_decode: assert not do_remote_prefill kv_transfer_params = KVTransferParams( - do_remote_prefill=True + do_remote_decode=True ) elif do_remote_prefill: kv_transfer_params = KVTransferParams( From 21ab3d9df3fd6923bd29c308fbbb715b1f25d89d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 13:21:55 -0500 Subject: [PATCH 068/119] updated Signed-off-by: Robert Shaw --- vllm/entrypoints/openai/protocol.py | 10 ++++++++++ vllm/sampling_params.py | 10 ++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c56a15af1367..5a14f6023f01 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -816,6 +816,14 @@ class CompletionRequest(OpenAIBaseModel): default=False, description="KVTransfer parameters used for disaggregated serving.") + remote_engine_id: Optional[str] = Field( + default=None, + description="Remote engine id.") + + remote_block_ids: Optional[list[int]] = Field( + default=None, + description="Remote block ids.") + # doc: end-completion-extra-params # Default sampling parameters for completion requests @@ -916,6 +924,8 @@ def to_sampling_params( kv_transfer_params = KVTransferParams.from_optional( do_remote_decode=self.do_remote_decode, do_remote_prefill=self.do_remote_prefill, + remote_engine_id=self.remote_engine_id, + remote_block_ids=self.remote_block_ids, ) return SamplingParams.from_optional( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4ecebb808c29..27a3772bac5f 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -39,8 +39,12 @@ class KVTransferParams( do_remote_prefill: bool = False @staticmethod - def from_optional(do_remote_decode: bool, - do_remote_prefill: bool) -> Optional["KVTransferParams"]: + def from_optional( + do_remote_decode: bool, + do_remote_prefill: bool, + remote_engine_id: Optional[str], + remote_block_ids: Optional[list[int]], + ) -> Optional["KVTransferParams"]: if do_remote_decode and do_remote_prefill: raise ValueError( "Cannot do both remote prefill and remote decode.") @@ -48,6 +52,8 @@ def from_optional(do_remote_decode: bool, return KVTransferParams( do_remote_decode=do_remote_decode, do_remote_prefill=do_remote_prefill, + remote_engine_id=remote_engine_id, + remote_block_ids=remote_block_ids, ) else: return None From 3a27bbc7fa14fc33d44f47ccb060568d0c268da0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 13:50:39 -0500 Subject: [PATCH 069/119] updated Signed-off-by: Robert Shaw --- vllm/entrypoints/openai/serving_completion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1067f35ce240..57e75ff956a4 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -475,13 +475,16 @@ def request_output_to_completion_response( ) request_metadata.final_usage_info = usage - + + assert len(final_res_batch) == 1 return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, + remote_engine_id=final_res_batch[0].kv_transfer_params.remote_engine_id, + remote_block_ids=final_res_batch[0].kv_transfer_params.remote_block_ids, ) def _create_completion_logprobs( From f252df98364e22be563c87fc7baf8c091a4c5787 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 13:53:25 -0500 Subject: [PATCH 070/119] updated Signed-off-by: Robert Shaw --- examples/disagg_proxy_server.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/disagg_proxy_server.py b/examples/disagg_proxy_server.py index 2639409a1522..fdb153f1d813 100644 --- a/examples/disagg_proxy_server.py +++ b/examples/disagg_proxy_server.py @@ -98,12 +98,15 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, async def stream_service_response(client: httpx.AsyncClient, endpoint: str, - req_data: dict): + req_data: dict, remote_block_ids: list[int], + remote_engine_id: str): """ Asynchronously stream the response from a service using a persistent client. """ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} req_data['do_remote_prefill'] = True + req_data["remote_block_ids"] = remote_block_ids + req_data['remote_engine_id'] = remote_engine_id async with client.stream("POST", endpoint, json=req_data, headers=headers) as response: response.raise_for_status() @@ -121,8 +124,13 @@ async def handle_completions(request: Request): req_data = await request.json() # Send request to prefill service, ignore the response - await send_request_to_service(app.state.prefill_client, "/completions", - req_data) + response = await send_request_to_service( + app.state.prefill_client, + "/completions", + req_data + ) + remote_block_ids = response.remote_block_ids + remote_engine_id = response.remote_engine_id et = time.time() stats_calculator.add(et - st) @@ -131,7 +139,8 @@ async def handle_completions(request: Request): async def generate_stream(): async for chunk in stream_service_response(app.state.decode_client, "/completions", - req_data): + req_data, + ): yield chunk return StreamingResponse(generate_stream(), From 81048030c7f1b3ba0674f6cfb3caf7a47f57ed23 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 13:58:45 -0500 Subject: [PATCH 071/119] updated Signed-off-by: Robert Shaw --- examples/disagg_proxy_server.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/disagg_proxy_server.py b/examples/disagg_proxy_server.py index fdb153f1d813..e696fabb24dd 100644 --- a/examples/disagg_proxy_server.py +++ b/examples/disagg_proxy_server.py @@ -137,10 +137,12 @@ async def handle_completions(request: Request): # Stream response from decode service async def generate_stream(): - async for chunk in stream_service_response(app.state.decode_client, - "/completions", - req_data, - ): + async for chunk in stream_service_response( + app.state.decode_client, + "/completions", + req_data, + remote_block_ids=remote_block_ids, + remote_engine_id=remote_engine_id): yield chunk return StreamingResponse(generate_stream(), From 10bbe219472ec2a5c119fc3c96eefdeabbfb2db7 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 27 Apr 2025 19:00:20 +0000 Subject: [PATCH 072/119] Hacking away Signed-off-by: Tyler Michael Smith --- start_proxy.py | 29 ++++++++++- .../kv_connector/v1/nixl_connector.py | 13 ++--- vllm/entrypoints/openai/protocol.py | 17 +++++-- vllm/v1/core/sched/scheduler.py | 6 +++ vllm/v1/engine/output_processor.py | 10 ++-- vllm/v1/request.py | 3 ++ vllm/v1/worker/gpu_model_runner.py | 48 ++++++++++++++----- 7 files changed, 100 insertions(+), 26 deletions(-) diff --git a/start_proxy.py b/start_proxy.py index 4c7af86eea98..8e365e19bb22 100644 --- a/start_proxy.py +++ b/start_proxy.py @@ -116,6 +116,32 @@ async def stream_service_response(client: httpx.AsyncClient, endpoint: str, async for chunk in response.aiter_bytes(): yield chunk +async def send_request_to_prefill_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['do_remote_decode'] = True + req_data["stream"] = False + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": "vllm-d-debug", + } + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + + # Extract and print the actual response content + try: + response_json = response.json() + print(f"Prefill Request Content: {req_data}") + print(f"Prefill Response Content: {response_json}") + except Exception as e: + print(f"Could not parse prefill response as JSON: {e}") + print(f"Raw prefill response text: {response.text}") + + return response + @app.post("/v1/completions") async def handle_completions(request: Request): @@ -128,8 +154,9 @@ async def handle_completions(request: Request): print(req_data) # Send request to prefill service, ignore the response - await send_request_to_service(app.state.prefill_client, "/completions", + response = await send_request_to_prefill_service(app.state.prefill_client, "/completions", req_data) + print(response) et = time.time() stats_calculator.add(et - st) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4d6b88086711..f0252b9ef222 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -71,8 +71,10 @@ def add_new_req( kv_transfer_params: KVTransferParams, ): assert request_id not in self.requests - assert kv_transfer_params.remote_block_ids is not None assert kv_transfer_params.remote_engine_id is not None + print(kv_transfer_params.remote_engine_id) + assert kv_transfer_params.remote_block_ids is not None + print("HERE") self.requests[request_id] = ReqMeta( local_block_ids=local_block_ids, @@ -450,7 +452,6 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, tp_idx=0): blocks_data.append( (base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank)) - print(len(self.kv_caches_base_addr[self.engine_id])) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank) @@ -535,10 +536,10 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): """ for req_id, meta in metadata.requests.items(): # NOTE: this is non-blocking - logger.debug("start_load_kv for request " + req_id) - print(meta.local_block_ids) - print(meta.remote_block_ids) - print(meta.remote_engine_id) + logger.debug("start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", + req_id, len(meta.local_block_ids), + len(meta.remote_block_ids), meta.remote_engine_id) self._read_blocks( local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c56a15af1367..139e900ba11b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -918,6 +918,11 @@ def to_sampling_params( do_remote_prefill=self.do_remote_prefill, ) + import os + print("Setting sampling params in protocol.py") + NIXL_ROLE = os.getenv("NIXL_ROLE") + kv_transfer_params.remote_engine_id = str(NIXL_ROLE) + " set_in_protocol.py" + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -1347,6 +1352,10 @@ class ChatCompletionResponse(OpenAIBaseModel): usage: UsageInfo prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + # Add these fields for KV transfer + remote_engine_id: Optional[str] = None + remote_block_ids: Optional[list[int]] = Field(default_factory=list) + class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None @@ -1606,9 +1615,9 @@ class TranscriptionRequest(OpenAIBaseModel): # doc: begin-transcription-extra-params stream: Optional[bool] = False - """Custom field not present in the original OpenAI definition. When set, + """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat - Completion endpoint. + Completion endpoint. """ # Flattened stream option to simplify form data. stream_include_usage: Optional[bool] = False @@ -1626,7 +1635,7 @@ class TranscriptionRequest(OpenAIBaseModel): """ top_p: Optional[float] = None - """Enables nucleus (top-p) sampling, where tokens are selected from the + """Enables nucleus (top-p) sampling, where tokens are selected from the smallest possible set whose cumulative probability exceeds `p`. """ @@ -1634,7 +1643,7 @@ class TranscriptionRequest(OpenAIBaseModel): """Limits sampling to the `k` most probable tokens at each step.""" min_p: Optional[float] = None - """Filters out tokens with a probability lower than `min_p`, ensuring a + """Filters out tokens with a probability lower than `min_p`, ensuring a minimum likelihood threshold during sampling. """ diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 99a6e6dc1297..2631d030efed 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -686,6 +686,12 @@ def update_from_output( new_running.append(request) continue + + if not req_id in model_runner_output.req_id_to_index: + print(req_id) + print(model_runner_output.req_id_to_index) + continue + req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 1de8e8994a86..b1b691f1eb23 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -187,6 +187,8 @@ def _new_request_output( else: prompt_logprobs = self.logprobs_processor.prompt_logprobs + print(f"Creating request output with kv_transfer_params {kv_transfer_params.remote_engine_id}") + return RequestOutput( request_id=request_id, prompt=self.prompt, @@ -305,22 +307,22 @@ def process_outputs( 1) Compute stats for logging 2) Detokenize 3) Create and handle RequestOutput objects: - * If there is a queue (for usage with AsyncLLM), + * If there is a queue (for usage with AsyncLLM), put the RequestOutput objects into the queue for handling by the per-request generate() tasks. - * If there is no queue (for usage with LLMEngine), + * If there is no queue (for usage with LLMEngine), return a list of RequestOutput objects. ****************** NOTE FOR DEVELOPERS ****************** vLLM V1 minimizes the number of python loops over the full - batch to ensure system overheads are minimized. This is the + batch to ensure system overheads are minimized. This is the only function that should loop over EngineCoreOutputs. If you need to touch every element of the batch, do it from within the loop below. - + ********************************************************** """ diff --git a/vllm/v1/request.py b/vllm/v1/request.py index bd0d57ee4c8e..a498caee5c66 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -68,6 +68,9 @@ def __init__( self.do_remote_prefill = ( False if sampling_params.kv_transfer_params is None else sampling_params.kv_transfer_params.do_remote_prefill) + + #TODO: need to get the remote_engine_id and + # remote block_ids self.kv_transfer_params = sampling_params.kv_transfer_params # Sanity check diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a2d47cf104cb..0ce2e0038c92 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1045,19 +1045,27 @@ def maybe_get_finished() -> tuple[list[str], list[str]]: return output # Prepare the decoder inputs. - attn_metadata, logits_indices, spec_decode_metadata = ( - self._prepare_inputs(scheduler_output)) + num_reqs = self.input_batch.num_reqs num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.use_cuda_graph - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) + if num_reqs > 0: + attn_metadata, logits_indices, spec_decode_metadata = ( + self._prepare_inputs(scheduler_output)) + if (self.use_cuda_graph + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) + else: + # Eager mode. + num_input_tokens = num_scheduled_tokens + attn_metadata.num_input_tokens = num_input_tokens else: - # Eager mode. - num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens + # This may happen when there are outstanding KV transfers + num_input_tokens = 1 + attn_metadata = None + logits_indices = None + spec_decode_metadata = None # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1124,6 +1132,24 @@ def maybe_get_finished() -> tuple[list[str], list[str]]: # For mid-pipeline stages, return the hidden states. return hidden_states + if logits_indices is None: + # HACK(tms): Early exit + + # Clear KVConnector state after all KVs are generated. + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) + hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) From 65ea91fb7297222288348bdb13708d8c38ab9e6b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 14:56:24 -0500 Subject: [PATCH 073/119] cleanup Signed-off-by: Robert Shaw --- .../kv_connector/test_remote_decode_scheduler.py | 15 +++++++-------- tests/v1/kv_connector/utils.py | 12 ++++++++---- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/v1/kv_connector/test_remote_decode_scheduler.py b/tests/v1/kv_connector/test_remote_decode_scheduler.py index 8a891fe16ecf..af943d2b291a 100644 --- a/tests/v1/kv_connector/test_remote_decode_scheduler.py +++ b/tests/v1/kv_connector/test_remote_decode_scheduler.py @@ -43,13 +43,12 @@ def test_basic_remote_decode_cycle(): # Ensure the request is finished after 1 tokens. assert request.is_finished() assert request.status == RequestStatus.FINISHED_REMOTE_DECODE - blocks = scheduler.kv_cache_manager.req_to_blocks[request_id] output = engine_core_outputs.outputs[0] assert output.finish_reason == FinishReason.REMOTE_DECODE - - # Ensure the return gives the proper transfer params. - remote_block_ids = output.kv_transfer_params.remote_block_ids - for remote_block_id, block in zip(remote_block_ids, blocks): - assert remote_block_id == block.block_id - assert (output.kv_transfer_params.remote_engine_id == - scheduler.connector.engine_id) + assert output.kv_transfer_params is not None + + # Ensure blocks are not freed. + blocks = scheduler.kv_cache_manager.req_to_blocks[request_id] + for block in blocks: + assert block.ref_cnt == 1 + diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index 5c3d0fdeada1..3fbfc098e06c 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -11,7 +11,6 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager -from vllm.v1.worker.gpu_model_runner import GPUModelRunner EOS_TOKEN_ID = 50256 @@ -21,6 +20,7 @@ def create_vllm_config( max_num_batched_tokens: int = 64, block_size: int = 16, ) -> VllmConfig: + """Initialize VllmConfig For Testing.""" scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, @@ -60,6 +60,7 @@ def create_scheduler( vllm_config: VllmConfig, num_blocks: int = 10000, ) -> Scheduler: + """Initialize Scheduler For Testing.""" block_size = vllm_config.cache_config.block_size kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests @@ -87,6 +88,8 @@ def create_request( do_remote_prefill: bool = False, use_all_1s_for_prompt_tokens: bool = False, ) -> Request: + """Make dummy request for testing.""" + if do_remote_decode: assert not do_remote_prefill kv_transfer_params = KVTransferParams( @@ -95,8 +98,8 @@ def create_request( elif do_remote_prefill: kv_transfer_params = KVTransferParams( do_remote_prefill=True, - remote_engine_id="abc", - remote_block_ids=[1, 2, 3], + remote_engine_id="remote_engine_id", + remote_block_ids=[1,2,3], ) else: kv_transfer_params = None @@ -106,7 +109,6 @@ def create_request( kv_transfer_params=kv_transfer_params, ) - if use_all_1s_for_prompt_tokens: prompt_token_ids = [1] * num_tokens else: @@ -131,6 +133,8 @@ def create_model_runner_output( finished_sending: Optional[list[str]] = None, finished_recving: Optional[list[str]] = None, ) -> ModelRunnerOutput: + """Make dummy model runner output for testing.""" + req_ids = [req.request_id for req in reqs] req_id_to_index = { req_id: idx for idx, req_id in enumerate(req_ids) From f2550ef2d8d059b9a6825b5afe78cdc54842f143 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 15:06:39 -0500 Subject: [PATCH 074/119] ensure request removed from running list Signed-off-by: Robert Shaw --- tests/v1/kv_connector/test_remote_decode_scheduler.py | 11 +++++++++-- vllm/v1/core/sched/scheduler.py | 3 +-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/v1/kv_connector/test_remote_decode_scheduler.py b/tests/v1/kv_connector/test_remote_decode_scheduler.py index af943d2b291a..412514c010ec 100644 --- a/tests/v1/kv_connector/test_remote_decode_scheduler.py +++ b/tests/v1/kv_connector/test_remote_decode_scheduler.py @@ -47,8 +47,15 @@ def test_basic_remote_decode_cycle(): assert output.finish_reason == FinishReason.REMOTE_DECODE assert output.kv_transfer_params is not None - # Ensure blocks are not freed. + # Request freed in Scheduler and in Persistent Batch. + # This causes the request to be freed in the scheduler. + + + # This causes the request to be freed in the PB on next step(). + assert request_id in scheduler.finished_req_ids + assert len(scheduler.running) == 0 + + # ... but blocks should not be freed. blocks = scheduler.kv_cache_manager.req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 - diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ac27a96a855a..ff4c0a594759 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -766,6 +766,7 @@ def update_from_output( if request.do_remote_decode and not stopped: request.status = RequestStatus.FINISHED_REMOTE_DECODE self._free_request(request, skip_free_blocks=True) + stopped = True # TODO(rob): do this on a per-Connector basis. remote_blocks = [ @@ -775,9 +776,7 @@ def update_from_output( kv_transfer_params = KVTransferParams( do_remote_prefill=True, - # put the remote block ids here remote_block_ids=remote_blocks, - # put the enigne id here remote_engine_id=self.connector.engine_id, ) From 985bac3493c4439564d8bf2dbcce9d7414829f96 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 27 Apr 2025 20:48:01 +0000 Subject: [PATCH 075/119] Runs E2E. Garbage output. Crashes on 2nd request Signed-off-by: Tyler Michael Smith --- Justfile | 4 ++-- examples/disagg_proxy_server.py | 20 ++++++++++++------- vllm/config.py | 18 +++++++++++------ .../kv_connector/v1/nixl_connector.py | 19 +++++++++++------- vllm/entrypoints/openai/protocol.py | 11 +--------- vllm/entrypoints/openai/serving_completion.py | 14 ++++++++++--- vllm/v1/core/sched/scheduler.py | 6 ++++-- vllm/v1/engine/output_processor.py | 2 -- vllm/v1/outputs.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 6 +++++- 10 files changed, 62 insertions(+), 42 deletions(-) diff --git a/Justfile b/Justfile index 8fd0f0b44a3a..bb734fba207e 100644 --- a/Justfile +++ b/Justfile @@ -31,14 +31,14 @@ decode: --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' proxy: - python start_proxy.py --port 8192 --prefiller-port 8100 --decoder-port 8200 + python examples/disagg_proxy_server.py --port 8192 send_request: curl -X POST http://localhost:8192/v1/completions \ -H "Content-Type: application/json" \ -d '{ \ "model": "meta-llama/Llama-3.2-1B-Instruct", \ - "prompt": "Generate a curl command to send to an openai server hosted at local_host:8192 with this as the prompt", \ + "prompt": "Generate a curl command to send to an openai server hosted at local_host:8192 with this as the", \ "max_tokens": 150, \ "temperature": 0.7 \ }' diff --git a/examples/disagg_proxy_server.py b/examples/disagg_proxy_server.py index e696fabb24dd..0d36a87e1a20 100644 --- a/examples/disagg_proxy_server.py +++ b/examples/disagg_proxy_server.py @@ -123,14 +123,20 @@ async def handle_completions(request: Request): try: req_data = await request.json() - # Send request to prefill service, ignore the response + # Send request to prefill service response = await send_request_to_service( - app.state.prefill_client, - "/completions", - req_data - ) - remote_block_ids = response.remote_block_ids - remote_engine_id = response.remote_engine_id + app.state.prefill_client, "/completions", req_data) + + # Extract the needed fields + response_json = response.json() + remote_block_ids = response_json.get('remote_block_ids', []) + remote_engine_id = response_json.get('remote_engine_id', '') + print("Prefiller response:\n" + str(response_json)) + + # Add these to the request data for the decoder + req_data['remote_block_ids'] = remote_block_ids + req_data['remote_engine_id'] = remote_engine_id + print(f"{req_data}") et = time.time() stats_calculator.add(et - st) diff --git a/vllm/config.py b/vllm/config.py index 2bf8a18250ce..6947e252bb14 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -8,6 +8,7 @@ import json import sys import textwrap +import uuid import warnings from collections import Counter from contextlib import contextmanager @@ -120,7 +121,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: def pairwise(iterable): """ Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise - + Can be removed when Python 3.9 support is dropped. """ iterator = iter(iterable) @@ -266,7 +267,7 @@ class ModelConfig: config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. hf_token: The token to use as HTTP bearer authorization for remote files - . If `True`, will use the token generated when running + . If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the @@ -1606,7 +1607,7 @@ class ParallelConfig: """The full name of the worker class to use. If "auto", the worker class will be determined based on the platform.""" sd_worker_cls: str = "auto" - """The full name of the worker class to use for speculative decofing. + """The full name of the worker class to use for speculative decofing. If "auto", the worker class will be determined based on the platform.""" worker_extension_cls: str = "" """The full name of the worker extension class to use. The worker extension @@ -1797,13 +1798,13 @@ class SchedulerConfig: max_num_batched_tokens: int = None # type: ignore """Maximum number of tokens to be processed in a single iteration. - + This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" max_num_seqs: int = None # type: ignore """Maximum number of sequences to be processed in a single iteration. - + This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" @@ -1849,7 +1850,7 @@ class SchedulerConfig: # TODO (ywang96): Make this configurable. max_num_encoder_input_tokens: int = field(init=False) """Multimodal encoder compute budget, only used in V1. - + NOTE: This is not currently configurable. It will be overridden by max_num_batched_tokens in case max multimodal embedding size is larger.""" @@ -3195,6 +3196,11 @@ class KVTransferConfig(BaseModel): # The KV connector for vLLM to transmit KV caches between vLLM instances. kv_connector: Optional[str] = None + # Engine ID for the KV transfers. + # Note(tms): sticking this here so the engine_id is consistent between + # scheduler-side and worker-side of the KVConnector + engine_id : str = str(uuid.uuid4()) + # The device used by kv connector to buffer the KV cache. # Currently only support 'cuda'. kv_buffer_device: Optional[str] = "cuda" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f0252b9ef222..dd2a1a89d717 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -72,9 +72,7 @@ def add_new_req( ): assert request_id not in self.requests assert kv_transfer_params.remote_engine_id is not None - print(kv_transfer_params.remote_engine_id) assert kv_transfer_params.remote_block_ids is not None - print("HERE") self.requests[request_id] = ReqMeta( local_block_ids=local_block_ids, @@ -85,7 +83,7 @@ def add_new_req( class NixlConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - self.engine_id = uuid.uuid4() + self.engine_id = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: self.connector_scheduler : Optional[NixlConnectorScheduler] = \ @@ -158,6 +156,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.engine_id = engine_id + logger.info("Initializing NIXL Scheduler " + engine_id) # Requests that need to start recv. # New requests are added by update_state_after_alloc in @@ -207,6 +206,10 @@ def build_connector_meta( local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, ) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + return meta @@ -218,6 +221,7 @@ def __init__(self, engine_id: str): logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") logger.info("Initializing NIXL wrapper") + logger.info("Initializing NIXL worker " + engine_id) # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) @@ -424,6 +428,7 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, tp_idx=0): return num_blocks = nixl_agent_meta.num_blocks + logger.debug("Adding remote agent " + engine_id + " " + str(num_blocks)) agent_names = [] agent_name = self.nixl_wrapper.add_remote_agent( @@ -538,8 +543,9 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): # NOTE: this is non-blocking logger.debug("start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", - req_id, len(meta.local_block_ids), - len(meta.remote_block_ids), meta.remote_engine_id) + req_id, meta.remote_engine_id, + len(meta.local_block_ids), + len(meta.remote_block_ids)) self._read_blocks( local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, @@ -597,13 +603,12 @@ def _read_blocks( local_block_descs_ids, remote_xfer_side_handle, remote_block_descs_ids, - notif_msg=request_id, + notif_msg=request_id.encode("utf-8"), ) # Call transfer to begin the async transfer # We will check this is done in the next forward pass. self.nixl_wrapper.transfer(handle) - self._recving_transfers[request_id].append(handle) def _get_block_descs_ids(self, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 728c6cd91478..04a85499e88c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -819,7 +819,7 @@ class CompletionRequest(OpenAIBaseModel): remote_engine_id: Optional[str] = Field( default=None, description="Remote engine id.") - + remote_block_ids: Optional[list[int]] = Field( default=None, description="Remote block ids.") @@ -928,11 +928,6 @@ def to_sampling_params( remote_block_ids=self.remote_block_ids, ) - import os - print("Setting sampling params in protocol.py") - NIXL_ROLE = os.getenv("NIXL_ROLE") - kv_transfer_params.remote_engine_id = str(NIXL_ROLE) + " set_in_protocol.py" - return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -1362,10 +1357,6 @@ class ChatCompletionResponse(OpenAIBaseModel): usage: UsageInfo prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None - # Add these fields for KV transfer - remote_engine_id: Optional[str] = None - remote_block_ids: Optional[list[int]] = Field(default_factory=list) - class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 57e75ff956a4..f011581b8909 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,6 +8,7 @@ import jinja2 from fastapi import Request +from torch._C import NoneType from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -475,7 +476,14 @@ def request_output_to_completion_response( ) request_metadata.final_usage_info = usage - + + if final_res_batch[0].kv_transfer_params is not None: + remote_engine_id=final_res_batch[0].kv_transfer_params.remote_engine_id + remote_block_ids=final_res_batch[0].kv_transfer_params.remote_block_ids + else: + remote_engine_id=None + remote_block_ids=None + assert len(final_res_batch) == 1 return CompletionResponse( id=request_id, @@ -483,8 +491,8 @@ def request_output_to_completion_response( model=model_name, choices=choices, usage=usage, - remote_engine_id=final_res_batch[0].kv_transfer_params.remote_engine_id, - remote_block_ids=final_res_batch[0].kv_transfer_params.remote_block_ids, + remote_engine_id=remote_engine_id, + remote_block_ids=remote_block_ids, ) def _create_completion_logprobs( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 2631d030efed..8c1b1d6c7d50 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -781,6 +781,8 @@ def update_from_output( block.block_id for block in self.kv_cache_manager.req_to_blocks[request.request_id] ] + # HACK(tms) - we're off by one between prefill an decode + remote_blocks.pop() kv_transfer_params = KVTransferParams( do_remote_prefill=True, @@ -813,8 +815,10 @@ def update_from_output( # P/D: update recv and send status from last step. for req_id in (model_runner_output.finished_recving or []): + logger.debug("FINISHED RECVING: " + req_id) self.finished_recving_KV_req_ids.add(req_id) for req_id in (model_runner_output.finished_sending or []): + logger.debug("FINISHED SENDING: " + req_id) self._free_blocks(self.requests[req_id]) self.running = new_running @@ -832,7 +836,6 @@ def update_from_output( def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request - print(f"Adding {request.request_id} to requests") if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) @@ -864,7 +867,6 @@ def finish_requests( else: self.waiting.remove(request) request.status = finished_status - print(f"freeing request {req_id}") self._free_request(request) def _free_request(self, request: Request, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index b1b691f1eb23..df8a2f270d82 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -187,8 +187,6 @@ def _new_request_output( else: prompt_logprobs = self.logprobs_processor.prompt_logprobs - print(f"Creating request output with kv_transfer_params {kv_transfer_params.remote_engine_id}") - return RequestOutput( request_id=request_id, prompt=self.prompt, diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 24052e01f006..baed401ac8b5 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -101,8 +101,8 @@ class ModelRunnerOutput: prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] # [req_ids] - finished_sending: Optional[list[str]] = None - finished_recving: Optional[list[str]] = None + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0ce2e0038c92..351fea8c301d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1039,7 +1039,7 @@ def maybe_get_finished() -> tuple[list[str], list[str]]: finished_sending, finished_recving = maybe_get_finished() # Return empty ModelRunnerOutput if there's no work to do. output = EMPTY_MODEL_RUNNER_OUTPUT - if len(finished_sending) > 0 or len(finished_sending) > 0: + if len(finished_sending) > 0 or len(finished_recving) > 0: output.finished_sending = finished_sending output.finished_recving = finished_recving return output @@ -1062,6 +1062,7 @@ def maybe_get_finished() -> tuple[list[str], list[str]]: attn_metadata.num_input_tokens = num_input_tokens else: # This may happen when there are outstanding KV transfers + print("tyler hack area " + str(scheduler_output.total_num_scheduled_tokens)) num_input_tokens = 1 attn_metadata = None logits_indices = None @@ -1139,6 +1140,9 @@ def maybe_get_finished() -> tuple[list[str], list[str]]: if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() + if len(finished_recving) > 0: + logger.debug(finished_recving) + return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, From bf37a7d500c094f0426a917328cbb7737cc382f6 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 27 Apr 2025 20:48:33 +0000 Subject: [PATCH 076/119] update Signed-off-by: Tyler Michael Smith --- examples/disagg_proxy_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/disagg_proxy_server.py b/examples/disagg_proxy_server.py index 0d36a87e1a20..3684336ce4bb 100644 --- a/examples/disagg_proxy_server.py +++ b/examples/disagg_proxy_server.py @@ -136,7 +136,6 @@ async def handle_completions(request: Request): # Add these to the request data for the decoder req_data['remote_block_ids'] = remote_block_ids req_data['remote_engine_id'] = remote_engine_id - print(f"{req_data}") et = time.time() stats_calculator.add(et - st) From ebe12632f956369bfeac8f7000f36d32b19b4de7 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 15:56:01 -0500 Subject: [PATCH 077/119] updated Signed-off-by: Robert Shaw --- .../test_remote_decode_scheduler.py | 18 ++++++++--- tests/v1/kv_connector/utils.py | 32 +++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/tests/v1/kv_connector/test_remote_decode_scheduler.py b/tests/v1/kv_connector/test_remote_decode_scheduler.py index 412514c010ec..2b89b5a7405f 100644 --- a/tests/v1/kv_connector/test_remote_decode_scheduler.py +++ b/tests/v1/kv_connector/test_remote_decode_scheduler.py @@ -47,15 +47,23 @@ def test_basic_remote_decode_cycle(): assert output.finish_reason == FinishReason.REMOTE_DECODE assert output.kv_transfer_params is not None - # Request freed in Scheduler and in Persistent Batch. - # This causes the request to be freed in the scheduler. - - - # This causes the request to be freed in the PB on next step(). + # Request freed in Scheduler and in Persistent Batch ... assert request_id in scheduler.finished_req_ids assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 # ... but blocks should not be freed. blocks = scheduler.kv_cache_manager.req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 + + # STEP (2): + # (2a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + + # model_runner_output = copy.deepcopy( + # EMPTY_MODEL_RUNNER_OUTPUT) + # model_runner_output.finished_sending diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index 3fbfc098e06c..5e1687ebb530 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -14,6 +14,38 @@ EOS_TOKEN_ID = 50256 +def assert_scheduler_empty(scheduler: Scheduler): + """Assert Scheduler Is Empty.""" + # Scheduler Metadata. + assert len(scheduler.requests) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.scheduled_req_ids) == 0 + assert len(scheduler.finished_req_ids) == 0 + assert len(scheduler.finished_recving_KV_req_ids) == 0 + assert len(scheduler._cached_reqs_data) == 0 + + # EncoderCacheManager. + assert len(scheduler.encoder_cache_manager.freed) == 0 + assert len(scheduler.encoder_cache_manager.cached) == 0 + + # KVCache Manager. + assert len(scheduler.kv_cache_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 + assert len(scheduler.kv_cache_manager.num_cached_block) == 0 + num_free_blocks = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + assert num_free_blocks == ( + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + assert ( + len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block) == 0) + + for block in scheduler.kv_cache_manager.block_pool.blocks: + assert block.block_hash is None + assert block.ref_cnt == 0 + + + def create_vllm_config( model: str = "facebook/opt-125m", max_num_seqs: int = 16, From a008aa3494a9f6a8fed4fccc43f2481bd34aebe8 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 16:14:38 -0500 Subject: [PATCH 078/119] updated Signed-off-by: Robert Shaw --- .../test_remote_decode_scheduler.py | 39 +++++++++++++++---- tests/v1/kv_connector/utils.py | 9 +++-- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/tests/v1/kv_connector/test_remote_decode_scheduler.py b/tests/v1/kv_connector/test_remote_decode_scheduler.py index 2b89b5a7405f..2ad1af853211 100644 --- a/tests/v1/kv_connector/test_remote_decode_scheduler.py +++ b/tests/v1/kv_connector/test_remote_decode_scheduler.py @@ -5,7 +5,8 @@ from vllm.v1.request import RequestStatus, FinishReason from .utils import (create_request, create_scheduler, - create_vllm_config, create_model_runner_output) + create_vllm_config, create_model_runner_output, + assert_scheduler_empty) def test_basic_remote_decode_cycle(): """Test Remote Decode Lifecycle.""" @@ -27,7 +28,7 @@ def test_basic_remote_decode_cycle(): scheduler.add_request(request) request_id = request.request_id - # STEP (1): + # STEP (1): Prefill. # (1a): schedule() scheduler_output = scheduler.schedule() assert len(scheduler.running) == 1 @@ -57,13 +58,37 @@ def test_basic_remote_decode_cycle(): for block in blocks: assert block.ref_cnt == 1 - # STEP (2): - # (2a): schedule() + # STEP (2): Send Finished to PB. + # (2a): schedule() - pass finished request to PB. scheduler_output = scheduler.schedule() assert len(scheduler.running) == 0 + assert len(scheduler_output.finished_req_ids) == 1 + assert request_id in scheduler_output.finished_req_ids assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.finished_req_ids) == 0 - # model_runner_output = copy.deepcopy( - # EMPTY_MODEL_RUNNER_OUTPUT) - # model_runner_output.finished_sending + # (2b): execute_model() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (2c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP (3): Finished sending. + # (3a): schedule() - pass finished request to PB. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 0 + assert len(scheduler_output.finished_req_ids) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.finished_req_ids) == 0 + + # (3b): execute_model() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_sending = [request_id] + + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm we do not have any memory leaks after req lifecycle. + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index 5e1687ebb530..6a27d060c764 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -37,15 +37,16 @@ def assert_scheduler_empty(scheduler: Scheduler): scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) assert num_free_blocks == ( scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) - assert ( - len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block) == 0) + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. for block in scheduler.kv_cache_manager.block_pool.blocks: - assert block.block_hash is None assert block.ref_cnt == 0 + # assert block._block_hash is None + # assert ( + # len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block) == 0) - def create_vllm_config( model: str = "facebook/opt-125m", max_num_seqs: int = 16, From 195dceb49057f67a55b107d821fec1c19e1c193b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 16:16:36 -0500 Subject: [PATCH 079/119] rename files Signed-off-by: Robert Shaw --- ...te_decode_scheduler.py => test_remote_decode_lifecycle.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename tests/v1/kv_connector/{test_remote_decode_scheduler.py => test_remote_decode_lifecycle.py} (97%) diff --git a/tests/v1/kv_connector/test_remote_decode_scheduler.py b/tests/v1/kv_connector/test_remote_decode_lifecycle.py similarity index 97% rename from tests/v1/kv_connector/test_remote_decode_scheduler.py rename to tests/v1/kv_connector/test_remote_decode_lifecycle.py index 2ad1af853211..e3485cddedbe 100644 --- a/tests/v1/kv_connector/test_remote_decode_scheduler.py +++ b/tests/v1/kv_connector/test_remote_decode_lifecycle.py @@ -8,8 +8,8 @@ create_vllm_config, create_model_runner_output, assert_scheduler_empty) -def test_basic_remote_decode_cycle(): - """Test Remote Decode Lifecycle.""" +def test_remote_decode_cycle(): + """Test lifecycle of a Remote Decode request.""" vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) From e2cc365ad0dac2339144dc95d3842bf7b9b434b9 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 16:17:37 -0500 Subject: [PATCH 080/119] updated Signed-off-by: Robert Shaw --- ...mote_prefill_scheduler.py => test_remote_prefill_lifecycle.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/v1/kv_connector/{test_remote_prefill_scheduler.py => test_remote_prefill_lifecycle.py} (100%) diff --git a/tests/v1/kv_connector/test_remote_prefill_scheduler.py b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py similarity index 100% rename from tests/v1/kv_connector/test_remote_prefill_scheduler.py rename to tests/v1/kv_connector/test_remote_prefill_lifecycle.py From b4b64feff5b39e9de8517966be42c0daa2fcf2c6 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 16:38:22 -0500 Subject: [PATCH 081/119] updated Signed-off-by: Robert Shaw --- .../test_remote_prefill_lifecycle.py | 48 ++++++++++++++++--- tests/v1/kv_connector/utils.py | 10 +++- vllm/v1/core/sched/scheduler.py | 1 + 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py index 98a65904e7ea..42a7a417e0c5 100644 --- a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py @@ -1,14 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from typing import Optional from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT -from vllm.v1.request import RequestStatus, Request +from vllm.v1.request import RequestStatus, FinishReason from .utils import (create_request, create_scheduler, - create_vllm_config, create_model_runner_output) + create_vllm_config, create_model_runner_output, + assert_scheduler_empty) -def test_basic_remote_prefill_cycle(): +def test_basic_lifecycle(): """Test Remote Prefills Lifecycle.""" vllm_config = create_vllm_config() @@ -79,6 +79,7 @@ def test_basic_remote_prefill_cycle(): assert len(scheduler.waiting) == 1 assert (request_id in scheduler.finished_recving_KV_req_ids) + # STEP (3): # (3a): schedule(): this should actually schedule. scheduler_output = scheduler.schedule() assert len(scheduler.running) == 1 @@ -97,8 +98,27 @@ def test_basic_remote_prefill_cycle(): total_prompt_tokens = len(scheduled_req.prompt_token_ids) assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + # (3b): execute_model() + model_runner_output = create_model_runner_output([request]) + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) -def test_interleaved_remote_prefill_cycle(): + # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + [request], use_eos=True) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs.outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) + + +def test_interleaved_lifecycle(): """Test Remote Prefills Work Well With Other Requests.""" vllm_config = create_vllm_config() @@ -184,8 +204,24 @@ def test_interleaved_remote_prefill_cycle(): assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_cached_reqs) == 2 + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b, request_remote] + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 6: Hit EOS and free. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b, request_remote], + use_eos=True, + ) + scheduler.update_from_output( + scheduler_output, model_runner_output) + scheduler.schedule() + assert_scheduler_empty(scheduler) + -def test_remote_prefill_no_prefix_cache_uncomputed_blocks(): +def test_no_spurious_prefix_caching(): """ With P/D, blocks can be allocated but uncomputed for multiple engine steps. This test confirms that we do diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index 6a27d060c764..d0ae1cc2a734 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -165,15 +165,21 @@ def create_model_runner_output( reqs: list[Request], finished_sending: Optional[list[str]] = None, finished_recving: Optional[list[str]] = None, + use_eos: bool = False, ) -> ModelRunnerOutput: """Make dummy model runner output for testing.""" + # Make request data. req_ids = [req.request_id for req in reqs] req_id_to_index = { req_id: idx for idx, req_id in enumerate(req_ids) } - sampled_token_ids = [[0] for _ in req_ids] - + + # Make sampled tokens. + sampled_token = EOS_TOKEN_ID if use_eos else 0 + sampled_token_ids = [[sampled_token] for _ in req_ids] + + # Make output data structure. return ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_id_to_index, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ff4c0a594759..c69692edc6bb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -305,6 +305,7 @@ def schedule(self) -> SchedulerOutput: num_tokens=0, num_computed_tokens=(len(request.all_token_ids) - 1) ) + self.finished_recving_KV_req_ids.remove(request.request_id) request.status = RequestStatus.WAITING self.kv_cache_manager.free(request) else: From 6686397b87ab9adc88f59c058f29fea184c394d5 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 16:39:36 -0500 Subject: [PATCH 082/119] updated Signed-off-by: Robert Shaw --- tests/v1/kv_connector/test_remote_decode_lifecycle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/test_remote_decode_lifecycle.py index e3485cddedbe..32d64d6d5328 100644 --- a/tests/v1/kv_connector/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/test_remote_decode_lifecycle.py @@ -8,7 +8,7 @@ create_vllm_config, create_model_runner_output, assert_scheduler_empty) -def test_remote_decode_cycle(): +def test_basic_lifecycle(): """Test lifecycle of a Remote Decode request.""" vllm_config = create_vllm_config() From 8736043a0c5a6b42b31fbc88bc3ff3dc730201c8 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 16:39:48 -0500 Subject: [PATCH 083/119] updated Signed-off-by: Robert Shaw --- tests/v1/kv_connector/test_remote_decode_lifecycle.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/v1/kv_connector/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/test_remote_decode_lifecycle.py index 32d64d6d5328..c3c01d127754 100644 --- a/tests/v1/kv_connector/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/test_remote_decode_lifecycle.py @@ -18,8 +18,6 @@ def test_basic_lifecycle(): BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - START_FREE_BLOCK_QUEUE_SIZE = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) request = create_request(request_id=1, num_tokens=NUM_TOKENS, From dcbf6e5648e60e3af02e925785aafbc49aa518c8 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 16:41:31 -0500 Subject: [PATCH 084/119] updated Signed-off-by: Robert Shaw --- tests/v1/kv_connector/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index d0ae1cc2a734..825e6186b5ae 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -15,7 +15,7 @@ EOS_TOKEN_ID = 50256 def assert_scheduler_empty(scheduler: Scheduler): - """Assert Scheduler Is Empty.""" + """Confirm the scheduler is "empty" - i.e. no leaks.""" # Scheduler Metadata. assert len(scheduler.requests) == 0 assert len(scheduler.waiting) == 0 From 7c8e21af011e3a347cd0d8215ab8262c94f5ef64 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 27 Apr 2025 17:42:10 -0500 Subject: [PATCH 085/119] update Signed-off-by: Robert Shaw --- tests/v1/kv_connector/test_remote_prefill_lifecycle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py index 42a7a417e0c5..73315e6121ae 100644 --- a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py @@ -39,8 +39,7 @@ def test_basic_lifecycle(): assert len(scheduler_output.num_scheduled_tokens) == 0 assert scheduler_output.total_num_scheduled_tokens == 0 - # Req waiting for KVs with no computed - # or scheduled tokens. + # Req waiting for KVs with no computed/scheduled toks ... assert len(scheduler.waiting) == 1 assert request in scheduler.waiting assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) From a4855d2b27b71834ae0c7ea22c116941d75d8bdd Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 27 Apr 2025 22:44:42 +0000 Subject: [PATCH 086/119] Second request no longer crashes Signed-off-by: Tyler Michael Smith --- examples/disagg_proxy_server.py | 22 ++++++++++++++----- .../kv_connector/v1/nixl_connector.py | 2 +- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/examples/disagg_proxy_server.py b/examples/disagg_proxy_server.py index 3684336ce4bb..85365ef9b6c8 100644 --- a/examples/disagg_proxy_server.py +++ b/examples/disagg_proxy_server.py @@ -3,6 +3,7 @@ import argparse import os import time +import uuid from contextlib import asynccontextmanager import httpx @@ -83,14 +84,17 @@ def parse_args(): async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, - req_data: dict): + req_data: dict, request_id: str): """ Send a request to a service using a persistent client. """ req_data = req_data.copy() req_data['do_remote_decode'] = True req_data["stream"] = False - headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } response = await client.post(endpoint, json=req_data, headers=headers) response.raise_for_status() @@ -99,11 +103,14 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, async def stream_service_response(client: httpx.AsyncClient, endpoint: str, req_data: dict, remote_block_ids: list[int], - remote_engine_id: str): + remote_engine_id: str, request_id: str): """ Asynchronously stream the response from a service using a persistent client. """ - headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } req_data['do_remote_prefill'] = True req_data["remote_block_ids"] = remote_block_ids req_data['remote_engine_id'] = remote_engine_id @@ -123,9 +130,11 @@ async def handle_completions(request: Request): try: req_data = await request.json() + request_id = str(uuid.uuid4()) + # Send request to prefill service response = await send_request_to_service( - app.state.prefill_client, "/completions", req_data) + app.state.prefill_client, "/completions", req_data, request_id) # Extract the needed fields response_json = response.json() @@ -147,7 +156,8 @@ async def generate_stream(): "/completions", req_data, remote_block_ids=remote_block_ids, - remote_engine_id=remote_engine_id): + remote_engine_id=remote_engine_id, + request_id=request_id): yield chunk return StreamingResponse(generate_stream(), diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index dd2a1a89d717..516066a4bc9f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -502,7 +502,7 @@ def _get_new_notifs(self) -> list[str]: for req_ids in self.nixl_wrapper.get_new_notifs().values(): for req_id in req_ids: assert req_id not in notified_req_ids - notified_req_ids.append(req_id.decode('utf-8')) + notified_req_ids.append(req_id.decode("utf-8")) return notified_req_ids def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: From c5b3053605ea419734d728c040d5a18b7c168b4a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 27 Apr 2025 23:05:48 +0000 Subject: [PATCH 087/119] Remove gpu_model_runner hacks Signed-off-by: Tyler Michael Smith --- vllm/v1/worker/gpu_model_runner.py | 53 +++++++----------------------- 1 file changed, 12 insertions(+), 41 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 351fea8c301d..a6076f6ab606 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1045,28 +1045,20 @@ def maybe_get_finished() -> tuple[list[str], list[str]]: return output # Prepare the decoder inputs. - num_reqs = self.input_batch.num_reqs num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if num_reqs > 0: - attn_metadata, logits_indices, spec_decode_metadata = ( - self._prepare_inputs(scheduler_output)) - if (self.use_cuda_graph - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) - else: - # Eager mode. - num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens + + attn_metadata, logits_indices, spec_decode_metadata = ( + self._prepare_inputs(scheduler_output)) + if (self.use_cuda_graph + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) else: - # This may happen when there are outstanding KV transfers - print("tyler hack area " + str(scheduler_output.total_num_scheduled_tokens)) - num_input_tokens = 1 - attn_metadata = None - logits_indices = None - spec_decode_metadata = None + # Eager mode. + num_input_tokens = num_scheduled_tokens + attn_metadata.num_input_tokens = num_input_tokens # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1133,27 +1125,6 @@ def maybe_get_finished() -> tuple[list[str], list[str]]: # For mid-pipeline stages, return the hidden states. return hidden_states - if logits_indices is None: - # HACK(tms): Early exit - - # Clear KVConnector state after all KVs are generated. - if has_kv_transfer_group(): - get_kv_transfer_group().clear_connector_metadata() - - if len(finished_recving) > 0: - logger.debug(finished_recving) - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - finished_sending=finished_sending, - finished_recving=finished_recving, - ) - hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) From 75028190e8f2d6d8ad78691d8e000addcd2d9499 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 28 Apr 2025 01:15:41 +0000 Subject: [PATCH 088/119] Clean up Justfile Signed-off-by: Tyler Michael Smith --- Justfile | 8 -------- 1 file changed, 8 deletions(-) diff --git a/Justfile b/Justfile index bb734fba207e..ab94c6819426 100644 --- a/Justfile +++ b/Justfile @@ -1,9 +1,3 @@ -notes: - UCX_RNDV_THRESH=0 # Force rendezvous protocol for all messages - UCX_MEMTYPE_CACHE=n # Disable memory type caching - UCX_TLS=rc,ud,dc,cuda_copy,cuda_ipc,gdr_copy # Prioritize RDMA transports - UCX_ZCOPY_THRESH=0 # Force zero-copy for all sizes - prefill: UCX_LOG_LEVEL=debug \ NIXL_ROLE="SENDER" \ @@ -14,7 +8,6 @@ prefill: vllm serve meta-llama/Llama-3.2-1B-Instruct \ --port 8100 \ --enforce-eager \ - --load-format dummy \ --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' decode: @@ -27,7 +20,6 @@ decode: vllm serve meta-llama/Llama-3.2-1B-Instruct \ --port 8200 \ --enforce-eager \ - --load-format dummy \ --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' proxy: From 7768b96387e1077593ab76e897f1862489c83622 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 28 Apr 2025 18:25:33 +0000 Subject: [PATCH 089/119] [Bugfix] Stale finished requests in EMPTY_MODEL_RUNNER_OUTPUT Signed-off-by: Tyler Michael Smith --- vllm/v1/worker/gpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a6076f6ab606..35fccd654d6e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import copy import gc import time import weakref @@ -1038,7 +1039,7 @@ def maybe_get_finished() -> tuple[list[str], list[str]]: maybe_wait_for_save() finished_sending, finished_recving = maybe_get_finished() # Return empty ModelRunnerOutput if there's no work to do. - output = EMPTY_MODEL_RUNNER_OUTPUT + output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) if len(finished_sending) > 0 or len(finished_recving) > 0: output.finished_sending = finished_sending output.finished_recving = finished_recving From a5950b7e1ccfbba26c767a1b6070852ceda60e03 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 28 Apr 2025 18:38:39 +0000 Subject: [PATCH 090/119] update Signed-off-by: Tyler Michael Smith --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 6 +++--- vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 516066a4bc9f..01388aab6ad4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -493,16 +493,16 @@ def get_finished(self) -> tuple[set[str], set[str]]: len(done_recving)) return done_sending, done_recving - def _get_new_notifs(self) -> list[str]: + def _get_new_notifs(self) -> set[str]: """Get req_ids which got a remote xfer message.""" - notified_req_ids: list[str] = [] + notified_req_ids: set[str] = set() # TODO: handle the TP case (N notifies for TP=N). # See: vllm/worker/worker_base.py L476 in DynamoPR. for req_ids in self.nixl_wrapper.get_new_notifs().values(): for req_id in req_ids: assert req_id not in notified_req_ids - notified_req_ids.append(req_id.decode("utf-8")) + notified_req_ids.add(req_id.decode("utf-8")) return notified_req_ids def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 35fccd654d6e..a4ae59e9d3fc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1024,7 +1024,7 @@ def maybe_wait_for_save(): kv_connector = get_kv_transfer_group() kv_connector.wait_for_save() - def maybe_get_finished() -> tuple[list[str], list[str]]: + def maybe_get_finished() -> tuple[set[str], set[str]]: if has_kv_transfer_group(): kv_connector = get_kv_transfer_group() return kv_connector.get_finished() From 610a35716d24beea1d888143009b4542ab60110d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 28 Apr 2025 19:14:20 +0000 Subject: [PATCH 091/119] justfile edits Signed-off-by: Tyler Michael Smith --- Justfile | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/Justfile b/Justfile index ab94c6819426..859b3abcbb50 100644 --- a/Justfile +++ b/Justfile @@ -5,9 +5,10 @@ prefill: VLLM_LOGGING_LEVEL="DEBUG" \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ VLLM_ENABLE_V1_MULTIPROCESSING=0 \ - vllm serve meta-llama/Llama-3.2-1B-Instruct \ + vllm serve meta-llama/Llama-3.1-8B-Instruct \ --port 8100 \ --enforce-eager \ + --disable-log-requests \ --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' decode: @@ -17,9 +18,10 @@ decode: VLLM_LOGGING_LEVEL="DEBUG" \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ VLLM_ENABLE_V1_MULTIPROCESSING=0 \ - vllm serve meta-llama/Llama-3.2-1B-Instruct \ + vllm serve meta-llama/Llama-3.1-8B-Instruct \ --port 8200 \ --enforce-eager \ + --disable-log-requests \ --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' proxy: @@ -29,7 +31,7 @@ send_request: curl -X POST http://localhost:8192/v1/completions \ -H "Content-Type: application/json" \ -d '{ \ - "model": "meta-llama/Llama-3.2-1B-Instruct", \ + "model": "meta-llama/Llama-3.1-8B-Instruct", \ "prompt": "Generate a curl command to send to an openai server hosted at local_host:8192 with this as the", \ "max_tokens": 150, \ "temperature": 0.7 \ From 5b026ab24dc6f70930367e70d3ba95747fc4de18 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 28 Apr 2025 20:22:24 +0000 Subject: [PATCH 092/119] Update Signed-off-by: Tyler Michael Smith --- Justfile | 2 +- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 +++++ vllm/v1/core/sched/scheduler.py | 6 +++--- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/Justfile b/Justfile index 859b3abcbb50..771d1afec13c 100644 --- a/Justfile +++ b/Justfile @@ -32,7 +32,7 @@ send_request: -H "Content-Type: application/json" \ -d '{ \ "model": "meta-llama/Llama-3.1-8B-Instruct", \ - "prompt": "Generate a curl command to send to an openai server hosted at local_host:8192 with this as the", \ + "prompt": "EXPLAIN KERMIT THE FROG", \ "max_tokens": 150, \ "temperature": 0.7 \ }' diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 01388aab6ad4..998ff3f4d141 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -572,6 +572,11 @@ def _read_blocks( # NOTE(rob): we could potentially do the rearranging during the load_kv! + # Note(tms): The remote_block_ids only contain full computed blocks, + # while the local_block_ids are all blocks allocated for this request, + # so truncate the local_block_ids to account for this. + if len(remote_block_ids) < len(local_block_ids): + local_block_ids = local_block_ids[:len(remote_block_ids)] assert len(local_block_ids) == len(remote_block_ids) if len(local_block_ids) == 0: return diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3f4c7fddf53c..72d3030447f2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -779,12 +779,12 @@ def update_from_output( stopped = True # TODO(rob): do this on a per-Connector basis. + # TODO(tms): Should this be get_computed_blocks to only send full blocks? remote_blocks = [ block.block_id for block in - self.kv_cache_manager.req_to_blocks[request.request_id] + #self.kv_cache_manager.req_to_blocks[request.request_id] + self.kv_cache_manager.get_computed_blocks(request)[0] ] - # HACK(tms) - we're off by one between prefill an decode - remote_blocks.pop() kv_transfer_params = KVTransferParams( do_remote_prefill=True, From f2fadd6722c5323327a3adf01ce5e58756954dd6 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 29 Apr 2025 18:29:16 +0000 Subject: [PATCH 093/119] Fixes - lm_eval gsm8k has correctness Signed-off-by: Tyler Michael Smith --- .../kv_connector/v1/nixl_connector.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 998ff3f4d141..347e603d661b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -233,7 +233,6 @@ def __init__(self, engine_id: str): self.rank = 0 # KV Caches and nixl tracking data. - self.num_layers: int = 0 self.kv_caches: dict[str, torch.Tensor] = {} # Map of engine_id -> kv_caches_base_addr @@ -243,6 +242,10 @@ def __init__(self, engine_id: str): # list[list[tuple[int, int]]]] self.kv_caches_base_addr: dict[str, list[int]] = {} + # Number of NIXL regions. Currently one region per cache + # (so 1 per layer for MLA, otherwise 2 per layer) + self.num_regions = 0 + # Map of tp_mult -> nixl_prepped_dlist_handle (int). self.src_xfer_side_handles: dict[int, int] = {} # Map of engine_id -> map[tp_mult -> nixl_prepped_dlist_handle (int)]. @@ -272,7 +275,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len = kv_elem_size * math.prod(first_kv_cache.shape[-3:]) logger.debug("Per layer kv cache size: %s", first_kv_cache[0].shape) - self.num_layers = len(kv_caches) self.num_blocks = num_blocks self.kv_caches = kv_caches kv_caches_base_addr = [] @@ -291,6 +293,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): caches_data.append((base_addr, region_len, self.rank, "")) kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") logger.debug("Registering descs: %s", caches_data) @@ -618,16 +621,16 @@ def _read_blocks( def _get_block_descs_ids(self, engine_id, - layer_ids, + region_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None): - if layer_ids == "all": - layer_ids = list(range(self.num_layers)) + if region_ids == "all": + region_ids = range(self.num_regions) if block_ids == "all": - block_ids = list(range(self.num_blocks)) + block_ids = range(self.num_blocks) descs_ids = [] @@ -636,7 +639,7 @@ def _get_block_descs_ids(self, "the same TP size.") else: num_blocks = self.dst_num_blocks[engine_id] - for layer_id in 2 * layer_ids: + for reg_id in region_ids: for block_id in block_ids: - descs_ids.append(layer_id * num_blocks + block_id) + descs_ids.append(reg_id * num_blocks + block_id) return descs_ids From 4060f86a0677f56be01e21bf4b1d480a37aa8d5b Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 29 Apr 2025 20:35:48 +0000 Subject: [PATCH 094/119] "just delete the assert" Signed-off-by: Tyler Michael Smith --- vllm/v1/core/block_pool.py | 2 ++ vllm/v1/core/kv_cache_manager.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 74f3f7852c9a..ffbe6d74e868 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -117,6 +117,8 @@ def cache_full_blocks( prev_block_hash_value = prev_block.block_hash.hash_value for i, blk in enumerate(new_full_blocks): + if blk.block_hash is not None: + continue assert blk.block_hash is None if i < len(new_block_hashes): diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 54f348ef54f5..ff5485932dcb 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -301,7 +301,7 @@ def cache_blocks( new_computed_blocks: Optional[list[KVCacheBlock]] = None, ): if new_computed_blocks is None: - new_computed_blocks = [] + new_computed_blocks = [] req_blocks = self.req_to_blocks[request.request_id] From bfe9d1957be6dbb09c46807ec83ecebf62e761d2 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 29 Apr 2025 22:21:25 +0000 Subject: [PATCH 095/119] fixup precommit issues Signed-off-by: Tyler Michael Smith --- .../kv_connector/v1/nixl_connector.py | 19 ++++---- vllm/v1/core/sched/scheduler.py | 48 +++++++++---------- vllm/v1/outputs.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 2 +- 4 files changed, 35 insertions(+), 38 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 347e603d661b..1ee9567aab97 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -156,7 +156,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.engine_id = engine_id - logger.info("Initializing NIXL Scheduler " + engine_id) + logger.info("Initializing NIXL Scheduler %s", engine_id) # Requests that need to start recv. # New requests are added by update_state_after_alloc in @@ -189,8 +189,7 @@ def update_state_after_alloc(self, request: "Request", if request.do_remote_decode: pass if request.do_remote_prefill and num_external_tokens > 0: - self._reqs_need_recv[request.request_id] = ( - request, block_ids) + self._reqs_need_recv[request.request_id] = (request, block_ids) def build_connector_meta( self, @@ -221,7 +220,7 @@ def __init__(self, engine_id: str): logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") logger.info("Initializing NIXL wrapper") - logger.info("Initializing NIXL worker " + engine_id) + logger.info("Initializing NIXL worker %s", engine_id) # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) @@ -431,7 +430,7 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, tp_idx=0): return num_blocks = nixl_agent_meta.num_blocks - logger.debug("Adding remote agent " + engine_id + " " + str(num_blocks)) + logger.debug("Adding remote agent %s %s", engine_id, str(num_blocks)) agent_names = [] agent_name = self.nixl_wrapper.add_remote_agent( @@ -544,11 +543,11 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): """ for req_id, meta in metadata.requests.items(): # NOTE: this is non-blocking - logger.debug("start_load_kv for request %s from remote engine %s. " - "Num local_block_ids: %s. Num remote_block_ids: %s. ", - req_id, meta.remote_engine_id, - len(meta.local_block_ids), - len(meta.remote_block_ids)) + logger.debug( + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + meta.remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) self._read_blocks( local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 72d3030447f2..1a66259f4154 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -32,6 +32,7 @@ logger = init_logger(__name__) + class Scheduler(SchedulerInterface): def __init__( @@ -303,9 +304,10 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.cache_blocks( request, num_tokens=0, - num_computed_tokens=(len(request.all_token_ids) - 1) - ) - self.finished_recving_KV_req_ids.remove(request.request_id) + num_computed_tokens=(len(request.all_token_ids) - + 1)) + self.finished_recving_KV_req_ids.remove( + request.request_id) request.status = RequestStatus.WAITING self.kv_cache_manager.free(request) else: @@ -380,8 +382,8 @@ def schedule(self) -> SchedulerOutput: # `request.num_prompt_tokens` to consider the resumed reqs, # which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): num_new_tokens = ( self.scheduler_config.long_prefill_token_threshold) num_new_tokens = min(num_new_tokens, token_budget) @@ -390,10 +392,9 @@ def schedule(self) -> SchedulerOutput: # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_budget) + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -447,14 +448,13 @@ def schedule(self) -> SchedulerOutput: request.num_computed_tokens = num_computed_tokens # Encoder-related. - if not request.do_remote_prefill: - if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) - # Allocate the encoder cache. - for i in encoder_inputs_to_schedule: - self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget + if not request.do_remote_prefill and encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: @@ -687,8 +687,7 @@ def update_from_output( new_running.append(request) continue - - if not req_id in model_runner_output.req_id_to_index: + if req_id not in model_runner_output.req_id_to_index: print(req_id) print(model_runner_output.req_id_to_index) continue @@ -779,17 +778,15 @@ def update_from_output( stopped = True # TODO(rob): do this on a per-Connector basis. - # TODO(tms): Should this be get_computed_blocks to only send full blocks? remote_blocks = [ block.block_id for block in - #self.kv_cache_manager.req_to_blocks[request.request_id] self.kv_cache_manager.get_computed_blocks(request)[0] ] kv_transfer_params = KVTransferParams( do_remote_prefill=True, remote_block_ids=remote_blocks, - remote_engine_id=self.connector.engine_id, + remote_engine_id=self.vllm_config.engine_id, ) # Add EngineCoreOutput for this Request. @@ -815,10 +812,10 @@ def update_from_output( # P/D: update recv and send status from last step. for req_id in (model_runner_output.finished_recving or []): - logger.debug("FINISHED RECVING: " + req_id) + logger.debug("FINISHED RECVING: %s", req_id) self.finished_recving_KV_req_ids.add(req_id) for req_id in (model_runner_output.finished_sending or []): - logger.debug("FINISHED SENDING: " + req_id) + logger.debug("FINISHED SENDING: %s", req_id) self._free_blocks(self.requests[req_id]) self.running = new_running @@ -869,7 +866,8 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, request: Request, + def _free_request(self, + request: Request, skip_free_blocks: bool = False) -> None: assert request.is_finished() self.encoder_cache_manager.free(request) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index baed401ac8b5..e8ce0df5ed8d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -111,5 +111,5 @@ class ModelRunnerOutput: spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - finished_sending=[], - finished_recving=[]) + finished_sending=None, + finished_recving=None) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a4ae59e9d3fc..f987f2db4359 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1029,7 +1029,7 @@ def maybe_get_finished() -> tuple[set[str], set[str]]: kv_connector = get_kv_transfer_group() return kv_connector.get_finished() else: - return [], [] + return set(), set() self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: From ced529ad9d93439e4aa938c44f2555eaefda395b Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 29 Apr 2025 22:46:15 +0000 Subject: [PATCH 096/119] Fixes Signed-off-by: Tyler Michael Smith --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 3 ++- vllm/v1/core/sched/scheduler.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1ee9567aab97..cdef95f18b98 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -275,6 +275,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("Per layer kv cache size: %s", first_kv_cache[0].shape) self.num_blocks = num_blocks + self.dst_num_blocks[self.engine_id] = num_blocks self.kv_caches = kv_caches kv_caches_base_addr = [] caches_data = [] @@ -592,7 +593,7 @@ def _read_blocks( # Read the data from the remote. for i in range(tp_multiplier): local_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, + self.engine_id, "all", local_block_ids, i=None, #TODO: Enable both tp_multiplier and staging_ranges. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1a66259f4154..82759d28b20c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -783,10 +783,11 @@ def update_from_output( self.kv_cache_manager.get_computed_blocks(request)[0] ] + engine_id = self.vllm_config.kv_transfer_config.engine_id kv_transfer_params = KVTransferParams( do_remote_prefill=True, remote_block_ids=remote_blocks, - remote_engine_id=self.vllm_config.engine_id, + remote_engine_id=engine_id, ) # Add EngineCoreOutput for this Request. From 83f2872557a0341539b8f5a53845081f4c0b34fc Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Tue, 29 Apr 2025 20:59:09 -0400 Subject: [PATCH 097/119] updated (#12) Signed-off-by: rshaw@neuralmagic.com --- .../openai_completion_client.py | 6 ++++-- vllm/v1/core/block_pool.py | 2 -- vllm/v1/core/kv_cache_manager.py | 19 +++++++++++++------ 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index b31ebcccce3b..1f8a7e5b078c 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -6,6 +6,9 @@ openai_api_key = "EMPTY" openai_api_base = "http://localhost:8192/v1" +PROMPT = "The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, which means that there is a lot of opportunity to work with community and customers on key infrastructure projects. This means", # noqa: E501 +PROMPT = "The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, " # noqa: E501 + def main(): client = OpenAI( @@ -21,8 +24,7 @@ def main(): stream = True completion = client.completions.create( model="meta-llama/Llama-3.1-8B-Instruct", - prompt= - "The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, which means that there is a lot of opportunity to work with community and customers on key infrastructure projects. This means", # noqa: E501 + prompt=PROMPT, echo=False, stream=stream) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ffbe6d74e868..74f3f7852c9a 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -117,8 +117,6 @@ def cache_full_blocks( prev_block_hash_value = prev_block.block_hash.hash_value for i, blk in enumerate(new_full_blocks): - if blk.block_hash is not None: - continue assert blk.block_hash is None if i < len(new_block_hashes): diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ff5485932dcb..41d9f1b65c23 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -279,12 +279,19 @@ def allocate_slots( new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) - if not self.enable_caching or skip_cache_blocks: - # If self.enable_caching, this is true since can only - # get to this codepath when we have never been scheduled. - assert request.request_id not in self.num_cached_block + if not self.enable_caching: return new_blocks + if skip_cache_blocks: + # NOTE(rob): this assert is valid because we only call + # skip_cache_blocks=True on the first time of WAITING + # during a P/D setup. + assert request.request_id not in self.num_cached_block + # NOTE(rob): this is necessary so we don't double + # cache a block after is has finished recving. + self.num_cached_block[request.request_id] = len( + new_computed_blocks) + self.cache_blocks( request=request, num_tokens=num_tokens, @@ -313,8 +320,8 @@ def cache_blocks( # Speculated tokens might be rejected in the future, so we do # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. - num_full_blocks_after_append = ( - num_computed_tokens + num_tokens - len(request.spec_token_ids)) // self.block_size + num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( + request.spec_token_ids)) // self.block_size self.block_pool.cache_full_blocks( request=request, From e853b3ceeaf256a08495c20fee5dade9c6e72741 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Tue, 29 Apr 2025 22:19:43 -0400 Subject: [PATCH 098/119] Add Accuracy Test (#13) * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com --------- Signed-off-by: rshaw@neuralmagic.com --- tests/v1/kv_connector/run_accuracy_test.sh | 52 ++++++++ tests/v1/kv_connector/test_accuracy.py | 28 ++++ tests/v1/kv_connector/toy_proxy_server.py | 145 +++++++++++++++++++++ 3 files changed, 225 insertions(+) create mode 100644 tests/v1/kv_connector/run_accuracy_test.sh create mode 100644 tests/v1/kv_connector/test_accuracy.py create mode 100644 tests/v1/kv_connector/toy_proxy_server.py diff --git a/tests/v1/kv_connector/run_accuracy_test.sh b/tests/v1/kv_connector/run_accuracy_test.sh new file mode 100644 index 000000000000..0aab60e4adca --- /dev/null +++ b/tests/v1/kv_connector/run_accuracy_test.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +set -xe + +# Model to run. +MODEL_NAME=Qwen/Qwen3-0.6B + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 + pkill -f python + echo "Cleanup complete. Exiting." + exit 0 +} + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Prefill instance. +CUDA_VISIBLE_DEVICES=0 NIXL_ROLE="SENDER" vllm serve $MODEL_NAME \ + --port 8100 \ + --enforce-eager \ + --disable-log-requests \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' & + +# Decode instance. +CUDA_VISIBLE_DEVICES=1 NIXL_ROLE="RECVER" vllm serve $MODEL_NAME \ + --port 8200 \ + --enforce-eager \ + --disable-log-requests \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' & + +# wait until prefill and decode instances are ready +wait_for_server 8100 +wait_for_server 8200 + +# Proxy server. +python toy_proxy_server.py --port 8192 & + +# Run lm eval. +python3 -m pytest -s -x test_accuracy.py diff --git a/tests/v1/kv_connector/test_accuracy.py b/tests/v1/kv_connector/test_accuracy.py new file mode 100644 index 000000000000..60878a664eb9 --- /dev/null +++ b/tests/v1/kv_connector/test_accuracy.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +import lm_eval + +MODEL_NAME = "Qwen/Qwen3-0.6B" +NUM_CONCURRENT = 100 +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 +EXPECTED_VALUE = 0.41 + + +def test_accuracy(): + """Run the end to end accuracy test.""" + + model_args = (f"model={MODEL_NAME}," + f"base_url=http://localhost:8192/v1/completions," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/v1/kv_connector/toy_proxy_server.py b/tests/v1/kv_connector/toy_proxy_server.py new file mode 100644 index 000000000000..89e3c4493fdb --- /dev/null +++ b/tests/v1/kv_connector/toy_proxy_server.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' + decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict, request_id: str): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['do_remote_decode'] = True + req_data["stream"] = False + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict, remote_block_ids: list[int], + remote_engine_id: str, request_id: str): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + req_data['do_remote_prefill'] = True + req_data["remote_block_ids"] = remote_block_ids + req_data['remote_engine_id'] = remote_engine_id + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + try: + req_data = await request.json() + + request_id = str(uuid.uuid4()) + + # Send request to prefill service + response = await send_request_to_service(app.state.prefill_client, + "/completions", req_data, + request_id) + + # Extract the needed fields + response_json = response.json() + remote_block_ids = response_json.get('remote_block_ids', []) + remote_engine_id = response_json.get('remote_engine_id', '') + + # Add these to the request data for the decoder + req_data['remote_block_ids'] = remote_block_ids + req_data['remote_engine_id'] = remote_engine_id + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response( + app.state.decode_client, + "/completions", + req_data, + remote_block_ids=remote_block_ids, + remote_engine_id=remote_engine_id, + request_id=request_id): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) From 1c45ed1c9efc32975dd6bb2682f8d679d2454c91 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Thu, 1 May 2025 16:34:48 -0400 Subject: [PATCH 099/119] Preemption Bugfixes (#15) * stash fixed double free issue Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * fixed issue Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updatrd Signed-off-by: rshaw@neuralmagic.com * updatrd Signed-off-by: rshaw@neuralmagic.com * updatrd Signed-off-by: rshaw@neuralmagic.com * updatrd Signed-off-by: rshaw@neuralmagic.com * updatrd Signed-off-by: rshaw@neuralmagic.com * updatrd Signed-off-by: rshaw@neuralmagic.com --------- Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/nixl_connector.py | 16 +++++++--------- vllm/v1/core/sched/scheduler.py | 2 ++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index cdef95f18b98..a9d4606c4f60 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -172,16 +172,12 @@ def get_num_new_matched_tokens(self, request: "Request", assert num_computed_tokens % self.block_size == 0 if request.do_remote_prefill: - # NOTE: subtract 1 since we compute the last token - # here so that we can sample the first token. - num_prompt_tokens = len(request.prompt_token_ids) - 1 - - # Round down to a full block shape. - num_external_blocks = num_prompt_tokens // self.block_size + num_external_blocks = len( + request.prompt_token_ids) // self.block_size rounded_num_prompt_tokens = num_external_blocks * self.block_size return max(rounded_num_prompt_tokens - num_computed_tokens, 0) - else: - return 0 + + return 0 def update_state_after_alloc(self, request: "Request", block_ids: list[int], @@ -310,7 +306,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # For debug, SENDER puts some stuff in the KV caches # so the RECVER can check it - n_blocks_to_send = 4096 + n_blocks_to_send = min(4096, kv_caches[first_layer_name].shape[1]) debug_xfer_gb = 2.0 * n_blocks_to_send * self.block_len / 1e9 print(f"gb {debug_xfer_gb} -- block_len {self.block_len}") if NIXL_ROLE == "SENDER": @@ -581,6 +577,8 @@ def _read_blocks( if len(remote_block_ids) < len(local_block_ids): local_block_ids = local_block_ids[:len(remote_block_ids)] assert len(local_block_ids) == len(remote_block_ids) + + # NOTE(rob): this can cause the remote blocks to not be freed? if len(local_block_ids) == 0: return diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 82759d28b20c..0b72ea390464 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -375,6 +375,8 @@ def schedule(self) -> SchedulerOutput: [b.block_id for b in computed_blocks + new_blocks], num_external_tokens, ) + # We should only trigger a KV transfer once per request. + request.do_remote_prefill = False continue # Number of tokens to be scheduled. From a45a6947d3ef625153073b002a4e6204693080bb Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Thu, 1 May 2025 17:09:17 -0400 Subject: [PATCH 100/119] updated (#16) Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/kv_cache_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 41d9f1b65c23..33cc33beb2ae 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -291,6 +291,7 @@ def allocate_slots( # cache a block after is has finished recving. self.num_cached_block[request.request_id] = len( new_computed_blocks) + return new_blocks self.cache_blocks( request=request, From 2f9a3f3f4b93f6c2167aa107d2f2e160264bcdf3 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Thu, 1 May 2025 20:27:46 -0400 Subject: [PATCH 101/119] Fix Bad Merge | Fix Memory Leak in Upstream (#18) * updated Signed-off-by: rshaw@neuralmagic.com * fix merge Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com * updated Signed-off-by: rshaw@neuralmagic.com --------- Signed-off-by: rshaw@neuralmagic.com --- tests/v1/kv_connector/utils.py | 42 ++++++++++++++------------------- vllm/config.py | 19 ++++++++++++--- vllm/v1/core/sched/scheduler.py | 10 ++++++-- 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index 825e6186b5ae..409b5ac69e97 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Optional + import torch from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, @@ -14,17 +15,17 @@ EOS_TOKEN_ID = 50256 + def assert_scheduler_empty(scheduler: Scheduler): """Confirm the scheduler is "empty" - i.e. no leaks.""" # Scheduler Metadata. assert len(scheduler.requests) == 0 assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 0 - assert len(scheduler.scheduled_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_recving_KV_req_ids) == 0 assert len(scheduler._cached_reqs_data) == 0 - + # EncoderCacheManager. assert len(scheduler.encoder_cache_manager.freed) == 0 assert len(scheduler.encoder_cache_manager.cached) == 0 @@ -37,15 +38,16 @@ def assert_scheduler_empty(scheduler: Scheduler): scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) assert num_free_blocks == ( scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) - + # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. for block in scheduler.kv_cache_manager.block_pool.blocks: assert block.ref_cnt == 0 # assert block._block_hash is None # assert ( - # len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block) == 0) - + # len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block + # ) == 0) + def create_vllm_config( model: str = "facebook/opt-125m", @@ -80,13 +82,11 @@ def create_vllm_config( kv_connector="NixlConnector", kv_role="kv_both", ) - return VllmConfig( - scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - device_config=DeviceConfig("cpu") - ) + return VllmConfig(scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu")) def create_scheduler( @@ -125,14 +125,12 @@ def create_request( if do_remote_decode: assert not do_remote_prefill - kv_transfer_params = KVTransferParams( - do_remote_decode=True - ) + kv_transfer_params = KVTransferParams(do_remote_decode=True) elif do_remote_prefill: kv_transfer_params = KVTransferParams( do_remote_prefill=True, remote_engine_id="remote_engine_id", - remote_block_ids=[1,2,3], + remote_block_ids=[1, 2, 3], ) else: kv_transfer_params = None @@ -145,13 +143,10 @@ def create_request( if use_all_1s_for_prompt_tokens: prompt_token_ids = [1] * num_tokens else: - prompt_token_ids = [ - i * request_id for i in range(num_tokens) - ] - + prompt_token_ids = [i * request_id for i in range(num_tokens)] + return Request( request_id=f"id-{request_id}", - prompt=None, prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, multi_modal_inputs=None, @@ -161,6 +156,7 @@ def create_request( arrival_time=0, ) + def create_model_runner_output( reqs: list[Request], finished_sending: Optional[list[str]] = None, @@ -171,9 +167,7 @@ def create_model_runner_output( # Make request data. req_ids = [req.request_id for req in reqs] - req_id_to_index = { - req_id: idx for idx, req_id in enumerate(req_ids) - } + req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)} # Make sampled tokens. sampled_token = EOS_TOKEN_ID if use_eos else 0 diff --git a/vllm/config.py b/vllm/config.py index 09c2faf5e7bf..7fde4880e473 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1904,6 +1904,13 @@ class SchedulerConfig: """Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt.""" + cuda_graph_sizes: list[int] = field(default_factory=lambda: [512]) + """Cuda graph capture sizes, default is 512. + 1. if one value is provided, then the capture list would follow the pattern: + [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] + 2. more than one value (e.g. 1 2 128) is provided, + then the capture list will follow the provided list.""" + enable_chunked_prefill: bool = None # type: ignore """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -4241,13 +4248,19 @@ def _set_cudagraph_sizes(self): batch_size_capture_list = [] if self.model_config is not None and \ not self.model_config.enforce_eager: - batch_size_capture_list = [1, 2, 4 - ] + [i for i in range(8, 513, 8)] + cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes + if len(cuda_graph_sizes) == 1: + batch_size_capture_list = [1, 2, 4] + [ + i for i in range(8, cuda_graph_sizes[0] + 1, 8) + ] + elif len(cuda_graph_sizes) > 1: + batch_size_capture_list = sorted(cuda_graph_sizes) + else: + raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") if self.parallel_config.tensor_parallel_size > 1 and \ self.compilation_config.pass_config.enable_sequence_parallelism: batch_size_capture_list = \ self.update_sizes_for_sequence_parallelism(batch_size_capture_list) - max_num_tokens = self.scheduler_config.max_num_batched_tokens batch_size_capture_list = [ size for size in batch_size_capture_list diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7a62147523e7..ae098352f1e5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -705,6 +705,7 @@ def update_from_output( prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens + stopped_set: set[str] = set() new_running: list[Request] = [] outputs: list[EngineCoreOutput] = [] spec_decoding_stats: Optional[SpecDecodingStats] = None @@ -849,6 +850,8 @@ def update_from_output( if not stopped: new_running.append(request) + else: + stopped_set.add(request.request_id) # P/D: update recv and send status from last step. for req_id in (model_runner_output.finished_recving or []): @@ -858,9 +861,12 @@ def update_from_output( logger.debug("FINISHED SENDING: %s", req_id) self._free_blocks(self.requests[req_id]) - # Return the cached request data to the queue so they can be reused. + # Return the cached request data to the queue so they can + # be reused. Note: we cannot add stopped requests to this + # since they are already freed above! for req_data in scheduler_output.scheduled_cached_reqs: - self._cached_reqs_data[req_data.req_id].append(req_data) + if req_data.req_id not in stopped_set: + self._cached_reqs_data[req_data.req_id].append(req_data) self.running = new_running engine_core_outputs = EngineCoreOutputs( From 113527b2e7b73d6884254ce77c1353c2205910f4 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 14:12:28 +0000 Subject: [PATCH 102/119] clean up justfile, examples Signed-off-by: Tyler Michael Smith --- Justfile | 38 ------------------- examples/offline_inference/basic/basic.py | 13 +++---- .../openai_completion_client.py | 19 +++++----- 3 files changed, 14 insertions(+), 56 deletions(-) delete mode 100644 Justfile diff --git a/Justfile b/Justfile deleted file mode 100644 index 771d1afec13c..000000000000 --- a/Justfile +++ /dev/null @@ -1,38 +0,0 @@ -prefill: - UCX_LOG_LEVEL=debug \ - NIXL_ROLE="SENDER" \ - CUDA_VISIBLE_DEVICES=3 \ - VLLM_LOGGING_LEVEL="DEBUG" \ - VLLM_WORKER_MULTIPROC_METHOD=spawn \ - VLLM_ENABLE_V1_MULTIPROCESSING=0 \ - vllm serve meta-llama/Llama-3.1-8B-Instruct \ - --port 8100 \ - --enforce-eager \ - --disable-log-requests \ - --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' - -decode: - UCX_LOG_LEVEL=info \ - NIXL_ROLE="RECVER" \ - CUDA_VISIBLE_DEVICES=4 \ - VLLM_LOGGING_LEVEL="DEBUG" \ - VLLM_WORKER_MULTIPROC_METHOD=spawn \ - VLLM_ENABLE_V1_MULTIPROCESSING=0 \ - vllm serve meta-llama/Llama-3.1-8B-Instruct \ - --port 8200 \ - --enforce-eager \ - --disable-log-requests \ - --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' - -proxy: - python examples/disagg_proxy_server.py --port 8192 - -send_request: - curl -X POST http://localhost:8192/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ \ - "model": "meta-llama/Llama-3.1-8B-Instruct", \ - "prompt": "EXPLAIN KERMIT THE FROG", \ - "max_tokens": 150, \ - "temperature": 0.7 \ - }' diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 60148bfd62c8..ae5ae7cb4834 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -4,10 +4,10 @@ # Sample prompts. prompts = [ - "Hello, my name is Robert and I work for Red Hat software", - "The president of the United States is Joe Biden who is ", - "The capital of France is different from the capital of USA because", - "The future of AI is open source because there is a race to the bottom", + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", ] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) @@ -15,10 +15,7 @@ def main(): # Create an LLM. - llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", - enforce_eager=True, - max_num_batched_tokens=16, - max_num_seqs=8) + llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index 1f8a7e5b078c..6ab7619bff19 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -4,10 +4,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8192/v1" - -PROMPT = "The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, which means that there is a lot of opportunity to work with community and customers on key infrastructure projects. This means", # noqa: E501 -PROMPT = "The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, " # noqa: E501 +openai_api_base = "http://localhost:8000/v1" def main(): @@ -17,16 +14,18 @@ def main(): base_url=openai_api_base, ) - # models = client.models.list() - # model = models.data[0].id + models = client.models.list() + model = models.data[0].id # Completion API - stream = True + stream = False completion = client.completions.create( - model="meta-llama/Llama-3.1-8B-Instruct", - prompt=PROMPT, + model=model, + prompt="A robot may not injure a human being", echo=False, - stream=stream) + n=2, + stream=stream, + logprobs=3) print("-" * 50) print("Completion results:") From 5f8b280b4418e4ae690edff0cc496d7ad3506ae3 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 17:05:47 +0000 Subject: [PATCH 103/119] more cleanup Signed-off-by: Tyler Michael Smith --- examples/proxy_example.sh | 75 --------------------------------------- 1 file changed, 75 deletions(-) delete mode 100644 examples/proxy_example.sh diff --git a/examples/proxy_example.sh b/examples/proxy_example.sh deleted file mode 100644 index 029291ecaaee..000000000000 --- a/examples/proxy_example.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -if [[ $# -lt 1 ]]; then - echo "Usage: $0 [model]" - exit 1 -fi - -if [[ $# -eq 1 ]]; then - echo "Using default model: meta-llama/Llama-3.1-8B-Instruct" - MODEL="meta-llama/Llama-3.1-8B-Instruct" -else - echo "Using model: $2" - MODEL=$2 -fi - - -if [[ $1 == "prefill" ]]; then - - UCX_TLS=cuda_ipc,cuda_copy,tcp \ - VLLM_WORKER_MULTIPROC_METHOD=spawn \ - vllm serve Qwen/Qwen2.5-1.5B-Instruct \ - --port 8100 \ - --enforce-eager \ - --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' - - UCX_TLS=cuda_ipc,cuda_copy,tcp \ - LMCACHE_CONFIG_FILE=$prefill_config_file \ - LMCACHE_USE_EXPERIMENTAL=True \ - VLLM_ENABLE_V1_MULTIPROCESSING=1 \ - VLLM_WORKER_MULTIPROC_METHOD=spawn \ - vllm serve $MODEL \ - --port 8100 \ - --enforce-eager \ - --kv-transfer-config \ - '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' - # Potential Env vars and cmdline options - # LMCACHE_LOG_LEVEL=DEBUG -- Set log level to DEBUG - # --enforce-eager -- Enforce eager mode - -elif [[ $1 == "decode" ]]; then - # Decoder listens on port 8200 - decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml - - UCX_TLS=cuda_ipc,cuda_copy,tcp \ - LMCACHE_CONFIG_FILE=$decode_config_file \ - LMCACHE_USE_EXPERIMENTAL=True \ - VLLM_ENABLE_V1_MULTIPROCESSING=1 \ - VLLM_WORKER_MULTIPROC_METHOD=spawn \ - vllm serve $MODEL \ - --port 8200 \ - --enforce-eager \ - --kv-transfer-config \ - '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' - - # Potential Env vars and cmdline options - # LMCACHE_LOG_LEVEL=DEBUG -- Set log level to DEBUG - # --enforce-eager -- Enforce eager mode - -elif [[ $1 == "proxy" ]]; then - # Proxy listens on port 9000 - python3 $SCRIPT_DIR/disagg_proxy_server.py \ - --host localhost \ - --port 9000 \ - --prefiller-host localhost \ - --prefiller-port 8100 \ - --decoder-host localhost \ - --decoder-port 8200 - -else - echo "Invalid role: $1" - echo "Should be either prefill, decode, or proxy" - exit 1 -fi From 79e7b2a54364050f69dba0b18b793bb37ba67188 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 17:07:17 +0000 Subject: [PATCH 104/119] more cleanup Signed-off-by: Tyler Michael Smith --- examples/disagg_proxy_server.py | 219 -------------------------------- 1 file changed, 219 deletions(-) delete mode 100644 examples/disagg_proxy_server.py diff --git a/examples/disagg_proxy_server.py b/examples/disagg_proxy_server.py deleted file mode 100644 index 85365ef9b6c8..000000000000 --- a/examples/disagg_proxy_server.py +++ /dev/null @@ -1,219 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import argparse -import os -import time -import uuid -from contextlib import asynccontextmanager - -import httpx -import numpy as np -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """ - Lifespan context manager to handle startup and shutdown events. - """ - # Startup: Initialize clients - prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' - decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' - - app.state.prefill_client = httpx.AsyncClient(timeout=None, - base_url=prefiller_base_url) - app.state.decode_client = httpx.AsyncClient(timeout=None, - base_url=decoder_base_url) - - yield - - # Shutdown: Close clients - await app.state.prefill_client.aclose() - await app.state.decode_client.aclose() - - -# Update FastAPI app initialization to use lifespan -app = FastAPI(lifespan=lifespan) - - -class StatsCalculator: - - def __init__(self): - self._stats = [] - self._last_log_time = time.time() - - def add(self, value): - self._stats.append(value) - if time.time() - self._last_log_time > 5: - self._log_stats() - self._last_log_time = time.time() - - def _log_stats(self): - # Print average, median, and 99th percentile - np_arr = np.array(self._stats) - output_str = f"\nNum requests: {len(self._stats)}" + \ - "\nPrefill node TTFT stats:" + \ - f"\n - Average (ms): {np.mean(np_arr)}" + \ - f"\n - Median (ms): {np.median(np_arr)}" + \ - f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" - print("===============================", output_str, - "===============================") - - -stats_calculator = StatsCalculator() -counter = 0 - - -def parse_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--prefiller-host", type=str, default="localhost") - parser.add_argument("--prefiller-port", type=int, default=8100) - parser.add_argument("--decoder-host", type=str, default="localhost") - parser.add_argument("--decoder-port", type=int, default=8200) - args = parser.parse_args() - return args - - -# Initialize variables to hold the persistent clients -app.state.prefill_client = None -app.state.decode_client = None - - -async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, - req_data: dict, request_id: str): - """ - Send a request to a service using a persistent client. - """ - req_data = req_data.copy() - req_data['do_remote_decode'] = True - req_data["stream"] = False - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id - } - response = await client.post(endpoint, json=req_data, headers=headers) - response.raise_for_status() - - return response - - -async def stream_service_response(client: httpx.AsyncClient, endpoint: str, - req_data: dict, remote_block_ids: list[int], - remote_engine_id: str, request_id: str): - """ - Asynchronously stream the response from a service using a persistent client. - """ - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id - } - req_data['do_remote_prefill'] = True - req_data["remote_block_ids"] = remote_block_ids - req_data['remote_engine_id'] = remote_engine_id - async with client.stream("POST", endpoint, json=req_data, - headers=headers) as response: - response.raise_for_status() - async for chunk in response.aiter_bytes(): - yield chunk - - -@app.post("/v1/completions") -async def handle_completions(request: Request): - global counter, stats_calculator - counter += 1 - - st = time.time() - try: - req_data = await request.json() - - request_id = str(uuid.uuid4()) - - # Send request to prefill service - response = await send_request_to_service( - app.state.prefill_client, "/completions", req_data, request_id) - - # Extract the needed fields - response_json = response.json() - remote_block_ids = response_json.get('remote_block_ids', []) - remote_engine_id = response_json.get('remote_engine_id', '') - print("Prefiller response:\n" + str(response_json)) - - # Add these to the request data for the decoder - req_data['remote_block_ids'] = remote_block_ids - req_data['remote_engine_id'] = remote_engine_id - - et = time.time() - stats_calculator.add(et - st) - - # Stream response from decode service - async def generate_stream(): - async for chunk in stream_service_response( - app.state.decode_client, - "/completions", - req_data, - remote_block_ids=remote_block_ids, - remote_engine_id=remote_engine_id, - request_id=request_id): - yield chunk - - return StreamingResponse(generate_stream(), - media_type="application/json") - - except Exception as e: - import sys - import traceback - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server" - " - completions endpoint") - print(e) - print("".join(traceback.format_exception(*exc_info))) - raise - - -@app.post("/v1/chat/completions") -async def handle_chat_completions(request: Request): - global counter, stats_calculator - counter += 1 - - st = time.time() - try: - req_data = await request.json() - - # Send request to prefill service, ignore the response - await send_request_to_service(app.state.prefill_client, - "/chat/completions", req_data) - - et = time.time() - stats_calculator.add(et - st) - - # Stream response from decode service - async def generate_stream(): - async for chunk in stream_service_response(app.state.decode_client, - "/chat/completions", - req_data): - yield chunk - - return StreamingResponse(generate_stream(), - media_type="application/json") - - except Exception as e: - import sys - import traceback - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server " - " - chat completions endpoint") - print(e) - print("".join(traceback.format_exception(*exc_info))) - raise - - -if __name__ == '__main__': - global global_args - global_args = parse_args() - - import uvicorn - uvicorn.run(app, host=global_args.host, port=global_args.port) From e8ab67865a8fec08e2bc28c75a585ccc5821827f Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 17:08:12 +0000 Subject: [PATCH 105/119] more cleanup Signed-off-by: Tyler Michael Smith --- start_proxy.py | 227 ------------------------------------------------- 1 file changed, 227 deletions(-) delete mode 100644 start_proxy.py diff --git a/start_proxy.py b/start_proxy.py deleted file mode 100644 index 8e365e19bb22..000000000000 --- a/start_proxy.py +++ /dev/null @@ -1,227 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import argparse -import os -import time -from contextlib import asynccontextmanager - -import httpx -import numpy as np -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """ - Lifespan context manager to handle startup and shutdown events. - """ - # Startup: Initialize clients - prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' - decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' - - app.state.prefill_client = httpx.AsyncClient(timeout=None, - base_url=prefiller_base_url) - app.state.decode_client = httpx.AsyncClient(timeout=None, - base_url=decoder_base_url) - - yield - - # Shutdown: Close clients - await app.state.prefill_client.aclose() - await app.state.decode_client.aclose() - - -# Update FastAPI app initialization to use lifespan -app = FastAPI(lifespan=lifespan) - - -class StatsCalculator: - - def __init__(self): - self._stats = [] - self._last_log_time = time.time() - - def add(self, value): - self._stats.append(value) - if time.time() - self._last_log_time > 5: - self._log_stats() - self._last_log_time = time.time() - - def _log_stats(self): - # Print average, median, and 99th percentile - np_arr = np.array(self._stats) - output_str = f"\nNum requests: {len(self._stats)}" + \ - "\nPrefill node TTFT stats:" + \ - f"\n - Average (ms): {np.mean(np_arr)}" + \ - f"\n - Median (ms): {np.median(np_arr)}" + \ - f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" - print("===============================", output_str, - "===============================") - - -stats_calculator = StatsCalculator() -counter = 0 - - -def parse_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--port", type=int, default=8192) - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--prefiller-host", type=str, default="localhost") - parser.add_argument("--prefiller-port", type=int, default=8100) - parser.add_argument("--decoder-host", type=str, default="localhost") - parser.add_argument("--decoder-port", type=int, default=8200) - args = parser.parse_args() - return args - - -# Initialize variables to hold the persistent clients -app.state.prefill_client = None -app.state.decode_client = None - - -async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, - req_data: dict): - """ - Send a request to a service using a persistent client. - """ - req_data = req_data.copy() - req_data['do_remote_decode'] = True - req_data["stream"] = False - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": "vllm-d-debug", - } - response = await client.post(endpoint, json=req_data, headers=headers) - response.raise_for_status() - - return response - - -async def stream_service_response(client: httpx.AsyncClient, endpoint: str, - req_data: dict): - """ - Asynchronously stream the response from a service using a persistent client. - """ - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": "vllm-d-debug", - } - req_data['do_remote_prefill'] = True - async with client.stream("POST", endpoint, json=req_data, - headers=headers) as response: - response.raise_for_status() - async for chunk in response.aiter_bytes(): - yield chunk - -async def send_request_to_prefill_service(client: httpx.AsyncClient, endpoint: str, - req_data: dict): - """ - Send a request to a service using a persistent client. - """ - req_data = req_data.copy() - req_data['do_remote_decode'] = True - req_data["stream"] = False - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": "vllm-d-debug", - } - response = await client.post(endpoint, json=req_data, headers=headers) - response.raise_for_status() - - # Extract and print the actual response content - try: - response_json = response.json() - print(f"Prefill Request Content: {req_data}") - print(f"Prefill Response Content: {response_json}") - except Exception as e: - print(f"Could not parse prefill response as JSON: {e}") - print(f"Raw prefill response text: {response.text}") - - return response - - -@app.post("/v1/completions") -async def handle_completions(request: Request): - global counter, stats_calculator - counter += 1 - - st = time.time() - try: - req_data = await request.json() - print(req_data) - - # Send request to prefill service, ignore the response - response = await send_request_to_prefill_service(app.state.prefill_client, "/completions", - req_data) - print(response) - - et = time.time() - stats_calculator.add(et - st) - - # Stream response from decode service - async def generate_stream(): - async for chunk in stream_service_response(app.state.decode_client, - "/completions", - req_data): - yield chunk - - return StreamingResponse(generate_stream(), - media_type="application/json") - - except Exception as e: - import sys - import traceback - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server" - " - completions endpoint") - print(e) - print("".join(traceback.format_exception(*exc_info))) - raise - - -@app.post("/v1/chat/completions") -async def handle_chat_completions(request: Request): - global counter, stats_calculator - counter += 1 - - st = time.time() - try: - req_data = await request.json() - - # Send request to prefill service, ignore the response - await send_request_to_service(app.state.prefill_client, - "/chat/completions", req_data) - - et = time.time() - stats_calculator.add(et - st) - - # Stream response from decode service - async def generate_stream(): - async for chunk in stream_service_response(app.state.decode_client, - "/chat/completions", - req_data): - yield chunk - - return StreamingResponse(generate_stream(), - media_type="application/json") - - except Exception as e: - import sys - import traceback - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server " - " - chat completions endpoint") - print(e) - print("".join(traceback.format_exception(*exc_info))) - raise - - -if __name__ == '__main__': - global global_args - global_args = parse_args() - - import uvicorn - uvicorn.run(app, host=global_args.host, port=global_args.port) From 969daa9d56290161664d34c3beb51c8553b2e9e2 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 17:09:19 +0000 Subject: [PATCH 106/119] more cleanup Signed-off-by: Tyler Michael Smith --- .../kv_connector/v1/kv_rearrange.py | 119 ------------------ 1 file changed, 119 deletions(-) delete mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/kv_rearrange.py diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/kv_rearrange.py b/vllm/distributed/kv_transfer/kv_connector/v1/kv_rearrange.py deleted file mode 100644 index c55e4de8f7d9..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/v1/kv_rearrange.py +++ /dev/null @@ -1,119 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def rearrange_kernel_read( - t1_ptr, - t2_ptr, - N, - B, - H, - C, - d, - tensor_subset_size, - block_size, - token_size, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - curr_n = offsets // block_size - curr_b = offsets // token_size % B - curr_h = offsets // C % H - curr_c = offsets % C - - src_pos = offsets - - tp_group = curr_h * d // H - dst_h = curr_h % (H // d) - tp_group_offset = curr_n * (block_size // - d) + curr_b * (H // d) * C + dst_h * C + curr_c - - dst_pos = tensor_subset_size * tp_group + tp_group_offset - - tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos)) - - -@triton.jit -def rearrange_kernel_write( - t1_ptr, - t2_ptr, - N, - B, - H, - C, - d, - tensor_subset_size, - block_size, - token_size, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - curr_n = offsets // block_size - curr_b = offsets // token_size % B - curr_h = offsets // C % H - curr_c = offsets % C - - src_pos = offsets - - tp_group = curr_h * d // H - dst_h = curr_h % (H // d) - tp_group_offset = curr_n * (block_size // - d) + curr_b * (H // d) * C + dst_h * C + curr_c - - dst_pos = tensor_subset_size * tp_group + tp_group_offset - - tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos)) - - -def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, - direction: str): - N, B, H, C = t1.shape - - assert t2.shape == (N, B, H, - C), "Destination tensor must have same shape as source" - assert H % d == 0, "H must be divisible by d" - - block_size = B * H * C - token_size = H * C - tensor_size = N * block_size - tensor_subset_size = tensor_size // d - - BLOCK_SIZE = 1024 - grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE, ) - - if direction == "read": - rearrange_kernel_read[grid](t1, - t2, - N, - B, - H, - C, - d, - tensor_subset_size, - block_size, - token_size, - BLOCK_SIZE=BLOCK_SIZE) - elif direction == "write": - rearrange_kernel_write[grid](t1, - t2, - N, - B, - H, - C, - d, - tensor_subset_size, - block_size, - token_size, - BLOCK_SIZE=BLOCK_SIZE) - else: - raise ValueError(f"Invalid direction: {direction}") From f9a3f3a95da4f44d40bd51186fb6281a400eac4b Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 17:16:25 +0000 Subject: [PATCH 107/119] More cleanup Signed-off-by: Tyler Michael Smith --- vllm/config.py | 2 +- vllm/platforms/cpu.py | 3 +-- vllm/v1/worker/gpu_model_runner.py | 3 --- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 7fde4880e473..94e8f9990348 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1638,7 +1638,7 @@ def data_parallel_rank_local(self, value: int) -> None: """Use expert parallelism instead of tensor parallelism for MoE layers.""" max_parallel_loading_workers: Optional[int] = None - """Maximum number of parallal loading workers when loading model + """Maximum number of parallel loading workers when loading model sequentially in multiple batches. To avoid RAM OOM when using tensor parallel and large models.""" diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 47a48126ed5c..70553354a060 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -43,8 +43,7 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, logger.info("Using CPU MLA backend.") return "vllm.attention.backends.cpu_mla.CPUMLABackend" logger.info("Using Torch SDPA backend.") - return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - # return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" + return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8c3857bb6029..403a288f9e90 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -401,9 +401,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the states of the running/resumed requests. for req_data in scheduler_output.scheduled_cached_reqs: req_id = req_data.req_id - if req_id not in self.requests: - print(f"{req_id} {self.requests}") - continue req_state = self.requests[req_id] # Update the cached states. From aec447cc22e9779c15d662ec788a02ae57fd6780 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 17:26:45 +0000 Subject: [PATCH 108/119] more cleanup Signed-off-by: Tyler Michael Smith --- tests/v1/kv_connector/run_accuracy_test.sh | 10 ---------- vllm/config.py | 8 ++++---- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/v1/kv_connector/run_accuracy_test.sh b/tests/v1/kv_connector/run_accuracy_test.sh index 0aab60e4adca..3ff1bbe1c162 100644 --- a/tests/v1/kv_connector/run_accuracy_test.sh +++ b/tests/v1/kv_connector/run_accuracy_test.sh @@ -8,16 +8,6 @@ MODEL_NAME=Qwen/Qwen3-0.6B # Trap the SIGINT signal (triggered by Ctrl+C) trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT -# Cleanup function -cleanup() { - echo "Caught Ctrl+C, cleaning up..." - # Cleanup commands - pgrep python | xargs kill -9 - pkill -f python - echo "Cleanup complete. Exiting." - exit 0 -} - # Waits for vLLM to start. wait_for_server() { local port=$1 diff --git a/vllm/config.py b/vllm/config.py index 94e8f9990348..cd56b842c8be 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1900,10 +1900,6 @@ class SchedulerConfig: NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.""" - delay_factor: float = 0.0 - """Apply a delay (of delay factor multiplied by previous - prompt latency) before scheduling next prompt.""" - cuda_graph_sizes: list[int] = field(default_factory=lambda: [512]) """Cuda graph capture sizes, default is 512. 1. if one value is provided, then the capture list would follow the pattern: @@ -1911,6 +1907,10 @@ class SchedulerConfig: 2. more than one value (e.g. 1 2 128) is provided, then the capture list will follow the provided list.""" + delay_factor: float = 0.0 + """Apply a delay (of delay factor multiplied by previous + prompt latency) before scheduling next prompt.""" + enable_chunked_prefill: bool = None # type: ignore """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" From ac68a75a1ba87fa872a173a7113364dbe0a5b078 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 19:33:36 +0000 Subject: [PATCH 109/119] more cleanup, precommit fixes Signed-off-by: Tyler Michael Smith --- tests/v1/core/test_scheduler.py | 138 +++++++++++++++++- .../kv_transfer/kv_connector/v1/base.py | 9 +- .../kv_connector/v1/lmcache_connector.py | 1 + .../v1/shared_storage_connector.py | 1 + 4 files changed, 140 insertions(+), 9 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 87673ad3ddc2..ee4e95856f23 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -3,13 +3,145 @@ from unittest.mock import Mock import pytest +import torch -from vllm.multimodal.inputs import PlaceholderRange -from vllm.tests.v1.utils import EOS_TOKEN_ID, create_requests, create_scheduler +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, SpeculativeConfig, VllmConfig) +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.request import RequestStatus +from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager + +EOS_TOKEN_ID = 50256 + + +def create_scheduler( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_prefix_caching: Optional[bool] = None, + long_prefill_token_threshold: int = 0, + disable_chunked_mm_input: bool = False, + use_kv_connector: bool = False, + num_blocks: int = 10000, + block_size: int = 16, + max_model_len: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, +) -> Scheduler: + '''Create scheduler under test. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + :class:`Scheduler` instance + ''' + if max_model_len is None: + max_model_len = max_num_batched_tokens + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + long_prefill_token_threshold=long_prefill_token_threshold, + disable_chunked_mm_input=disable_chunked_mm_input, + enable_chunked_prefill=True, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + kwargs_cache = ({} if enable_prefix_caching is None else { + 'enable_prefix_caching': enable_prefix_caching + }) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + + speculative_config: Optional[SpeculativeConfig] = None + if num_speculative_tokens is not None: + speculative_config = SpeculativeConfig( + model="ngram", num_speculative_tokens=num_speculative_tokens) + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + speculative_config=speculative_config, + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_requests(num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None): + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + requests.append(request) + return requests def test_add_requests(): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 95d3dfb7c841..dfc49f5adde7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -61,11 +61,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self._connector_metadata = KVConnectorMetadata() self._vllm_config = vllm_config self._role = role - - def register_kv_caches( - self, - kv_caches: dict[str, torch.Tensor] - ): + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). @@ -73,7 +70,7 @@ def register_kv_caches( Args: kv_caches: dictionary of layer names, kv cache """ - pass + return @property def role(self) -> KVConnectorRole: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index e07f185f0dd8..8f86b72f9cff 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -110,6 +110,7 @@ def get_num_new_matched_tokens( request, num_computed_tokens) def update_state_after_alloc(self, request: "Request", + block_ids: list[int], num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index f91ffbc720e7..94bf53d90c91 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -260,6 +260,7 @@ def get_num_new_matched_tokens( return num_tokens_to_check - num_computed_tokens def update_state_after_alloc(self, request: "Request", + block_ids: list[int], num_external_tokens: int): """ Update KVConnector state after block allocation. From e72245b423f312bb01709951e4bc7a83766a2768 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 19:43:20 +0000 Subject: [PATCH 110/119] More cleanup Signed-off-by: Tyler Michael Smith --- vllm/entrypoints/openai/serving_completion.py | 11 ++++++----- vllm/v1/core/kv_cache_manager.py | 4 ++-- vllm/v1/core/sched/scheduler.py | 9 ++------- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index f011581b8909..b8d463357387 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,6 @@ import jinja2 from fastapi import Request -from torch._C import NoneType from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -478,11 +477,13 @@ def request_output_to_completion_response( request_metadata.final_usage_info = usage if final_res_batch[0].kv_transfer_params is not None: - remote_engine_id=final_res_batch[0].kv_transfer_params.remote_engine_id - remote_block_ids=final_res_batch[0].kv_transfer_params.remote_block_ids + remote_engine_id = final_res_batch[ + 0].kv_transfer_params.remote_engine_id + remote_block_ids = final_res_batch[ + 0].kv_transfer_params.remote_block_ids else: - remote_engine_id=None - remote_block_ids=None + remote_engine_id = None + remote_block_ids = None assert len(final_res_batch) == 1 return CompletionResponse( diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index e78bf2527270..56b0b0796408 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -179,8 +179,8 @@ def allocate_slots( num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such as eagle. - skip_cache_blocks: Whether to skip cachings the blocks. This is - used by P/D when allocating blocks that used in KV transfer + skip_cache_blocks: Whether to skip caching the blocks. This is + used by P/D when allocating blocks used in a KV transfer which will complete in a future step. Blocks layout: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ae098352f1e5..c31b74e8eb07 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -721,11 +721,6 @@ def update_from_output( new_running.append(request) continue - if req_id not in model_runner_output.req_id_to_index: - print(req_id) - print(model_runner_output.req_id_to_index) - continue - req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] @@ -855,10 +850,10 @@ def update_from_output( # P/D: update recv and send status from last step. for req_id in (model_runner_output.finished_recving or []): - logger.debug("FINISHED RECVING: %s", req_id) + logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_KV_req_ids.add(req_id) for req_id in (model_runner_output.finished_sending or []): - logger.debug("FINISHED SENDING: %s", req_id) + logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) # Return the cached request data to the queue so they can From cd2aa72257380705563dbb0685cb2e92e5ef8a3e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 19:55:57 +0000 Subject: [PATCH 111/119] run_accuracy_test.sh UX Signed-off-by: Tyler Michael Smith --- tests/v1/kv_connector/run_accuracy_test.sh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/run_accuracy_test.sh b/tests/v1/kv_connector/run_accuracy_test.sh index 3ff1bbe1c162..9679a070525f 100644 --- a/tests/v1/kv_connector/run_accuracy_test.sh +++ b/tests/v1/kv_connector/run_accuracy_test.sh @@ -5,6 +5,9 @@ set -xe # Model to run. MODEL_NAME=Qwen/Qwen3-0.6B +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + # Trap the SIGINT signal (triggered by Ctrl+C) trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT @@ -36,7 +39,7 @@ wait_for_server 8100 wait_for_server 8200 # Proxy server. -python toy_proxy_server.py --port 8192 & +python ${GIT_ROOT}/tests/v1/kv_connector/toy_proxy_server.py --port 8192 & # Run lm eval. -python3 -m pytest -s -x test_accuracy.py +python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/test_accuracy.py From 10183d5630f7c5c6d26970d1b2349950a2a211c5 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 20:10:44 +0000 Subject: [PATCH 112/119] squash warnings Signed-off-by: Tyler Michael Smith --- vllm/entrypoints/openai/protocol.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 9868db8d61a5..4fb82f38e477 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -872,11 +872,13 @@ class CompletionRequest(OpenAIBaseModel): default=False, description="KVTransfer parameters used for disaggregated serving.") - remote_engine_id: Optional[str] = Field(default=None, - description="Remote engine id.") + remote_engine_id: Optional[str] = Field( + default=None, + description="Remote engine id for prefill-decode disaggregation.") remote_block_ids: Optional[list[int]] = Field( - default=None, description="Remote block ids.") + default=None, + description="Remote block ids for prefill-decode disaggregation.") # doc: end-completion-extra-params @@ -1262,6 +1264,12 @@ class CompletionResponse(OpenAIBaseModel): model: str choices: list[CompletionResponseChoice] usage: UsageInfo + remote_engine_id: Optional[str] = Field( + default=None, + description="Remote engine id for prefill-decode disaggregation.") + remote_block_ids: Optional[list[int]] = Field( + default=None, + description="Remote block ids for prefill-decode disaggregation.") class CompletionResponseStreamChoice(OpenAIBaseModel): From 9eb978713a833c90d4ea293d868578421c92c950 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 20:14:10 +0000 Subject: [PATCH 113/119] pre-commit Signed-off-by: Tyler Michael Smith --- tests/v1/kv_connector/test_nixl_connector.py | 15 ++- .../test_remote_decode_lifecycle.py | 18 ++-- .../test_remote_prefill_lifecycle.py | 98 ++++++++----------- 3 files changed, 57 insertions(+), 74 deletions(-) diff --git a/tests/v1/kv_connector/test_nixl_connector.py b/tests/v1/kv_connector/test_nixl_connector.py index b746978907ea..684823408c94 100644 --- a/tests/v1/kv_connector/test_nixl_connector.py +++ b/tests/v1/kv_connector/test_nixl_connector.py @@ -1,14 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -import copy -from typing import Optional from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnectorMetadata) -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput -from vllm.v1.request import RequestStatus, Request from .utils import create_request, create_scheduler, create_vllm_config + def test_scheduler_worker_inferface(): vllm_config = create_vllm_config() @@ -18,7 +15,7 @@ def test_scheduler_worker_inferface(): BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - + request = create_request(request_id=1, num_tokens=NUM_TOKENS, do_remote_prefill=True) @@ -31,12 +28,12 @@ def test_scheduler_worker_inferface(): kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) - + assert len(kv_connector_metadata.requests) == 1 assert request_id in kv_connector_metadata.requests req_meta = kv_connector_metadata.requests[request_id] - + for block_id, block in zip( - req_meta.local_block_ids, - scheduler.kv_cache_manager.req_to_blocks[request_id]): + req_meta.local_block_ids, + scheduler.kv_cache_manager.req_to_blocks[request_id]): assert block_id == block.block_id diff --git a/tests/v1/kv_connector/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/test_remote_decode_lifecycle.py index c3c01d127754..bfe97efeb3ee 100644 --- a/tests/v1/kv_connector/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/test_remote_decode_lifecycle.py @@ -2,23 +2,23 @@ import copy from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT -from vllm.v1.request import RequestStatus, FinishReason +from vllm.v1.request import FinishReason, RequestStatus + +from .utils import (assert_scheduler_empty, create_model_runner_output, + create_request, create_scheduler, create_vllm_config) -from .utils import (create_request, create_scheduler, - create_vllm_config, create_model_runner_output, - assert_scheduler_empty) def test_basic_lifecycle(): """Test lifecycle of a Remote Decode request.""" vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) - + # 2 Full Blocks and 1 Half Block. BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - + request = create_request(request_id=1, num_tokens=NUM_TOKENS, do_remote_decode=True) @@ -36,9 +36,9 @@ def test_basic_lifecycle(): model_runner_output = create_model_runner_output(reqs=[request]) # (1c): update_from_output() - engine_core_outputs = scheduler.update_from_output( - scheduler_output, model_runner_output) - + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + # Ensure the request is finished after 1 tokens. assert request.is_finished() assert request.status == RequestStatus.FINISHED_REMOTE_DECODE diff --git a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py index 73315e6121ae..ba614c44e105 100644 --- a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py @@ -2,25 +2,25 @@ import copy from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT -from vllm.v1.request import RequestStatus, FinishReason +from vllm.v1.request import FinishReason, RequestStatus + +from .utils import (assert_scheduler_empty, create_model_runner_output, + create_request, create_scheduler, create_vllm_config) -from .utils import (create_request, create_scheduler, - create_vllm_config, create_model_runner_output, - assert_scheduler_empty) def test_basic_lifecycle(): """Test Remote Prefills Lifecycle.""" vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) - + # 2 Full Blocks and 1 Half Block. BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) START_FREE_BLOCK_QUEUE_SIZE = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) - + request = create_request(request_id=1, num_tokens=NUM_TOKENS, do_remote_prefill=True) @@ -44,7 +44,7 @@ def test_basic_lifecycle(): assert request in scheduler.waiting assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) assert (request.num_computed_tokens == 0) - + # ... but should have (uncached) blocks allocated to it. block_pool = scheduler.kv_cache_manager.block_pool assert (block_pool.free_block_queue.num_free_blocks @@ -57,8 +57,8 @@ def test_basic_lifecycle(): model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT # (1c): update_from_output() - engine_core_outputs = scheduler.update_from_output( - scheduler_output, model_runner_output) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) assert len(engine_core_outputs.outputs) == 0 # STEP (2): @@ -68,13 +68,12 @@ def test_basic_lifecycle(): assert len(scheduler.running) == 0 # (2b): forward(): request finishes recv. - model_runner_output = copy.deepcopy( - EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.finished_recving = [request_id] - # (2c): update_from_output(): - engine_core_outputs = scheduler.update_from_output( - scheduler_output, model_runner_output) + # (2c): update_from_output(): + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) assert len(scheduler.waiting) == 1 assert (request_id in scheduler.finished_recving_KV_req_ids) @@ -82,7 +81,7 @@ def test_basic_lifecycle(): # (3a): schedule(): this should actually schedule. scheduler_output = scheduler.schedule() assert len(scheduler.running) == 1 - + # Confirm the block are actually allocated. num_hashed_blocks = 0 for block in scheduler.kv_cache_manager.req_to_blocks[request_id]: @@ -104,10 +103,9 @@ def test_basic_lifecycle(): # Step (4): Hit EOS. scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output( - [request], use_eos=True) - engine_core_outputs = scheduler.update_from_output( - scheduler_output, model_runner_output) + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) scheduler.schedule() outputs = engine_core_outputs.outputs @@ -122,17 +120,15 @@ def test_interleaved_lifecycle(): vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) - + # 2 Full Blocks and 1 Half Block. BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - - request_remote = create_request( - request_id=1, - num_tokens=NUM_TOKENS, - do_remote_prefill=True - ) + + request_remote = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) request_local_a = create_request( request_id=2, num_tokens=NUM_TOKENS, @@ -147,11 +143,9 @@ def test_interleaved_lifecycle(): scheduler_output = scheduler.schedule() assert len(scheduler.running) == 1 - model_runner_output = create_model_runner_output( - [request_local_a]) - scheduler.update_from_output(scheduler_output, - model_runner_output) - + model_runner_output = create_model_runner_output([request_local_a]) + scheduler.update_from_output(scheduler_output, model_runner_output) + # STEP 2: Add a local and remote request. scheduler.add_request(request_local_b) scheduler.add_request(request_remote) @@ -163,8 +157,7 @@ def test_interleaved_lifecycle(): model_runner_output = create_model_runner_output( [request_local_a, request_local_b]) - scheduler.update_from_output(scheduler_output, - model_runner_output) + scheduler.update_from_output(scheduler_output, model_runner_output) # STEP 3: continue running, KVs not arrived yet. scheduler_output = scheduler.schedule() @@ -175,8 +168,7 @@ def test_interleaved_lifecycle(): model_runner_output = create_model_runner_output( reqs=[request_local_a, request_local_b]) - scheduler.update_from_output(scheduler_output, - model_runner_output) + scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 @@ -191,10 +183,8 @@ def test_interleaved_lifecycle(): model_runner_output = create_model_runner_output( [request_local_a, request_local_b], - finished_recving=[request_remote.request_id] - ) - scheduler.update_from_output(scheduler_output, - model_runner_output) + finished_recving=[request_remote.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) # STEP 5: RECVed KVs are sent to ModelRunner. scheduler_output = scheduler.schedule() @@ -204,8 +194,7 @@ def test_interleaved_lifecycle(): assert len(scheduler_output.scheduled_cached_reqs) == 2 model_runner_output = create_model_runner_output( - [request_local_a, request_local_b, request_remote] - ) + [request_local_a, request_local_b, request_remote]) scheduler.update_from_output(scheduler_output, model_runner_output) # STEP 6: Hit EOS and free. @@ -214,8 +203,7 @@ def test_interleaved_lifecycle(): [request_local_a, request_local_b, request_remote], use_eos=True, ) - scheduler.update_from_output( - scheduler_output, model_runner_output) + scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.schedule() assert_scheduler_empty(scheduler) @@ -233,12 +221,12 @@ def test_no_spurious_prefix_caching(): vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) - + # 2 and a half full external blocks. BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - + # Both of these requests have prompts like [1,1,1,1,1, ...] request_remote = create_request( request_id=1, @@ -258,30 +246,28 @@ def test_no_spurious_prefix_caching(): # cause any blocks to be cached. scheduler.add_request(request_remote) scheduler_output = scheduler.schedule() - scheduler.update_from_output( - scheduler_output, - EMPTY_MODEL_RUNNER_OUTPUT - ) + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) assert len(scheduler.waiting) == 1 # Schedule the local prefill request. This should - # cause blocks to be cached, but separately from + # cause blocks to be cached, but separately from scheduler.add_request(request_local) scheduler_output = scheduler.schedule() assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 - local_blocks = scheduler.kv_cache_manager.req_to_blocks[request_local.request_id] - remote_blocks = scheduler.kv_cache_manager.req_to_blocks[request_remote.request_id] + local_blocks = scheduler.kv_cache_manager.req_to_blocks[ + request_local.request_id] + remote_blocks = scheduler.kv_cache_manager.req_to_blocks[ + request_remote.request_id] # Local should have cached blocks (but not all due to preallocate). num_hashed_blocks = 0 for block in local_blocks: assert block.ref_cnt == 1 - num_hashed_blocks += ( - 1 if block._block_hash is not None else 0) + num_hashed_blocks += (1 if block._block_hash is not None else 0) assert num_hashed_blocks > 0 - + # Remote blocks should not be cached. for block in remote_blocks: assert block.ref_cnt == 1 @@ -292,4 +278,4 @@ def test_remote_prefill_no_blocks_available(): """ letTest whether we properly handle no blocks available """ - pass \ No newline at end of file + pass From aeef78bda950822fa04baa963f1f4a0fa733dd8c Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 20:27:37 +0000 Subject: [PATCH 114/119] update Signed-off-by: Tyler Michael Smith --- requirements/test.txt | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index a27d89a2acc7..e42ad984d796 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -27,10 +27,6 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration -async-timeout==5.0.1 - # via - # aiohttp - # redis attrs==24.2.0 # via # aiohttp @@ -130,11 +126,6 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval -exceptiongroup==1.2.2 - # via - # anyio - # hypothesis - # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -691,13 +682,8 @@ tokenizers==0.21.1 # via # -r requirements/test.in # transformers -toml==0.10.2 - # via datamodel-code-generator tomli==2.2.1 - # via - # black - # pytest - # schemathesis + # via schemathesis tomli-w==1.2.0 # via schemathesis torch==2.7.0 @@ -769,16 +755,12 @@ types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 # via - # anyio - # black # huggingface-hub # librosa # mistral-common - # multidict # pqdm # pydantic # pydantic-core - # rich # torch # typer tzdata==2024.2 From 5306d5b543e0b3760b347599c2270eab33c61268 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 2 May 2025 20:37:26 +0000 Subject: [PATCH 115/119] Add get_finished to base kv connector Signed-off-by: mgoin --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index dfc49f5adde7..ca9e19156719 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -72,6 +72,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ return + def get_finished(self) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + return set(), set() + @property def role(self) -> KVConnectorRole: return self._role From a16f2bec9b1b2b33c7211eeb60756b7f14be57c4 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 20:39:17 +0000 Subject: [PATCH 116/119] revert test.txt Signed-off-by: Tyler Michael Smith --- requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/test.txt b/requirements/test.txt index e42ad984d796..d4c92f15025f 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -623,6 +623,7 @@ setuptools==77.0.3 # via # mamba-ssm # pytablewriter + # torch # triton shellingham==1.5.4 # via typer From 445b010631652b539d32127a35caaf276ce00b88 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 20:47:51 +0000 Subject: [PATCH 117/119] cleanup Signed-off-by: Tyler Michael Smith --- vllm/v1/request.py | 3 --- vllm/v1/worker/gpu_model_runner.py | 5 ++++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index c855e389e6d3..837359a4b6a4 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -68,9 +68,6 @@ def __init__( self.do_remote_prefill = ( False if sampling_params.kv_transfer_params is None else sampling_params.kv_transfer_params.do_remote_prefill) - - #TODO: need to get the remote_engine_id and - # remote block_ids self.kv_transfer_params = sampling_params.kv_transfer_params # Sanity check diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 403a288f9e90..462aa0b8d4fd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1054,9 +1054,12 @@ def maybe_get_finished() -> tuple[set[str], set[str]]: maybe_setup_kv_connector() maybe_wait_for_save() finished_sending, finished_recving = maybe_get_finished() + # Return empty ModelRunnerOutput if there's no work to do. - output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + output = EMPTY_MODEL_RUNNER_OUTPUT + if len(finished_sending) > 0 or len(finished_recving) > 0: + output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) output.finished_sending = finished_sending output.finished_recving = finished_recving return output From 14a4c64d3f4204d9458de8427c0a2b887d68c3e3 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 2 May 2025 21:15:03 +0000 Subject: [PATCH 118/119] Cleanup Signed-off-by: Tyler Michael Smith --- .../kv_connector/v1/nixl_connector.py | 172 +++++++++--------- 1 file changed, 88 insertions(+), 84 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a9d4606c4f60..3414cb67379d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import math +import os import time import uuid from collections import defaultdict @@ -15,6 +16,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.logger import init_logger from vllm.sampling_params import KVTransferParams +from vllm.utils import round_down from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: @@ -119,7 +121,6 @@ def build_connector_meta( ############################################################ # Worker Side Methods ############################################################ - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) @@ -172,9 +173,8 @@ def get_num_new_matched_tokens(self, request: "Request", assert num_computed_tokens % self.block_size == 0 if request.do_remote_prefill: - num_external_blocks = len( - request.prompt_token_ids) // self.block_size - rounded_num_prompt_tokens = num_external_blocks * self.block_size + rounded_num_prompt_tokens = round_down( + len(request.prompt_token_ids), self.block_size) return max(rounded_num_prompt_tokens - num_computed_tokens, 0) return 0 @@ -298,29 +298,31 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._registered_descs.append(descs) - # THIS IS FOR DEBUG and INSECURE - import os + # THIS IS FOR DEV _ctx = zmq.Context() # type: ignore _side_channel = _ctx.socket(zmq.PAIR) # type: ignore NIXL_ROLE = os.getenv("NIXL_ROLE") - # For debug, SENDER puts some stuff in the KV caches - # so the RECVER can check it - n_blocks_to_send = min(4096, kv_caches[first_layer_name].shape[1]) - debug_xfer_gb = 2.0 * n_blocks_to_send * self.block_len / 1e9 - print(f"gb {debug_xfer_gb} -- block_len {self.block_len}") - if NIXL_ROLE == "SENDER": - for b in range(n_blocks_to_send): - kv_caches[first_layer_name][0, b, 0, 0, 0] = b + 100.0 - kv_caches[first_layer_name][1, b, 0, 0, 0] = b + 200.0 - for b in range(5): - print( - f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][0, b, 0, 0, 0]}" #noqa - ) - print( - f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][1, b, 0, 0, 0]}" #noqa - ) - remote_engine_id = None # HACK for debug send + # FOR DEV, SENDER puts data in its KV caches so the RECVER can check it + if os.environ.get("VLLM_DEBUG_INITIAL_NIXL_PD_XFER") is not None: + n_blocks_to_send = min(4096, kv_caches[first_layer_name].shape[1]) + debug_xfer_gb = 2.0 * n_blocks_to_send * self.block_len / 1e9 + logger.debug( + "Starting initial NIXL PD XFER: Total %s GB, Block len %s KB", + debug_xfer_gb, self.block_len / 1024) + if NIXL_ROLE == "SENDER": + for b in range(n_blocks_to_send): + kv_caches[first_layer_name][0, b, 0, 0, 0] = b + 100.0 + kv_caches[first_layer_name][1, b, 0, 0, 0] = b + 200.0 + + for b in range(5): + logger.debug("%s KV_CACHE coord %s val %f", NIXL_ROLE, + (0, b, 0, 0, 0), + kv_caches[first_layer_name][0, b, 0, 0, 0]) + logger.debug("%s KV_CACHE coord %s val %f", NIXL_ROLE, + (1, b, 0, 0, 0), + kv_caches[first_layer_name][1, b, 0, 0, 0]) + remote_engine_id = None # HACK for debug send if NIXL_ROLE == "SENDER": _side_channel.connect("tcp://localhost:5577") @@ -352,74 +354,76 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): remote_engine_id = metadata.engine_id #HACK self.add_remote_agent(metadata) - print("SENDING ACK") + logger.debug("SENDING ACK") _side_channel.send(b"ack") else: raise Exception("SET NIXL_ROLE to SENDER OR RECVER") - # FOR DEBUG: try to send some shit - - if NIXL_ROLE == "RECVER": - logger.debug("Sending blocks") - connector_metadata = NixlConnectorMetadata() - assert remote_engine_id is not None - xfer_params = KVTransferParams( - do_remote_decode=True, - do_remote_prefill=False, - remote_block_ids=list(range(n_blocks_to_send)), - remote_engine_id=remote_engine_id #HACK - ) + # FOR DEV, SENDER puts data in its KV caches so the RECVER can check it + + if os.environ.get("VLLM_DEBUG_INITIAL_NIXL_PD_XFER") is not None: + initial_xfer_req_id = "initial_xfer_req_id" + + if NIXL_ROLE == "RECVER": + logger.debug("SENDING BLOCKS") + connector_metadata = NixlConnectorMetadata() + assert remote_engine_id is not None + xfer_params = KVTransferParams( + do_remote_decode=True, + do_remote_prefill=False, + remote_block_ids=list(range(n_blocks_to_send)), + remote_engine_id=remote_engine_id #HACK + ) + + connector_metadata.add_new_req(request_id=initial_xfer_req_id, + local_block_ids=list( + range(n_blocks_to_send)), + kv_transfer_params=xfer_params) + self.start_load_kv(connector_metadata) + + # Wait for Receive to complete + logger.debug("START RECEIVE XFER") + done = False + start_time = time.time() + while (not done): + finished = self.get_finished() + done = initial_xfer_req_id in finished[1] + time.sleep(1e-5) + end_time = time.time() + execution_time = end_time - start_time + logger.debug( + "Transfer Received. Duration: %f ms Bandwidth %f GB/s", + 1e3 * execution_time, debug_xfer_gb / execution_time) - connector_metadata.add_new_req(request_id="tms", - local_block_ids=list( - range(n_blocks_to_send)), - kv_transfer_params=xfer_params) - self.start_load_kv(connector_metadata) - - # Wait for Receive to complete - logger.debug("TMS START RECEIVE XFER") - done = False - start_time = time.time() - while (not done): - finished = self.get_finished() - # NOTE: Should fix discrepancy between bytes/str finished sets - # Here we have str. For sender we have bytes. - done = "tms" in finished[1] - time.sleep(1e-5) - end_time = time.time() - execution_time = end_time - start_time - logger.debug( - "Transfer Received. Duration: %f ms Bandwidth %f GB/s", - 1e3 * execution_time, debug_xfer_gb / execution_time) - - if NIXL_ROLE == "SENDER": - # Wait for Send to complete - logger.debug("TMS START SEND XFER") - done = False - start_time = time.time() - while (not done): - finished = self.get_finished() - done = "tms" in finished[0] - time.sleep(1e-5) - end_time = time.time() - execution_time = end_time - start_time - logger.debug("Transfer Sent. Duration: %f ms Bandwidth %f GB/s", - 1e3 * execution_time, debug_xfer_gb / execution_time) - - # Put some different stuff in there if NIXL_ROLE == "SENDER": - for b in range(n_blocks_to_send): - kv_caches[first_layer_name][0, b, 0, 0, 0] = b + 300.0 - kv_caches[first_layer_name][1, b, 0, 0, 0] = b + 400.0 - - for b in range(5): - print( - f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][0, b, 0, 0, 0]}" #noqa - ) - print( - f"{NIXL_ROLE} KV_CACHE block b val {kv_caches[first_layer_name][1, b, 0, 0, 0]}" #noqa - ) + # Wait for Send to complete + logger.debug("START SEND XFER") + done = False + start_time = time.time() + while (not done): + finished = self.get_finished() + done = initial_xfer_req_id in finished[0] + time.sleep(1e-5) + end_time = time.time() + execution_time = end_time - start_time + logger.debug( + "Transfer Sent. Duration: %f ms Bandwidth %f GB/s", + 1e3 * execution_time, debug_xfer_gb / execution_time) + + # Put some different stuff in there + if NIXL_ROLE == "SENDER": + for b in range(n_blocks_to_send): + kv_caches[first_layer_name][0, b, 0, 0, 0] = b + 300.0 + kv_caches[first_layer_name][1, b, 0, 0, 0] = b + 400.0 + + for b in range(5): + logger.debug("%s KV_CACHE coord %s val %f", NIXL_ROLE, + (0, b, 0, 0, 0), + kv_caches[first_layer_name][0, b, 0, 0, 0]) + logger.debug("%s KV_CACHE coord %s val %f", NIXL_ROLE, + (1, b, 0, 0, 0), + kv_caches[first_layer_name][1, b, 0, 0, 0]) def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, tp_idx=0): engine_id = nixl_agent_meta.engine_id From fc7d8ad6f25dab59ddc8fb5f3057f6e795e58096 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sat, 3 May 2025 14:47:25 +0000 Subject: [PATCH 119/119] review comments Signed-off-by: Tyler Michael Smith --- .../test_remote_prefill_lifecycle.py | 7 ----- tests/v1/kv_connector/utils.py | 4 --- .../kv_connector/v1/nixl_connector.py | 27 ++++++++----------- vllm/entrypoints/openai/serving_completion.py | 9 +++---- vllm/sampling_params.py | 11 ++++---- vllm/v1/core/kv_cache_manager.py | 5 ++-- vllm/v1/core/sched/scheduler.py | 26 +++++++++++------- 7 files changed, 38 insertions(+), 51 deletions(-) diff --git a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py index ba614c44e105..91fcbf53fb3d 100644 --- a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/test_remote_prefill_lifecycle.py @@ -272,10 +272,3 @@ def test_no_spurious_prefix_caching(): for block in remote_blocks: assert block.ref_cnt == 1 assert block._block_hash is None - - -def test_remote_prefill_no_blocks_available(): - """ - letTest whether we properly handle no blocks available - """ - pass diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/utils.py index 409b5ac69e97..0387cd58ab0f 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/utils.py @@ -43,10 +43,6 @@ def assert_scheduler_empty(scheduler: Scheduler): # value, etc will remain since we lazily evict for prefix cache. for block in scheduler.kv_cache_manager.block_pool.blocks: assert block.ref_cnt == 0 - # assert block._block_hash is None - # assert ( - # len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block - # ) == 0) def create_vllm_config( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 3414cb67379d..6a75f1522461 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -182,8 +182,6 @@ def get_num_new_matched_tokens(self, request: "Request", def update_state_after_alloc(self, request: "Request", block_ids: list[int], num_external_tokens: int): - if request.do_remote_decode: - pass if request.do_remote_prefill and num_external_tokens > 0: self._reqs_need_recv[request.request_id] = (request, block_ids) @@ -333,12 +331,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, ) - encoder = msgspec.msgpack.Encoder() - encoded_data = encoder.encode(metadata) + encoded_data = msgspec.msgpack.encode(metadata) size_in_bytes = len(encoded_data) logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes)) - _side_channel.send(encoder.encode(metadata)) + _side_channel.send(encoded_data) logger.debug("WAITING ON RECV") ack = _side_channel.recv() @@ -433,10 +430,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, tp_idx=0): num_blocks = nixl_agent_meta.num_blocks logger.debug("Adding remote agent %s %s", engine_id, str(num_blocks)) - agent_names = [] - agent_name = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) - agent_names.append(agent_name) + agent_names = [ + self.nixl_wrapper.add_remote_agent(nixl_agent_meta.agent_metadata) + ] self._remote_agents[engine_id] = agent_names self.kv_caches_base_addr[ @@ -578,8 +574,7 @@ def _read_blocks( # Note(tms): The remote_block_ids only contain full computed blocks, # while the local_block_ids are all blocks allocated for this request, # so truncate the local_block_ids to account for this. - if len(remote_block_ids) < len(local_block_ids): - local_block_ids = local_block_ids[:len(remote_block_ids)] + del local_block_ids[len(remote_block_ids):] assert len(local_block_ids) == len(remote_block_ids) # NOTE(rob): this can cause the remote blocks to not be freed? @@ -639,9 +634,9 @@ def _get_block_descs_ids(self, if i is not None: raise NotImplementedError("Prefill and Decode instances must have " "the same TP size.") - else: - num_blocks = self.dst_num_blocks[engine_id] - for reg_id in region_ids: - for block_id in block_ids: - descs_ids.append(reg_id * num_blocks + block_id) + + num_blocks = self.dst_num_blocks[engine_id] + for reg_id in region_ids: + for block_id in block_ids: + descs_ids.append(reg_id * num_blocks + block_id) return descs_ids diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index b8d463357387..42180b81119f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -476,11 +476,10 @@ def request_output_to_completion_response( request_metadata.final_usage_info = usage - if final_res_batch[0].kv_transfer_params is not None: - remote_engine_id = final_res_batch[ - 0].kv_transfer_params.remote_engine_id - remote_block_ids = final_res_batch[ - 0].kv_transfer_params.remote_block_ids + kv_transfer_params = final_res_batch[0].kv_transfer_params + if kv_transfer_params is not None: + remote_engine_id = kv_transfer_params.remote_engine_id + remote_block_ids = kv_transfer_params.remote_block_ids else: remote_engine_id = None remote_block_ids = None diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 30a7184b59cd..6f10ba3d5fd3 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -49,15 +49,14 @@ def from_optional( if do_remote_decode and do_remote_prefill: raise ValueError( "Cannot do both remote prefill and remote decode.") - elif do_remote_decode or do_remote_prefill: + if do_remote_decode or do_remote_prefill: return KVTransferParams( do_remote_decode=do_remote_decode, do_remote_prefill=do_remote_prefill, remote_engine_id=remote_engine_id, remote_block_ids=remote_block_ids, ) - else: - return None + return None # maybe make msgspec? @@ -219,9 +218,9 @@ class SamplingParams( logits_processors: list of functions that modify logits based on previously generated tokens, and optionally prompt tokens as a first argument. - truncate_prompt_tokens: If set to -1, will use the truncation size - supported by the model. If set to an integer k, will use only - the last k tokens from the prompt (i.e., left truncation). + truncate_prompt_tokens: If set to -1, will use the truncation size + supported by the model. If set to an integer k, will use only + the last k tokens from the prompt (i.e., left truncation). Defaults to None (i.e., no truncation). guided_decoding: If provided, the engine will construct a guided decoding logits processor from these parameters. Defaults to None. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 56b0b0796408..12c55be00375 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -399,9 +399,8 @@ def get_num_common_prefix_blocks( Returns: int: The number of common prefix blocks. """ - assert request.status in [ - RequestStatus.RUNNING, RequestStatus.FINISHED_REMOTE_DECODE - ] + assert request.status in (RequestStatus.RUNNING, + RequestStatus.FINISHED_REMOTE_DECODE) blocks = self.req_to_blocks[request.request_id] num_common_blocks = 0 for block in blocks: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c31b74e8eb07..9f5e476bb1eb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2,6 +2,7 @@ from __future__ import annotations +import itertools import time from collections import defaultdict, deque from collections.abc import Iterable @@ -98,7 +99,7 @@ def __init__( self.finished_req_ids: set[str] = set() # Requests in states for tracking KV transfers for P/D disagg - self.finished_recving_KV_req_ids: set[str] = set() + self.finished_recving_kv_req_ids: set[str] = set() # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. @@ -315,7 +316,7 @@ def schedule(self) -> SchedulerOutput: # Skip request if the remote KV recv is still waiting # for the requests to arrive. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: - if request.request_id in self.finished_recving_KV_req_ids: + if request.request_id in self.finished_recving_kv_req_ids: assert self.kv_cache_manager.enable_caching # Now that the KVs have been recved, we can cache # them and set num_computed_tokens. @@ -324,7 +325,7 @@ def schedule(self) -> SchedulerOutput: num_tokens=0, num_computed_tokens=(len(request.all_token_ids) - 1)) - self.finished_recving_KV_req_ids.remove( + self.finished_recving_kv_req_ids.remove( request.request_id) request.status = RequestStatus.WAITING self.kv_cache_manager.free(request) @@ -369,7 +370,7 @@ def schedule(self) -> SchedulerOutput: # Total computed tokens (local + external). num_computed_tokens += num_external_tokens - if (request.do_remote_prefill and num_external_tokens > 0): + if request.do_remote_prefill and num_external_tokens > 0: # Allocate slots for the external tokens, but skip # caching until after the KV transfer is done. new_blocks = self.kv_cache_manager.allocate_slots( @@ -391,7 +392,10 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: self.connector.update_state_after_alloc( request, - [b.block_id for b in computed_blocks + new_blocks], + [ + b.block_id for b in itertools.chain( + computed_blocks, new_blocks) + ], num_external_tokens, ) # We should only trigger a KV transfer once per request. @@ -439,7 +443,10 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: self.connector.update_state_after_alloc( request, - [b.block_id for b in computed_blocks + new_blocks], + [ + b.block_id for b in itertools.chain( + computed_blocks, new_blocks) + ], num_external_tokens, ) @@ -573,9 +580,8 @@ def schedule(self) -> SchedulerOutput: # 3. If some tokens (e.g. spec tokens) are rejected later, the number of # computed tokens will be adjusted in update_from_output. for req_id, num_scheduled_token in num_scheduled_tokens.items(): - if req_id in self.requests: - self.requests[ - req_id].num_computed_tokens += num_scheduled_token + if req := self.requests.get(req_id): + req.num_computed_tokens += num_scheduled_token self.finished_req_ids = set() return scheduler_output @@ -851,7 +857,7 @@ def update_from_output( # P/D: update recv and send status from last step. for req_id in (model_runner_output.finished_recving or []): logger.debug("Finished recving KV transfer for request %s", req_id) - self.finished_recving_KV_req_ids.add(req_id) + self.finished_recving_kv_req_ids.add(req_id) for req_id in (model_runner_output.finished_sending or []): logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id])