diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index 6ab7619bff19..f1e6175268de 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -4,7 +4,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" +openai_api_base = "http://localhost:9000/v1" def main(): @@ -14,18 +14,17 @@ def main(): base_url=openai_api_base, ) - models = client.models.list() - model = models.data[0].id + # models = client.models.list() + # model = models.data[0].id # Completion API stream = False completion = client.completions.create( - model=model, - prompt="A robot may not injure a human being", + model="meta-llama/Llama-3.1-8B-Instruct", + prompt= + "The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, which means that there is a lot of opportunity to work with community and customers on key infrastructure projects. This means", # noqa: E501 echo=False, - n=2, - stream=stream, - logprobs=3) + stream=stream) print("-" * 50) print("Completion results:") diff --git a/examples/online_serving/openai_completion_client2.py b/examples/online_serving/openai_completion_client2.py new file mode 100644 index 000000000000..fb6d0d120b5e --- /dev/null +++ b/examples/online_serving/openai_completion_client2.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 + +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8100/v1" + + +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + # models = client.models.list() + # model = models.data[0].id + + # Completion API + stream = False + completion = client.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + prompt="The quick brown job jumped", + echo=False, + stream=stream) + + print("-" * 50) + print("Completion results:") + if stream: + for c in completion: + print(c) + else: + print(completion) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/other/LMCache/disagg-example.sh b/examples/other/LMCache/disagg-example.sh index 43b0b59c88f8..89f1c753e887 100644 --- a/examples/other/LMCache/disagg-example.sh +++ b/examples/other/LMCache/disagg-example.sh @@ -25,10 +25,9 @@ if [[ $1 == "prefiller" ]]; then LMCACHE_USE_EXPERIMENTAL=True \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ - CUDA_VISIBLE_DEVICES=0 \ + CUDA_VISIBLE_DEVICES=6 \ vllm serve $MODEL \ --port 8100 \ - --disable-log-requests \ --enforce-eager \ --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' @@ -46,10 +45,9 @@ elif [[ $1 == "decoder" ]]; then LMCACHE_USE_EXPERIMENTAL=True \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ - CUDA_VISIBLE_DEVICES=1 \ + CUDA_VISIBLE_DEVICES=7 \ vllm serve $MODEL \ --port 8200 \ - --disable-log-requests \ --enforce-eager \ --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' diff --git a/examples/other/LMCache/disagg_proxy_server.py b/examples/other/LMCache/disagg_proxy_server.py index 8db93bc8931b..2639409a1522 100644 --- a/examples/other/LMCache/disagg_proxy_server.py +++ b/examples/other/LMCache/disagg_proxy_server.py @@ -88,13 +88,12 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, Send a request to a service using a persistent client. """ req_data = req_data.copy() - req_data['max_tokens'] = 1 - if 'max_completion_tokens' in req_data: - req_data['max_completion_tokens'] = 1 - + req_data['do_remote_decode'] = True + req_data["stream"] = False headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} response = await client.post(endpoint, json=req_data, headers=headers) response.raise_for_status() + return response @@ -104,6 +103,7 @@ async def stream_service_response(client: httpx.AsyncClient, endpoint: str, Asynchronously stream the response from a service using a persistent client. """ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + req_data['do_remote_prefill'] = True async with client.stream("POST", endpoint, json=req_data, headers=headers) as response: response.raise_for_status() diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8d2ab29d221e..c56a15af1367 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -17,7 +17,8 @@ from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, - RequestOutputKind, SamplingParams) + KVTransferParams, RequestOutputKind, + SamplingParams) from vllm.sequence import Logprob from vllm.utils import random_uuid, resolve_obj_by_qualname @@ -807,6 +808,14 @@ class CompletionRequest(OpenAIBaseModel): " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + do_remote_decode: bool = Field( + default=False, + description="KVTransfer parameters used for disaggregated serving.") + + do_remote_prefill: bool = Field( + default=False, + description="KVTransfer parameters used for disaggregated serving.") + # doc: end-completion-extra-params # Default sampling parameters for completion requests @@ -904,6 +913,11 @@ def to_sampling_params( whitespace_pattern=self.guided_whitespace_pattern, ) + kv_transfer_params = KVTransferParams.from_optional( + do_remote_decode=self.do_remote_decode, + do_remote_prefill=self.do_remote_prefill, + ) + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -932,7 +946,9 @@ def to_sampling_params( else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids) + allowed_token_ids=self.allowed_token_ids, + kv_transfer_params=kv_transfer_params, + ) @model_validator(mode="before") @classmethod diff --git a/vllm/outputs.py b/vllm/outputs.py index 014e8d5d8823..06206ea3e1ec 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import KVTransferParams, RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceGroupBase, SequenceStatus) @@ -103,6 +103,7 @@ class RequestOutput: encoder_prompt_token_ids: The token IDs of the encoder prompt. None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. + kv_transfer_params: The params for remote K/V transfer. """ def __init__( @@ -120,6 +121,7 @@ def __init__( num_cached_tokens: Optional[int] = None, *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, + kv_transfer_params: Optional[KVTransferParams] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -133,11 +135,13 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens + self.kv_transfer_params = kv_transfer_params def add(self, next_output: "RequestOutput") -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished + self.kv_transfer_params = next_output.kv_transfer_params for next_completion in next_output.outputs: for completion in self.outputs: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 68ed99664947..2504b0367b04 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -26,6 +26,33 @@ class SamplingType(IntEnum): RANDOM_SEED = 2 +# TODO(rob): make this per connector +class KVTransferParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + # TODO(rob): we can handle xPyD and direct KV block Xfer + # remote_instance_id: Optional[str] = None + # remote_block_ids: Optional[list[int]] = None + do_remote_decode: bool = False + do_remote_prefill: bool = False + + @staticmethod + def from_optional(do_remote_decode: bool, + do_remote_prefill: bool) -> Optional["KVTransferParams"]: + if do_remote_prefill and do_remote_prefill: + raise ValueError( + "Cannot do both remote prefill and remote decode.") + elif do_remote_decode or do_remote_prefill: + return KVTransferParams( + do_remote_decode=do_remote_decode, + do_remote_prefill=do_remote_prefill, + ) + else: + return None + + # maybe make msgspec? @dataclass class GuidedDecodingParams: @@ -237,6 +264,9 @@ class SamplingParams( bad_words: Optional[list[str]] = None _bad_words_token_ids: Optional[list[list[int]]] = None + # Fields used for KVTransfer in disaggregated serving. + kv_transfer_params: Optional[KVTransferParams] = None + @staticmethod def from_optional( n: Optional[int] = 1, @@ -268,6 +298,7 @@ def from_optional( guided_decoding: Optional[GuidedDecodingParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, allowed_token_ids: Optional[list[int]] = None, + kv_transfer_params: Optional[KVTransferParams] = None, extra_args: Optional[dict[str, Any]] = None, ) -> "SamplingParams": if logit_bias is not None: @@ -310,6 +341,7 @@ def from_optional( guided_decoding=guided_decoding, logit_bias=logit_bias, allowed_token_ids=allowed_token_ids, + kv_transfer_params=kv_transfer_params, extra_args=extra_args, ) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 6263743d9710..297a2d2a1355 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -128,4 +128,4 @@ class SchedulerOutput: kv_connector_metadata: Optional[KVConnectorMetadata] = None sending_KV_req_ids: set[str] = field(default_factory=set) receiving_KV_req_ids: set[str] = field(default_factory=set) - new_KV_requests_to_send: list[NewRequestData] = field(default_factory=list) + new_KV_req_ids_to_send: list[str] = field(default_factory=list) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d3e562594aa1..000375e6a533 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.sampling_params import KVTransferParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -173,18 +174,8 @@ def schedule(self) -> SchedulerOutput: scheduled_timestamp = time.monotonic() # Check for new remote decode requests for P/D - new_KV_requests_to_send: list[Request] = [] + new_KV_req_ids_to_send: list[str] = [] if self.connector is not None: - # TODO: Receive request over ZMQ - # self.receiving_KV_req_ids.update( - # self.connector.receive_remote_decode_requests()) - - # Check if any P/D requests have finished sending or receiving - for req_id in list(self.waiting_to_send_KV_req_ids): - self.sending_KV_req_ids.add(req_id) - self.waiting_to_send_KV_req_ids.remove(req_id) - new_KV_requests_to_send.append(self.requests[req_id]) - for req_id in list(self.sending_KV_req_ids): if self.connector.is_request_done_sending(req_id): self.sending_KV_req_ids.remove(req_id) @@ -193,6 +184,10 @@ def schedule(self) -> SchedulerOutput: if self.connector.is_request_done_receiving(req_id): self.receiving_KV_req_ids.remove(req_id) self.waiting.append(self.requests[req_id]) + for req_id in list(self.waiting_to_send_KV_req_ids): + self.sending_KV_req_ids.add(req_id) + self.waiting_to_send_KV_req_ids.remove(req_id) + new_KV_req_ids_to_send.append(req_id) # First, schedule the RUNNING requests. req_index = 0 @@ -328,6 +323,19 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.appendleft(request) continue + # TODO(rob): we should do this after we allocate the blocks if + # we want to write directly into the BlockTable (like Dynamo). + # TODO(rob): this logic is incorrect if the req was preempted. + if request.do_remote_decode: + assert self.connector is not None + if not self.connector.is_request_done_receiving( + request.request_id): + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + self.receiving_KV_req_ids.add(request.request_id) + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + # Check that adding the request still respects the max_loras # constraint. if self.lora_config and request.lora_request and ( @@ -506,18 +514,13 @@ def schedule(self) -> SchedulerOutput: # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) - scheduler_output.kv_connector_metadata = meta - # TODO: encapsulate these in the KV connector metadata scheduler_output.sending_KV_req_ids = self.sending_KV_req_ids scheduler_output.receiving_KV_req_ids = self.receiving_KV_req_ids - new_KV_to_send_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_block_ids[req.request_id]) - for req in new_KV_requests_to_send - ] - scheduler_output.new_KV_requests_to_send = new_KV_to_send_reqs_data + scheduler_output.new_KV_req_ids_to_send = new_KV_req_ids_to_send + + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta # Advance the number of computed tokens for the request AFTER # the request is scheduled. @@ -719,7 +722,6 @@ def update_from_output( # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. - # TODO: What if we detect we're done here when doing P/D disagg? stopped = check_stop(request, self.max_model_len) if stopped: self._free_request(request) @@ -742,6 +744,22 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: + # Stop request after the first token if doing a remote_decode. + # NOTE(rob): req is not freed (or preempted) in the EngineCore + # until the xfer is done to ensure we do not free the KV blocks. + kv_transfer_params = None + if request.do_remote_decode and not stopped: + stopped = True + request.status = RequestStatus.FINISHED_REMOTE_DECODE + self.waiting_to_send_KV_req_ids.add(req_id) + assert self.connector is not None + # TODO(rob): do this on a per-Connector basis. + # NOTE(rob): this KVTransferParams will be sent to the + # DWorker. From the POV of the DWorker, it should be a + # remote Prefill. + kv_transfer_params = KVTransferParams( + do_remote_prefill=True) + # Add EngineCoreOutput for this Request. outputs.append( EngineCoreOutput( @@ -751,18 +769,14 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events())) + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + )) + else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors - if self.connector is not None and request.do_remote_decode: - stopped = True - self.waiting_to_send_KV_req_ids.add(req_id) - # TODO: Add ZMQ request - #self.connector.send_remote_decode_request( - # self.kv_cache_manager.req_to_blocks[req_id]) - self.scheduled_req_ids.remove(req_id) if not stopped: new_running.append(request) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index af4122a51077..ac6228edfc56 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -10,13 +10,13 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import KVTransferParams, SamplingParams from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors # These are possible values of RequestOutput.finish_reason, # so form part of the external API. -FINISH_REASON_STRINGS = ("stop", "length", "abort") +FINISH_REASON_STRINGS = ("stop", "length", "abort", "remote_decode") class FinishReason(enum.IntEnum): @@ -28,11 +28,13 @@ class FinishReason(enum.IntEnum): stop - a stop string was emitted length - max_tokens was consumed, or max_model_len was reached abort - aborted for another reason + remote_decode - request will be processed as a remote_decode """ STOP = 0 LENGTH = 1 ABORT = 2 + REMOTE_DECODE = 3 def __str__(self): return FINISH_REASON_STRINGS[self.value] @@ -102,6 +104,7 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + kv_transfer_params: Optional[KVTransferParams] = None @property def finished(self) -> bool: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 21e2a1aee4e2..1de8e8994a86 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -6,7 +6,7 @@ from typing import Optional, Union from vllm.outputs import CompletionOutput, RequestOutput -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import KVTransferParams, RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason @@ -148,6 +148,7 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + kv_transfer_params: KVTransferParams, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -169,13 +170,15 @@ def make_request_output( if not outputs: return None - return self._new_request_output(request_id, outputs, finished) + return self._new_request_output(request_id, outputs, finished, + kv_transfer_params) def _new_request_output( self, request_id: str, outputs: list[CompletionOutput], finished: bool, + kv_transfer_params: KVTransferParams, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -191,6 +194,7 @@ def _new_request_output( prompt_logprobs=prompt_logprobs, outputs=outputs, finished=finished, + kv_transfer_params=kv_transfer_params, ) def _new_completion_output( @@ -337,6 +341,7 @@ def process_outputs( new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason + kv_transfer_params = engine_core_output.kv_transfer_params req_state.is_prefilling = False @@ -352,7 +357,8 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, finish_reason, stop_reason, + kv_transfer_params): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 60b4ee739fec..11722c4ccc9a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -61,8 +61,10 @@ def __init__( self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 - # P/D disagg related - self.do_remote_decode = False + # Disaggregated serving related + self.do_remote_decode = ( + False if sampling_params.kv_transfer_params is None else + sampling_params.kv_transfer_params.do_remote_decode) # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) @@ -153,6 +155,7 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() + WAITING_FOR_REMOTE_KVS = enum.auto() RUNNING = enum.auto() SENDING_KV = enum.auto() PREEMPTED = enum.auto() @@ -162,6 +165,7 @@ class RequestStatus(enum.IntEnum): FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() + FINISHED_REMOTE_DECODE = enum.auto() @staticmethod def is_finished(status: "RequestStatus") -> bool: @@ -182,4 +186,5 @@ def get_finished_reason( RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, + RequestStatus.FINISHED_REMOTE_DECODE: FinishReason.REMOTE_DECODE } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c780bbe40934..e6c61ecb0254 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1016,6 +1016,11 @@ def maybe_setup_kv_connector(): if get_forward_context().attn_metadata is not None: kv_connector.start_load_kv(get_forward_context()) + def maybe_wait_for_save(): + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + kv_connector.wait_for_save() + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOutput if there's no work to do. @@ -1094,6 +1099,7 @@ def maybe_setup_kv_connector(): intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + maybe_wait_for_save() if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. return hidden_states