From 891a5da46d1adcd76f70efd26df9cf07f2644e07 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 18:44:34 +0000 Subject: [PATCH 01/33] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/lm_cache/__init__.py | 0 vllm/entrypoints/openai/protocol.py | 11 ++++++-- vllm/entrypoints/openai/serving_completion.py | 6 ++++- vllm/outputs.py | 6 ++++- vllm/sampling_params.py | 15 +++++++++++ vllm/v1/core/sched/scheduler.py | 27 ++++++++++++------- vllm/v1/engine/__init__.py | 9 +++++-- vllm/v1/request.py | 3 +++ 8 files changed, 62 insertions(+), 15 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/lm_cache/__init__.py diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lm_cache/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/lm_cache/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8d2ab29d221e..3193bb78d73d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -17,7 +17,8 @@ from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, - RequestOutputKind, SamplingParams) + KVTransferParams, RequestOutputKind, + SamplingParams) from vllm.sequence import Logprob from vllm.utils import random_uuid, resolve_obj_by_qualname @@ -807,6 +808,9 @@ class CompletionRequest(OpenAIBaseModel): " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + kv_transfer_params: Optional[KVTransferParams] = Field( + default=None, description=("TODO")) + # doc: end-completion-extra-params # Default sampling parameters for completion requests @@ -932,7 +936,9 @@ def to_sampling_params( else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids) + allowed_token_ids=self.allowed_token_ids, + kv_transfer_params=self.kv_transfer_params, + ) @model_validator(mode="before") @classmethod @@ -1182,6 +1188,7 @@ class CompletionResponse(OpenAIBaseModel): model: str choices: list[CompletionResponseChoice] usage: UsageInfo + kv_transfer_params: Optional[KVTransferParams] = Field(default=None) class CompletionResponseStreamChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1067f35ce240..c87883fa54e3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -184,7 +184,8 @@ async def create_completion( # we do not stream the results when use beam search. stream = (request.stream and (request.best_of is None or request.n == request.best_of) - and not request.use_beam_search) + and not request.use_beam_search + and not request.kv_transfer_params.do_remote_decode) # Streaming response if stream: @@ -476,12 +477,15 @@ def request_output_to_completion_response( request_metadata.final_usage_info = usage + # TODO(rob): assert somewhere that we dont have a batch req. + assert (len(final_res_batch) == 1) return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, + kv_transfer_params=final_res_batch[0].kv_transfer_params, ) def _create_completion_logprobs( diff --git a/vllm/outputs.py b/vllm/outputs.py index 014e8d5d8823..06206ea3e1ec 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import KVTransferParams, RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceGroupBase, SequenceStatus) @@ -103,6 +103,7 @@ class RequestOutput: encoder_prompt_token_ids: The token IDs of the encoder prompt. None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. + kv_transfer_params: The params for remote K/V transfer. """ def __init__( @@ -120,6 +121,7 @@ def __init__( num_cached_tokens: Optional[int] = None, *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, + kv_transfer_params: Optional[KVTransferParams] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -133,11 +135,13 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens + self.kv_transfer_params = kv_transfer_params def add(self, next_output: "RequestOutput") -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished + self.kv_transfer_params = next_output.kv_transfer_params for next_completion in next_output.outputs: for completion in self.outputs: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 68ed99664947..2c8e5e185bdd 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -26,6 +26,16 @@ class SamplingType(IntEnum): RANDOM_SEED = 2 +class KVTransferParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + request_id: str + remote_id: Optional[str] = None + remote_block_ids: Optional[list[int]] = None + + # maybe make msgspec? @dataclass class GuidedDecodingParams: @@ -237,6 +247,9 @@ class SamplingParams( bad_words: Optional[list[str]] = None _bad_words_token_ids: Optional[list[list[int]]] = None + # Fields used for kv cache transfer + kv_transfer_params: Optional[KVTransferParams] = None + @staticmethod def from_optional( n: Optional[int] = 1, @@ -268,6 +281,7 @@ def from_optional( guided_decoding: Optional[GuidedDecodingParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, allowed_token_ids: Optional[list[int]] = None, + kv_transfer_params: Optional[KVTransferParams] = None, extra_args: Optional[dict[str, Any]] = None, ) -> "SamplingParams": if logit_bias is not None: @@ -310,6 +324,7 @@ def from_optional( guided_decoding=guided_decoding, logit_bias=logit_bias, allowed_token_ids=allowed_token_ids, + kv_transfer_params=kv_transfer_params, extra_args=extra_args, ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d3e562594aa1..16c9330ae92c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -719,7 +719,6 @@ def update_from_output( # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. - # TODO: What if we detect we're done here when doing P/D disagg? stopped = check_stop(request, self.max_model_len) if stopped: self._free_request(request) @@ -741,7 +740,21 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + + # NOTE: new_token_ids is None if we have a partial prefill. if new_token_ids: + # If remote_decode, stop the request in the engine and add + # it to the sending KVs state. We hold onto this request in the + # engine until the sending is done. + kv_transfer_params = None + if request.do_remote_decode and not stopped: + assert self.connector is not None + stopped = True + request.status = RequestStatus.FINISHED_REMOTE_DECODE + self.sending_KV_req_ids.add(req_id) + kv_transfer_params = self.connector.make_transfer_params( + request=request, remote_decode=True) + # Add EngineCoreOutput for this Request. outputs.append( EngineCoreOutput( @@ -751,18 +764,14 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events())) + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + )) + else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors - if self.connector is not None and request.do_remote_decode: - stopped = True - self.waiting_to_send_KV_req_ids.add(req_id) - # TODO: Add ZMQ request - #self.connector.send_remote_decode_request( - # self.kv_cache_manager.req_to_blocks[req_id]) - self.scheduled_req_ids.remove(req_id) if not stopped: new_running.append(request) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index af4122a51077..939ba808a242 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -10,13 +10,13 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import KVTransferParams, SamplingParams from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors # These are possible values of RequestOutput.finish_reason, # so form part of the external API. -FINISH_REASON_STRINGS = ("stop", "length", "abort") +FINISH_REASON_STRINGS = ("stop", "length", "abort", "remote_decode") class FinishReason(enum.IntEnum): @@ -28,11 +28,13 @@ class FinishReason(enum.IntEnum): stop - a stop string was emitted length - max_tokens was consumed, or max_model_len was reached abort - aborted for another reason + remote_decode - request will be processed as a remote_decode """ STOP = 0 LENGTH = 1 ABORT = 2 + REMOTE_DECODE = 3 def __str__(self): return FINISH_REASON_STRINGS[self.value] @@ -103,6 +105,9 @@ class EngineCoreOutput( stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + # In P/D case, used to trigger remote decode + kv_transfer_params: Optional[KVTransferParams] = None + @property def finished(self) -> bool: return self.finish_reason is not None diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 60b4ee739fec..9c2f60ec6952 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -153,6 +153,7 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() + RUNNING = enum.auto() SENDING_KV = enum.auto() PREEMPTED = enum.auto() @@ -162,6 +163,7 @@ class RequestStatus(enum.IntEnum): FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() + FINISHED_REMOTE_DECODE = enum.auto() @staticmethod def is_finished(status: "RequestStatus") -> bool: @@ -182,4 +184,5 @@ def get_finished_reason( RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, + RequestStatus.FINISHED_REMOTE_DECODE: FinishReason.REMOTE_DECODE } From 0fde0215d1f2f6394fa3f5f06954c247817f94ae Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 17:12:38 +0000 Subject: [PATCH 02/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/request.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 9c2f60ec6952..54fdf1eae72b 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -153,7 +153,6 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() - RUNNING = enum.auto() SENDING_KV = enum.auto() PREEMPTED = enum.auto() From 2fb8eeaecea6b8bac23534759fabf21f710c71a3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 17:20:02 +0000 Subject: [PATCH 03/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/engine/output_processor.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 21e2a1aee4e2..1de8e8994a86 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -6,7 +6,7 @@ from typing import Optional, Union from vllm.outputs import CompletionOutput, RequestOutput -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import KVTransferParams, RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason @@ -148,6 +148,7 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + kv_transfer_params: KVTransferParams, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -169,13 +170,15 @@ def make_request_output( if not outputs: return None - return self._new_request_output(request_id, outputs, finished) + return self._new_request_output(request_id, outputs, finished, + kv_transfer_params) def _new_request_output( self, request_id: str, outputs: list[CompletionOutput], finished: bool, + kv_transfer_params: KVTransferParams, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -191,6 +194,7 @@ def _new_request_output( prompt_logprobs=prompt_logprobs, outputs=outputs, finished=finished, + kv_transfer_params=kv_transfer_params, ) def _new_completion_output( @@ -337,6 +341,7 @@ def process_outputs( new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason + kv_transfer_params = engine_core_output.kv_transfer_params req_state.is_prefilling = False @@ -352,7 +357,8 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, finish_reason, stop_reason, + kv_transfer_params): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) From 2f98b2330f8bf03f91ad1f1b986738493e2b2dc3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 18:45:07 +0000 Subject: [PATCH 04/33] updated Signed-off-by: rshaw@neuralmagic.com --- .../distributed/kv_transfer/kv_connector/v1/base.py | 8 ++++++++ .../kv_connector/v1/shared_storage_connector.py | 13 +++++++++++++ vllm/sampling_params.py | 5 ++++- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index fc67d118070e..7b71a67a2ef9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -34,6 +34,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.forward_context import ForwardContext + from vllm.sampling_params import KVTransferParams from vllm.v1.request import Request logger = init_logger(__name__) @@ -216,3 +217,10 @@ def is_request_done_sending(self, req_id: str) -> bool: def is_request_done_receiving(self, req_id: str) -> bool: raise NotImplementedError + + @abstractmethod + def build_transfer_params(self, request: "Request") -> "KVTransferParams": + """ + Build the KV transfer parameters for this step. + """ + pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 9e4ce253b618..a5323852b9fe 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext + from vllm.sampling_params import KVTransferParams from vllm.v1.request import Request logger = init_logger(__name__) @@ -336,6 +337,18 @@ def is_request_done_sending(self, req_id: str) -> bool: def is_request_done_receiving(self, req_id: str) -> bool: return True + def build_transfer_params(self, request: "Request") -> "KVTransferParams": + """ + Build the KVTransferParams for the request. + """ + + return KVTransferParams( + request_id=request.request_id, + remote_instance_id=self.remote_instance_id, + remote_block_ids=request.block_ids, + do_remote_prefill=True, + ) + # ============================== # Helper functions # ============================== diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 2c8e5e185bdd..ce3eeee36e34 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -26,14 +26,17 @@ class SamplingType(IntEnum): RANDOM_SEED = 2 +# TODO(rob): make this per connector class KVTransferParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] # required for @cached_property. dict=True): request_id: str - remote_id: Optional[str] = None + remote_instance_id: Optional[str] = None remote_block_ids: Optional[list[int]] = None + do_remote_decode: bool = False + do_remote_prefill: bool = False # maybe make msgspec? From ad63d9af92f3bb2fb6b8a35c43e4f699babde620 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 17:35:40 +0000 Subject: [PATCH 05/33] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_transfer/kv_connector/v1/shared_storage_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index a5323852b9fe..42dd3e0eb898 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -344,7 +344,7 @@ def build_transfer_params(self, request: "Request") -> "KVTransferParams": return KVTransferParams( request_id=request.request_id, - remote_instance_id=self.remote_instance_id, + remote_instance_id=self.instance_id, remote_block_ids=request.block_ids, do_remote_prefill=True, ) From 54f1e64eb2066a1a96a54bfc7a0611c1c2900ecb Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 17:37:44 +0000 Subject: [PATCH 06/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/engine/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 939ba808a242..ac6228edfc56 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -104,8 +104,6 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None - - # In P/D case, used to trigger remote decode kv_transfer_params: Optional[KVTransferParams] = None @property From a1257626c0aae29ff9fda5f171859aca653e1692 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 17:51:25 +0000 Subject: [PATCH 07/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/entrypoints/openai/protocol.py | 3 ++- vllm/entrypoints/openai/serving_completion.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3193bb78d73d..bc1ff71eba1c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -809,7 +809,8 @@ class CompletionRequest(OpenAIBaseModel): "that are not JSON-encodable can be identified.")) kv_transfer_params: Optional[KVTransferParams] = Field( - default=None, description=("TODO")) + default=None, + description=("KVTransfer parameters used for P/D disaggregation.")) # doc: end-completion-extra-params diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c87883fa54e3..eea14a6f53c1 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -477,8 +477,6 @@ def request_output_to_completion_response( request_metadata.final_usage_info = usage - # TODO(rob): assert somewhere that we dont have a batch req. - assert (len(final_res_batch) == 1) return CompletionResponse( id=request_id, created=created_time, From dd3e2997efa089e6b807bb9b44ace1e01d924822 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 17:54:57 +0000 Subject: [PATCH 08/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/request.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 54fdf1eae72b..f52790ae8910 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -62,7 +62,13 @@ def __init__( self.has_encoder_inputs = self.num_encoder_inputs > 0 # P/D disagg related - self.do_remote_decode = False + self.do_remote_decode = ( + False if sampling_params.kv_transfer_params is None else + sampling_params.kv_transfer_params.do_remote_decode) + self.do_remote_prefill = ( + False if sampling_params.kv_transfer_params is None else + sampling_params.kv_transfer_params.do_remote_decode) + assert not (self.do_remote_decode and self.do_remote_prefill) # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) From dd8df0c808fda63a3f01c582b15bb98bd84d57e3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 17:58:36 +0000 Subject: [PATCH 09/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/entrypoints/openai/serving_completion.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index eea14a6f53c1..f64e03fb0ab0 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -181,11 +181,13 @@ async def create_completion( # Similar to the OpenAI API, when n != best_of, we do not stream the # results. Noting that best_of is only supported in V0. In addition, - # we do not stream the results when use beam search. + # we do not stream the results when use beam search or if the request + # should do the decode phase remotely. + do_remote_decode = (request.kv_transfer_params + and request.kv_transfer_params.do_remote_decode) stream = (request.stream and (request.best_of is None or request.n == request.best_of) - and not request.use_beam_search - and not request.kv_transfer_params.do_remote_decode) + and not request.use_beam_search and not do_remote_decode) # Streaming response if stream: From 69c9d1368f75e2e5e3d1cc14dc787aaa8ca97233 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 17:59:39 +0000 Subject: [PATCH 10/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index bc1ff71eba1c..8f7204602cca 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -810,7 +810,7 @@ class CompletionRequest(OpenAIBaseModel): kv_transfer_params: Optional[KVTransferParams] = Field( default=None, - description=("KVTransfer parameters used for P/D disaggregation.")) + description="KVTransfer parameters used for P/D disaggregation.") # doc: end-completion-extra-params From 778389be5a5c9ff36c192e11636a37e41e445fcb Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 17:59:50 +0000 Subject: [PATCH 11/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8f7204602cca..f2b10cee3fbd 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -810,7 +810,7 @@ class CompletionRequest(OpenAIBaseModel): kv_transfer_params: Optional[KVTransferParams] = Field( default=None, - description="KVTransfer parameters used for P/D disaggregation.") + description="KVTransfer parameters used for disaggregated serving.") # doc: end-completion-extra-params From 6b2fa35cf5a8061fb429793a9d44c2ca8ac382ac Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 18:33:26 +0000 Subject: [PATCH 12/33] stash Signed-off-by: rshaw@neuralmagic.com --- .../disagg_examples/disagg_proxy_demo_v1.py | 450 ++++++++++++++++++ 1 file changed, 450 insertions(+) create mode 100644 examples/online_serving/disagg_examples/disagg_proxy_demo_v1.py diff --git a/examples/online_serving/disagg_examples/disagg_proxy_demo_v1.py b/examples/online_serving/disagg_examples/disagg_proxy_demo_v1.py new file mode 100644 index 000000000000..a701636f357a --- /dev/null +++ b/examples/online_serving/disagg_examples/disagg_proxy_demo_v1.py @@ -0,0 +1,450 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file provides a disaggregated prefilling proxy demo to demonstrate an +example usage of XpYd disaggregated prefilling. +We can launch multiple vllm instances (2 for prefill and 2 for decode), and +launch this proxy demo through: + python3 examples/online_serving/disagg_examples/disagg_proxy_demo.py \ + --model $model_name \ + --prefill localhost:8100 localhost:8101 \ + --decode localhost:8200 localhost:8201 \ + --port 8000 + +Note: This demo will be removed once the PDController implemented in PR 15343 +(https://github.com/vllm-project/vllm/pull/15343) supports XpYd. +""" +import argparse +import ipaddress +import itertools +import json +import logging +import os +import sys +from abc import ABC, abstractmethod +from typing import Callable, Optional + +import aiohttp +import requests +import uvicorn +from fastapi import (APIRouter, Depends, FastAPI, Header, HTTPException, + Request, status) +from fastapi.responses import JSONResponse, StreamingResponse + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO) + + +class SchedulingPolicy(ABC): + + @abstractmethod + def schedule(self, cycler: itertools.cycle): + raise NotImplementedError("Scheduling Proxy is not set.") + + +class Proxy: + + def __init__( + self, + prefill_instances: list[str], + decode_instances: list[str], + model: str, + scheduling_policy: SchedulingPolicy, + custom_create_completion: Optional[Callable[[Request], + StreamingResponse]] = None, + custom_create_chat_completion: Optional[Callable[ + [Request], StreamingResponse]] = None, + ): + self.prefill_instances = prefill_instances + self.decode_instances = decode_instances + self.prefill_cycler = itertools.cycle(prefill_instances) + self.decode_cycler = itertools.cycle(decode_instances) + self.model = model + self.scheduling_policy = scheduling_policy + self.custom_create_completion = custom_create_completion + self.custom_create_chat_completion = custom_create_chat_completion + self.router = APIRouter() + self.setup_routes() + + def setup_routes(self): + self.router.post( + "/v1/completions", + dependencies=[ + Depends(self.validate_json_request) + ])(self.custom_create_completion if self. + custom_create_completion else self.create_completion) + self.router.post( + "/v1/chat/completions", + dependencies=[ + Depends(self.validate_json_request) + ])(self.custom_create_chat_completion if self. + custom_create_chat_completion else self.create_chat_completion) + self.router.get("/status", + response_class=JSONResponse)(self.get_status) + self.router.post("/instances/add", + dependencies=[Depends(self.api_key_authenticate) + ])(self.add_instance_endpoint) + + async def validate_json_request(self, raw_request: Request): + content_type = raw_request.headers.get("content-type", "").lower() + if content_type != "application/json": + raise HTTPException( + status_code=415, + detail= + "Unsupported Media Type: Only 'application/json' is allowed", + ) + + def api_key_authenticate(self, x_api_key: str = Header(...)): + expected_api_key = os.environ.get("ADMIN_API_KEY") + if not expected_api_key: + logger.error("ADMIN_API_KEY is not set in the environment.") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Server configuration error.", + ) + if x_api_key != expected_api_key: + logger.warning("Unauthorized access attempt with API Key: %s", + x_api_key) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Forbidden: Invalid API Key.", + ) + + async def validate_instance(self, instance: str) -> bool: + url = f"http://{instance}/v1/models" + try: + async with aiohttp.ClientSession( + timeout=AIOHTTP_TIMEOUT) as client: + logger.info("Verifying %s ...", instance) + async with client.get(url) as response: + if response.status == 200: + data = await response.json() + if "data" in data and len(data["data"]) > 0: + model_cur = data["data"][0].get("id", "") + if model_cur == self.model: + logger.info("Instance: %s could be added.", + instance) + return True + else: + logger.warning("Mismatch model %s : %s != %s", + instance, model_cur, self.model) + return False + else: + return False + else: + return False + except aiohttp.ClientError as e: + logger.error(str(e)) + return False + except Exception as e: + logger.error(str(e)) + return False + + async def add_instance_endpoint(self, request: Request): + try: + data = await request.json() + logger.warning(str(data)) + instance_type = data.get("type") + instance = data.get("instance") + if instance_type not in ["prefill", "decode"]: + raise HTTPException(status_code=400, + detail="Invalid instance type.") + if not instance or ":" not in instance: + raise HTTPException(status_code=400, + detail="Invalid instance format.") + host, port_str = instance.split(":") + try: + if host != "localhost": + ipaddress.ip_address(host) + port = int(port_str) + if not (0 < port < 65536): + raise HTTPException(status_code=400, + detail="Invalid port number.") + except Exception as e: + raise HTTPException(status_code=400, + detail="Invalid instance address.") from e + + is_valid = await self.validate_instance(instance) + if not is_valid: + raise HTTPException(status_code=400, + detail="Instance validation failed.") + + if instance_type == "prefill": + if instance not in self.prefill_instances: + self.prefill_instances.append(instance) + self.prefill_cycler = itertools.cycle( + self.prefill_instances) + else: + raise HTTPException(status_code=400, + detail="Instance already exists.") + else: + if instance not in self.decode_instances: + self.decode_instances.append(instance) + self.decode_cycler = itertools.cycle(self.decode_instances) + else: + raise HTTPException(status_code=400, + detail="Instance already exists.") + + return JSONResponse(content={ + "message": + f"Added {instance} to {instance_type}_instances." + }) + except HTTPException as http_exc: + raise http_exc + except Exception as e: + logger.error("Error in add_instance_endpoint: %s", str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e + + async def forward_request(self, url, data, use_chunked=True): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + try: + async with session.post(url=url, json=data, + headers=headers) as response: + if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501 + if use_chunked: + async for chunk_bytes in response.content.iter_chunked( # noqa: E501 + 1024): + yield chunk_bytes + else: + content = await response.read() + yield content + else: + error_content = await response.text() + try: + error_content = json.loads(error_content) + except json.JSONDecodeError: + error_content = error_content + logger.error("Request failed with status %s: %s", + response.status, error_content) + raise HTTPException( + status_code=response.status, + detail= + f"Request failed with status {response.status}: " + f"{error_content}", + ) + except aiohttp.ClientError as e: + logger.error("ClientError occurred: %s", str(e)) + raise HTTPException( + status_code=502, + detail= + "Bad Gateway: Error communicating with upstream server.", + ) from e + except Exception as e: + logger.error("Unexpected error: %s", str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e + + def schedule(self, cycler: itertools.cycle) -> str: + return self.scheduling_policy.schedule(cycler) + + async def get_status(self): + status = { + "prefill_node_count": len(self.prefill_instances), + "decode_node_count": len(self.decode_instances), + "prefill_nodes": self.prefill_instances, + "decode_nodes": self.decode_instances, + } + return status + + async def create_completion(self, raw_request: Request): + try: + request = await raw_request.json() + + kv_prepare_request = request.copy() + kv_prepare_request["max_tokens"] = 1 + + prefill_instance = self.schedule(self.prefill_cycler) + try: + async for _ in self.forward_request( + f"http://{prefill_instance}/v1/completions", + kv_prepare_request): + continue + except HTTPException as http_exc: + self.remove_instance_endpoint("prefill", prefill_instance) + raise http_exc + + # Perform kv recv and decoding stage + decode_instance = self.schedule(self.decode_cycler) + + try: + generator = self.forward_request( + f"http://{decode_instance}/v1/completions", request) + except HTTPException as http_exc: + self.remove_instance_endpoint("decode", decode_instance) + raise http_exc + response = StreamingResponse(generator) + return response + except Exception: + import sys + + exc_info = sys.exc_info() + print("Error occurred in disagg proxy server") + print(exc_info) + + async def create_chat_completion(self, raw_request: Request): + try: + request = await raw_request.json() + + # add params to request + kv_prepare_request = request.copy() + kv_prepare_request["max_tokens"] = 1 + + # prefill stage + prefill_instance = self.schedule(self.prefill_cycler) + try: + async for _ in self.forward_request( + f"http://{prefill_instance}/v1/chat/completions", + kv_prepare_request): + continue + except HTTPException as http_exc: + self.remove_instance_endpoint("prefill", prefill_instance) + raise http_exc + # Perform kv recv and decoding stage + decode_instance = self.schedule(self.decode_cycler) + + try: + generator = self.forward_request( + "http://" + decode_instance + "/v1/chat/completions", + request) + except HTTPException as http_exc: + self.remove_instance_endpoint("decode", decode_instance) + raise http_exc + response = StreamingResponse(content=generator) + return response + except Exception: + exc_info = sys.exc_info() + error_messages = [str(e) for e in exc_info if e] + print("Error occurred in disagg proxy server") + print(error_messages) + return StreamingResponse(content=iter(error_messages), + media_type="text/event-stream") + + def remove_instance_endpoint(self, instance_type, instance): + if (instance_type == "decode" and instance in self.decode_instances): + self.decode_instances.remove(instance) + self.decode_cycler = itertools.cycle(self.decode_instances) + if (instance_type == "prefill" and instance in self.decode_instances): + self.prefill_instances.remove(instance) + self.prefill_cycler = itertools.cycle(self.decode_instances) + + +class RoundRobinSchedulingPolicy(SchedulingPolicy): + + def __init__(self): + super().__init__() + + def schedule(self, cycler: itertools.cycle) -> str: + return next(cycler) + + +class ProxyServer: + + def __init__( + self, + args: argparse.Namespace, + scheduling_policy: Optional[SchedulingPolicy] = None, + create_completion: Optional[Callable[[Request], + StreamingResponse]] = None, + create_chat_completion: Optional[Callable[[Request], + StreamingResponse]] = None, + ): + self.validate_parsed_serve_args(args) + self.port = args.port + self.proxy_instance = Proxy( + prefill_instances=[] if args.prefill is None else args.prefill, + decode_instances=[] if args.decode is None else args.decode, + model=args.model, + scheduling_policy=(scheduling_policy if scheduling_policy + is not None else RoundRobinSchedulingPolicy()), + custom_create_completion=create_completion, + custom_create_chat_completion=create_chat_completion, + ) + + def validate_parsed_serve_args(self, args: argparse.Namespace): + if not args.prefill: + raise ValueError("Please specify at least one prefill node.") + if not args.decode: + raise ValueError("Please specify at least one decode node.") + self.validate_instances(args.prefill) + self.validate_instances(args.decode) + self.verify_model_config(args.prefill, args.model) + self.verify_model_config(args.decode, args.model) + + def validate_instances(self, instances: list): + for instance in instances: + if len(instance.split(":")) != 2: + raise ValueError(f"Invalid instance format: {instance}") + host, port = instance.split(":") + try: + if host != "localhost": + ipaddress.ip_address(host) + port = int(port) + if not (0 < port < 65536): + raise ValueError( + f"Invalid port number in instance: {instance}") + except Exception as e: + raise ValueError( + f"Invalid instance {instance}: {str(e)}") from e + + def verify_model_config(self, instances: list, model: str) -> None: + model_suffix = model.split("/")[-1] + for instance in instances: + try: + response = requests.get(f"http://{instance}/v1/models") + if response.status_code == 200: + model_cur = response.json()["data"][0]["id"] + model_cur_suffix = model_cur.split("/")[-1] + if model_cur_suffix != model_suffix: + raise ValueError( + f"{instance} serves a different model: " + f"{model_cur} != {model}") + else: + raise ValueError(f"Cannot get model id from {instance}!") + except requests.RequestException as e: + raise ValueError( + f"Error communicating with {instance}: {str(e)}") from e + + def run_server(self): + app = FastAPI() + app.include_router(self.proxy_instance.router) + config = uvicorn.Config(app, port=self.port, loop="uvloop") + server = uvicorn.Server(config) + server.run() + + +if __name__ == "__main__": + # Todo: allow more config + parser = argparse.ArgumentParser("vLLM disaggregated proxy server.") + parser.add_argument("--model", + "-m", + type=str, + required=True, + help="Model name") + + parser.add_argument( + "--prefill", + "-p", + type=str, + nargs="+", + help="List of prefill node URLs (host:port)", + ) + + parser.add_argument( + "--decode", + "-d", + type=str, + nargs="+", + help="List of decode node URLs (host:port)", + ) + + parser.add_argument( + "--port", + type=int, + default=8000, + help="Server port number", + ) + args = parser.parse_args() + proxy_server = ProxyServer(args=args) + proxy_server.run_server() From a4ab99673d77cd36c389677db5ea36382b2de685 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 19:06:11 +0000 Subject: [PATCH 13/33] updated Signed-off-by: rshaw@neuralmagic.com --- .../disagg_examples/disagg_proxy_demo_v1.py | 450 ------------------ 1 file changed, 450 deletions(-) delete mode 100644 examples/online_serving/disagg_examples/disagg_proxy_demo_v1.py diff --git a/examples/online_serving/disagg_examples/disagg_proxy_demo_v1.py b/examples/online_serving/disagg_examples/disagg_proxy_demo_v1.py deleted file mode 100644 index a701636f357a..000000000000 --- a/examples/online_serving/disagg_examples/disagg_proxy_demo_v1.py +++ /dev/null @@ -1,450 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -This file provides a disaggregated prefilling proxy demo to demonstrate an -example usage of XpYd disaggregated prefilling. -We can launch multiple vllm instances (2 for prefill and 2 for decode), and -launch this proxy demo through: - python3 examples/online_serving/disagg_examples/disagg_proxy_demo.py \ - --model $model_name \ - --prefill localhost:8100 localhost:8101 \ - --decode localhost:8200 localhost:8201 \ - --port 8000 - -Note: This demo will be removed once the PDController implemented in PR 15343 -(https://github.com/vllm-project/vllm/pull/15343) supports XpYd. -""" -import argparse -import ipaddress -import itertools -import json -import logging -import os -import sys -from abc import ABC, abstractmethod -from typing import Callable, Optional - -import aiohttp -import requests -import uvicorn -from fastapi import (APIRouter, Depends, FastAPI, Header, HTTPException, - Request, status) -from fastapi.responses import JSONResponse, StreamingResponse - -AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) -logger = logging.getLogger() -logging.basicConfig(level=logging.INFO) - - -class SchedulingPolicy(ABC): - - @abstractmethod - def schedule(self, cycler: itertools.cycle): - raise NotImplementedError("Scheduling Proxy is not set.") - - -class Proxy: - - def __init__( - self, - prefill_instances: list[str], - decode_instances: list[str], - model: str, - scheduling_policy: SchedulingPolicy, - custom_create_completion: Optional[Callable[[Request], - StreamingResponse]] = None, - custom_create_chat_completion: Optional[Callable[ - [Request], StreamingResponse]] = None, - ): - self.prefill_instances = prefill_instances - self.decode_instances = decode_instances - self.prefill_cycler = itertools.cycle(prefill_instances) - self.decode_cycler = itertools.cycle(decode_instances) - self.model = model - self.scheduling_policy = scheduling_policy - self.custom_create_completion = custom_create_completion - self.custom_create_chat_completion = custom_create_chat_completion - self.router = APIRouter() - self.setup_routes() - - def setup_routes(self): - self.router.post( - "/v1/completions", - dependencies=[ - Depends(self.validate_json_request) - ])(self.custom_create_completion if self. - custom_create_completion else self.create_completion) - self.router.post( - "/v1/chat/completions", - dependencies=[ - Depends(self.validate_json_request) - ])(self.custom_create_chat_completion if self. - custom_create_chat_completion else self.create_chat_completion) - self.router.get("/status", - response_class=JSONResponse)(self.get_status) - self.router.post("/instances/add", - dependencies=[Depends(self.api_key_authenticate) - ])(self.add_instance_endpoint) - - async def validate_json_request(self, raw_request: Request): - content_type = raw_request.headers.get("content-type", "").lower() - if content_type != "application/json": - raise HTTPException( - status_code=415, - detail= - "Unsupported Media Type: Only 'application/json' is allowed", - ) - - def api_key_authenticate(self, x_api_key: str = Header(...)): - expected_api_key = os.environ.get("ADMIN_API_KEY") - if not expected_api_key: - logger.error("ADMIN_API_KEY is not set in the environment.") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Server configuration error.", - ) - if x_api_key != expected_api_key: - logger.warning("Unauthorized access attempt with API Key: %s", - x_api_key) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Forbidden: Invalid API Key.", - ) - - async def validate_instance(self, instance: str) -> bool: - url = f"http://{instance}/v1/models" - try: - async with aiohttp.ClientSession( - timeout=AIOHTTP_TIMEOUT) as client: - logger.info("Verifying %s ...", instance) - async with client.get(url) as response: - if response.status == 200: - data = await response.json() - if "data" in data and len(data["data"]) > 0: - model_cur = data["data"][0].get("id", "") - if model_cur == self.model: - logger.info("Instance: %s could be added.", - instance) - return True - else: - logger.warning("Mismatch model %s : %s != %s", - instance, model_cur, self.model) - return False - else: - return False - else: - return False - except aiohttp.ClientError as e: - logger.error(str(e)) - return False - except Exception as e: - logger.error(str(e)) - return False - - async def add_instance_endpoint(self, request: Request): - try: - data = await request.json() - logger.warning(str(data)) - instance_type = data.get("type") - instance = data.get("instance") - if instance_type not in ["prefill", "decode"]: - raise HTTPException(status_code=400, - detail="Invalid instance type.") - if not instance or ":" not in instance: - raise HTTPException(status_code=400, - detail="Invalid instance format.") - host, port_str = instance.split(":") - try: - if host != "localhost": - ipaddress.ip_address(host) - port = int(port_str) - if not (0 < port < 65536): - raise HTTPException(status_code=400, - detail="Invalid port number.") - except Exception as e: - raise HTTPException(status_code=400, - detail="Invalid instance address.") from e - - is_valid = await self.validate_instance(instance) - if not is_valid: - raise HTTPException(status_code=400, - detail="Instance validation failed.") - - if instance_type == "prefill": - if instance not in self.prefill_instances: - self.prefill_instances.append(instance) - self.prefill_cycler = itertools.cycle( - self.prefill_instances) - else: - raise HTTPException(status_code=400, - detail="Instance already exists.") - else: - if instance not in self.decode_instances: - self.decode_instances.append(instance) - self.decode_cycler = itertools.cycle(self.decode_instances) - else: - raise HTTPException(status_code=400, - detail="Instance already exists.") - - return JSONResponse(content={ - "message": - f"Added {instance} to {instance_type}_instances." - }) - except HTTPException as http_exc: - raise http_exc - except Exception as e: - logger.error("Error in add_instance_endpoint: %s", str(e)) - raise HTTPException(status_code=500, detail=str(e)) from e - - async def forward_request(self, url, data, use_chunked=True): - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" - } - try: - async with session.post(url=url, json=data, - headers=headers) as response: - if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501 - if use_chunked: - async for chunk_bytes in response.content.iter_chunked( # noqa: E501 - 1024): - yield chunk_bytes - else: - content = await response.read() - yield content - else: - error_content = await response.text() - try: - error_content = json.loads(error_content) - except json.JSONDecodeError: - error_content = error_content - logger.error("Request failed with status %s: %s", - response.status, error_content) - raise HTTPException( - status_code=response.status, - detail= - f"Request failed with status {response.status}: " - f"{error_content}", - ) - except aiohttp.ClientError as e: - logger.error("ClientError occurred: %s", str(e)) - raise HTTPException( - status_code=502, - detail= - "Bad Gateway: Error communicating with upstream server.", - ) from e - except Exception as e: - logger.error("Unexpected error: %s", str(e)) - raise HTTPException(status_code=500, detail=str(e)) from e - - def schedule(self, cycler: itertools.cycle) -> str: - return self.scheduling_policy.schedule(cycler) - - async def get_status(self): - status = { - "prefill_node_count": len(self.prefill_instances), - "decode_node_count": len(self.decode_instances), - "prefill_nodes": self.prefill_instances, - "decode_nodes": self.decode_instances, - } - return status - - async def create_completion(self, raw_request: Request): - try: - request = await raw_request.json() - - kv_prepare_request = request.copy() - kv_prepare_request["max_tokens"] = 1 - - prefill_instance = self.schedule(self.prefill_cycler) - try: - async for _ in self.forward_request( - f"http://{prefill_instance}/v1/completions", - kv_prepare_request): - continue - except HTTPException as http_exc: - self.remove_instance_endpoint("prefill", prefill_instance) - raise http_exc - - # Perform kv recv and decoding stage - decode_instance = self.schedule(self.decode_cycler) - - try: - generator = self.forward_request( - f"http://{decode_instance}/v1/completions", request) - except HTTPException as http_exc: - self.remove_instance_endpoint("decode", decode_instance) - raise http_exc - response = StreamingResponse(generator) - return response - except Exception: - import sys - - exc_info = sys.exc_info() - print("Error occurred in disagg proxy server") - print(exc_info) - - async def create_chat_completion(self, raw_request: Request): - try: - request = await raw_request.json() - - # add params to request - kv_prepare_request = request.copy() - kv_prepare_request["max_tokens"] = 1 - - # prefill stage - prefill_instance = self.schedule(self.prefill_cycler) - try: - async for _ in self.forward_request( - f"http://{prefill_instance}/v1/chat/completions", - kv_prepare_request): - continue - except HTTPException as http_exc: - self.remove_instance_endpoint("prefill", prefill_instance) - raise http_exc - # Perform kv recv and decoding stage - decode_instance = self.schedule(self.decode_cycler) - - try: - generator = self.forward_request( - "http://" + decode_instance + "/v1/chat/completions", - request) - except HTTPException as http_exc: - self.remove_instance_endpoint("decode", decode_instance) - raise http_exc - response = StreamingResponse(content=generator) - return response - except Exception: - exc_info = sys.exc_info() - error_messages = [str(e) for e in exc_info if e] - print("Error occurred in disagg proxy server") - print(error_messages) - return StreamingResponse(content=iter(error_messages), - media_type="text/event-stream") - - def remove_instance_endpoint(self, instance_type, instance): - if (instance_type == "decode" and instance in self.decode_instances): - self.decode_instances.remove(instance) - self.decode_cycler = itertools.cycle(self.decode_instances) - if (instance_type == "prefill" and instance in self.decode_instances): - self.prefill_instances.remove(instance) - self.prefill_cycler = itertools.cycle(self.decode_instances) - - -class RoundRobinSchedulingPolicy(SchedulingPolicy): - - def __init__(self): - super().__init__() - - def schedule(self, cycler: itertools.cycle) -> str: - return next(cycler) - - -class ProxyServer: - - def __init__( - self, - args: argparse.Namespace, - scheduling_policy: Optional[SchedulingPolicy] = None, - create_completion: Optional[Callable[[Request], - StreamingResponse]] = None, - create_chat_completion: Optional[Callable[[Request], - StreamingResponse]] = None, - ): - self.validate_parsed_serve_args(args) - self.port = args.port - self.proxy_instance = Proxy( - prefill_instances=[] if args.prefill is None else args.prefill, - decode_instances=[] if args.decode is None else args.decode, - model=args.model, - scheduling_policy=(scheduling_policy if scheduling_policy - is not None else RoundRobinSchedulingPolicy()), - custom_create_completion=create_completion, - custom_create_chat_completion=create_chat_completion, - ) - - def validate_parsed_serve_args(self, args: argparse.Namespace): - if not args.prefill: - raise ValueError("Please specify at least one prefill node.") - if not args.decode: - raise ValueError("Please specify at least one decode node.") - self.validate_instances(args.prefill) - self.validate_instances(args.decode) - self.verify_model_config(args.prefill, args.model) - self.verify_model_config(args.decode, args.model) - - def validate_instances(self, instances: list): - for instance in instances: - if len(instance.split(":")) != 2: - raise ValueError(f"Invalid instance format: {instance}") - host, port = instance.split(":") - try: - if host != "localhost": - ipaddress.ip_address(host) - port = int(port) - if not (0 < port < 65536): - raise ValueError( - f"Invalid port number in instance: {instance}") - except Exception as e: - raise ValueError( - f"Invalid instance {instance}: {str(e)}") from e - - def verify_model_config(self, instances: list, model: str) -> None: - model_suffix = model.split("/")[-1] - for instance in instances: - try: - response = requests.get(f"http://{instance}/v1/models") - if response.status_code == 200: - model_cur = response.json()["data"][0]["id"] - model_cur_suffix = model_cur.split("/")[-1] - if model_cur_suffix != model_suffix: - raise ValueError( - f"{instance} serves a different model: " - f"{model_cur} != {model}") - else: - raise ValueError(f"Cannot get model id from {instance}!") - except requests.RequestException as e: - raise ValueError( - f"Error communicating with {instance}: {str(e)}") from e - - def run_server(self): - app = FastAPI() - app.include_router(self.proxy_instance.router) - config = uvicorn.Config(app, port=self.port, loop="uvloop") - server = uvicorn.Server(config) - server.run() - - -if __name__ == "__main__": - # Todo: allow more config - parser = argparse.ArgumentParser("vLLM disaggregated proxy server.") - parser.add_argument("--model", - "-m", - type=str, - required=True, - help="Model name") - - parser.add_argument( - "--prefill", - "-p", - type=str, - nargs="+", - help="List of prefill node URLs (host:port)", - ) - - parser.add_argument( - "--decode", - "-d", - type=str, - nargs="+", - help="List of decode node URLs (host:port)", - ) - - parser.add_argument( - "--port", - type=int, - default=8000, - help="Server port number", - ) - args = parser.parse_args() - proxy_server = ProxyServer(args=args) - proxy_server.run_server() From ec3ed2ef73fe7e2828c9c5f572a44eb5384c3f99 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 19:07:04 +0000 Subject: [PATCH 14/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/distributed/kv_transfer/kv_connector/v1/lm_cache/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/lm_cache/__init__.py diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lm_cache/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/lm_cache/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 From 5f73147e1b8bcc5b2c4e2490085eab3c92f012d0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 19:07:43 +0000 Subject: [PATCH 15/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/sampling_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index ce3eeee36e34..a25ce394fb20 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -250,7 +250,7 @@ class SamplingParams( bad_words: Optional[list[str]] = None _bad_words_token_ids: Optional[list[list[int]]] = None - # Fields used for kv cache transfer + # Fields used for kv cache transfer in P/D setup kv_transfer_params: Optional[KVTransferParams] = None @staticmethod From ab22fb8d0f55a13211a47c3d45d9b11c932e811d Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 19:08:17 +0000 Subject: [PATCH 16/33] update Signed-off-by: rshaw@neuralmagic.com --- vllm/sampling_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a25ce394fb20..ce7d09a89fc4 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -250,7 +250,7 @@ class SamplingParams( bad_words: Optional[list[str]] = None _bad_words_token_ids: Optional[list[list[int]]] = None - # Fields used for kv cache transfer in P/D setup + # Fields used for KVTransfer in disaggregated serving. kv_transfer_params: Optional[KVTransferParams] = None @staticmethod From 09ff5809a958ea1a7df74df88a29b285031dc668 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 19:10:11 +0000 Subject: [PATCH 17/33] pr readability Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 16c9330ae92c..56ba10abfa6b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -740,8 +740,6 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - - # NOTE: new_token_ids is None if we have a partial prefill. if new_token_ids: # If remote_decode, stop the request in the engine and add # it to the sending KVs state. We hold onto this request in the From ecfd8d6e3aeccb1201ef5f61485f5f2781dfe516 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 19:18:47 +0000 Subject: [PATCH 18/33] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/lmcache_connector.py | 13 +++++++++++++ vllm/v1/core/sched/scheduler.py | 9 ++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 3b64c14361a4..8491ba3ebff8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext + from vllm.sampling_params import KVTransferParams from vllm.v1.request import Request logger = init_logger(__name__) @@ -130,6 +131,18 @@ def build_connector_meta( """ return self._lmcache_engine.build_connector_meta(scheduler_output) + def build_transfer_params(self, request: "Request") -> "KVTransferParams": + """ + Build the KVTransferParams for the request. + """ + + return KVTransferParams( + request_id=request.request_id, + remote_instance_id=self.instance_id, + remote_block_ids=request.block_ids, + do_remote_prefill=True, + ) + # These return true for now since they are not async def is_request_done_sending(self, req_id: str) -> bool: return True diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 56ba10abfa6b..8ce52ed0f1a3 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -741,16 +741,15 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: - # If remote_decode, stop the request in the engine and add - # it to the sending KVs state. We hold onto this request in the - # engine until the sending is done. + # If remote_decode, stop the request. Note that the request + # is not freed until the sending is complete. kv_transfer_params = None if request.do_remote_decode and not stopped: - assert self.connector is not None stopped = True request.status = RequestStatus.FINISHED_REMOTE_DECODE self.sending_KV_req_ids.add(req_id) - kv_transfer_params = self.connector.make_transfer_params( + assert self.connector is not None + kv_transfer_params = self.connector.build_transfer_params( request=request, remote_decode=True) # Add EngineCoreOutput for this Request. From 51fac6c81430941fbfb447eb25e03b0f226576d9 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 19:19:43 +0000 Subject: [PATCH 19/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index f52790ae8910..54a86ba00c63 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -61,7 +61,7 @@ def __init__( self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 - # P/D disagg related + # Disaggregated serving related self.do_remote_decode = ( False if sampling_params.kv_transfer_params is None else sampling_params.kv_transfer_params.do_remote_decode) From 9f6109d8fbb5f2cc1de875d5e3cebe1bcd8b03a7 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 19:47:57 +0000 Subject: [PATCH 20/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 16 ++++++++++++---- vllm/v1/request.py | 1 + 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8ce52ed0f1a3..ff3b08c6afdb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -175,10 +175,6 @@ def schedule(self) -> SchedulerOutput: # Check for new remote decode requests for P/D new_KV_requests_to_send: list[Request] = [] if self.connector is not None: - # TODO: Receive request over ZMQ - # self.receiving_KV_req_ids.update( - # self.connector.receive_remote_decode_requests()) - # Check if any P/D requests have finished sending or receiving for req_id in list(self.waiting_to_send_KV_req_ids): self.sending_KV_req_ids.add(req_id) @@ -328,6 +324,18 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.appendleft(request) continue + # Check whether the + # TODO(rob): we should do this after we allocate the blocks. + # TODO(rob): this logic is incorrect if the req was preempted. + if request.do_remote_decode: + assert self.connector is not None + if not self.connector.is_request_done_receiving(req_id): + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + self.receiving_KV_req_ids.add(request.request_id) + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + # Check that adding the request still respects the max_loras # constraint. if self.lora_config and request.lora_request and ( diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 54a86ba00c63..9be39c8d1e46 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -159,6 +159,7 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() + WAITING_FOR_REMOTE_KVS = enum.auto() RUNNING = enum.auto() SENDING_KV = enum.auto() PREEMPTED = enum.auto() From 82f001b2712222db7a1bb0083788d1bceeba5722 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 20:00:14 +0000 Subject: [PATCH 21/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ff3b08c6afdb..921ac80cb31e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -324,8 +324,8 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.appendleft(request) continue - # Check whether the - # TODO(rob): we should do this after we allocate the blocks. + # TODO(rob): we should do this after we allocate the blocks if + # we want to write directly into the BlockTable (like Dynamo). # TODO(rob): this logic is incorrect if the req was preempted. if request.do_remote_decode: assert self.connector is not None @@ -749,8 +749,9 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: - # If remote_decode, stop the request. Note that the request - # is not freed until the sending is complete. + # Stop request after the first token if doing a remote_decode. + # NOTE(rob): req is not freed (or preempted) in the EngineCore + # until the xfer is done to ensure we do not free the KV blocks. kv_transfer_params = None if request.do_remote_decode and not stopped: stopped = True @@ -758,7 +759,7 @@ def update_from_output( self.sending_KV_req_ids.add(req_id) assert self.connector is not None kv_transfer_params = self.connector.build_transfer_params( - request=request, remote_decode=True) + request) # Add EngineCoreOutput for this Request. outputs.append( From 3a768c20c9a63f5e2497085b24ff2cb283bc2769 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 20:04:03 +0000 Subject: [PATCH 22/33] updated Signed-off-by: rshaw@neuralmagic.com --- .../distributed/kv_transfer/kv_connector/v1/base.py | 8 -------- .../kv_connector/v1/lmcache_connector.py | 13 ------------- .../kv_connector/v1/shared_storage_connector.py | 13 ------------- vllm/sampling_params.py | 5 +++-- vllm/v1/core/sched/scheduler.py | 8 ++++++-- 5 files changed, 9 insertions(+), 38 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 7b71a67a2ef9..fc67d118070e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -34,7 +34,6 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.forward_context import ForwardContext - from vllm.sampling_params import KVTransferParams from vllm.v1.request import Request logger = init_logger(__name__) @@ -217,10 +216,3 @@ def is_request_done_sending(self, req_id: str) -> bool: def is_request_done_receiving(self, req_id: str) -> bool: raise NotImplementedError - - @abstractmethod - def build_transfer_params(self, request: "Request") -> "KVTransferParams": - """ - Build the KV transfer parameters for this step. - """ - pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 8491ba3ebff8..3b64c14361a4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -13,7 +13,6 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext - from vllm.sampling_params import KVTransferParams from vllm.v1.request import Request logger = init_logger(__name__) @@ -131,18 +130,6 @@ def build_connector_meta( """ return self._lmcache_engine.build_connector_meta(scheduler_output) - def build_transfer_params(self, request: "Request") -> "KVTransferParams": - """ - Build the KVTransferParams for the request. - """ - - return KVTransferParams( - request_id=request.request_id, - remote_instance_id=self.instance_id, - remote_block_ids=request.block_ids, - do_remote_prefill=True, - ) - # These return true for now since they are not async def is_request_done_sending(self, req_id: str) -> bool: return True diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 42dd3e0eb898..9e4ce253b618 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -17,7 +17,6 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext - from vllm.sampling_params import KVTransferParams from vllm.v1.request import Request logger = init_logger(__name__) @@ -337,18 +336,6 @@ def is_request_done_sending(self, req_id: str) -> bool: def is_request_done_receiving(self, req_id: str) -> bool: return True - def build_transfer_params(self, request: "Request") -> "KVTransferParams": - """ - Build the KVTransferParams for the request. - """ - - return KVTransferParams( - request_id=request.request_id, - remote_instance_id=self.instance_id, - remote_block_ids=request.block_ids, - do_remote_prefill=True, - ) - # ============================== # Helper functions # ============================== diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index ce7d09a89fc4..dbcd117d7b34 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -33,8 +33,9 @@ class KVTransferParams( # required for @cached_property. dict=True): request_id: str - remote_instance_id: Optional[str] = None - remote_block_ids: Optional[list[int]] = None + # TODO(rob): we can handle xPyD and direct KV block Xfer. + # remote_instance_id: Optional[str] = None + # remote_block_ids: Optional[list[int]] = None do_remote_decode: bool = False do_remote_prefill: bool = False diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 921ac80cb31e..1f92ea96f79f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.sampling_params import KVTransferParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -758,8 +759,11 @@ def update_from_output( request.status = RequestStatus.FINISHED_REMOTE_DECODE self.sending_KV_req_ids.add(req_id) assert self.connector is not None - kv_transfer_params = self.connector.build_transfer_params( - request) + # TODO(rob): do this on a per-Connector basis. + kv_transfer_params = KVTransferParams( + request_id=request.request_id, + do_remote_prefill=True, + ) # Add EngineCoreOutput for this Request. outputs.append( From fd0e92d409cc991e37f016f353d0e88f3717754e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 20:04:56 +0000 Subject: [PATCH 23/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/sampling_params.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index dbcd117d7b34..a58d48788d63 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -33,7 +33,8 @@ class KVTransferParams( # required for @cached_property. dict=True): request_id: str - # TODO(rob): we can handle xPyD and direct KV block Xfer. + # TODO(rob): we can handle xPyD and direct KV block Xfer + # by passing these data. # remote_instance_id: Optional[str] = None # remote_block_ids: Optional[list[int]] = None do_remote_decode: bool = False From 77cf32b56cfb1ad00f0c0eddaaf2c9eca2396c27 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 20:05:54 +0000 Subject: [PATCH 24/33] update beam search Signed-off-by: rshaw@neuralmagic.com --- vllm/entrypoints/openai/serving_completion.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index f64e03fb0ab0..d08b0059e2cf 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -181,13 +181,10 @@ async def create_completion( # Similar to the OpenAI API, when n != best_of, we do not stream the # results. Noting that best_of is only supported in V0. In addition, - # we do not stream the results when use beam search or if the request - # should do the decode phase remotely. - do_remote_decode = (request.kv_transfer_params - and request.kv_transfer_params.do_remote_decode) + # we do not stream the results when use beam search. stream = (request.stream and (request.best_of is None or request.n == request.best_of) - and not request.use_beam_search and not do_remote_decode) + and not request.use_beam_search) # Streaming response if stream: From 319095f371389f16f588213a9225a5977ac1d6ff Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 20:31:46 +0000 Subject: [PATCH 25/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1f92ea96f79f..8ac5d9d1fc8d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -757,7 +757,7 @@ def update_from_output( if request.do_remote_decode and not stopped: stopped = True request.status = RequestStatus.FINISHED_REMOTE_DECODE - self.sending_KV_req_ids.add(req_id) + self.waiting_to_send_KV_req_ids.add(req_id) assert self.connector is not None # TODO(rob): do this on a per-Connector basis. kv_transfer_params = KVTransferParams( From 59fab97535bcedc29a23e22ee4a1d1267ea13166 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 20:33:56 +0000 Subject: [PATCH 26/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/request.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 9be39c8d1e46..11722c4ccc9a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -65,10 +65,6 @@ def __init__( self.do_remote_decode = ( False if sampling_params.kv_transfer_params is None else sampling_params.kv_transfer_params.do_remote_decode) - self.do_remote_prefill = ( - False if sampling_params.kv_transfer_params is None else - sampling_params.kv_transfer_params.do_remote_decode) - assert not (self.do_remote_decode and self.do_remote_prefill) # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) From 5a9416408673a7fcbc800352f0d50101691634c4 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 20:36:51 +0000 Subject: [PATCH 27/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8ac5d9d1fc8d..deb468c06049 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -760,9 +760,13 @@ def update_from_output( self.waiting_to_send_KV_req_ids.add(req_id) assert self.connector is not None # TODO(rob): do this on a per-Connector basis. + # NOTE(rob): this KVTransferParams will be sent to the + # DWorker. From the POV of the DWorker, it should be a + # remote Prefill. kv_transfer_params = KVTransferParams( request_id=request.request_id, do_remote_prefill=True, + do_remote_decode=False, ) # Add EngineCoreOutput for this Request. From 8408dbac7deb1f6a660bb816cf45fed85fd44f5e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 19 Apr 2025 21:45:27 +0000 Subject: [PATCH 28/33] stash Signed-off-by: rshaw@neuralmagic.com --- examples/other/LMCache/disagg_proxy_server.py | 13 ++++++++----- vllm/entrypoints/openai/protocol.py | 18 ++++++++++++++---- vllm/entrypoints/openai/serving_completion.py | 5 ++++- vllm/sampling_params.py | 16 ++++++++++++++-- vllm/v1/core/sched/scheduler.py | 5 +---- 5 files changed, 41 insertions(+), 16 deletions(-) diff --git a/examples/other/LMCache/disagg_proxy_server.py b/examples/other/LMCache/disagg_proxy_server.py index 8db93bc8931b..1ce62f0c74e4 100644 --- a/examples/other/LMCache/disagg_proxy_server.py +++ b/examples/other/LMCache/disagg_proxy_server.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import json import os import time from contextlib import asynccontextmanager @@ -88,9 +89,10 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, Send a request to a service using a persistent client. """ req_data = req_data.copy() - req_data['max_tokens'] = 1 - if 'max_completion_tokens' in req_data: - req_data['max_completion_tokens'] = 1 + req_data['kv_transfer_params'] = json.dumps({ + "do_remote_decode": True, + }) + req_data["stream"] = False headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} response = await client.post(endpoint, json=req_data, headers=headers) @@ -121,8 +123,9 @@ async def handle_completions(request: Request): req_data = await request.json() # Send request to prefill service, ignore the response - await send_request_to_service(app.state.prefill_client, "/completions", - req_data) + response = await send_request_to_service(app.state.prefill_client, + "/completions", req_data) + print(response) et = time.time() stats_calculator.add(et - st) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f2b10cee3fbd..c372dfd8a383 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -808,8 +808,12 @@ class CompletionRequest(OpenAIBaseModel): " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) - kv_transfer_params: Optional[KVTransferParams] = Field( - default=None, + do_remote_decode: bool = Field( + default=False, + description="KVTransfer parameters used for disaggregated serving.") + + do_remote_prefill: bool = Field( + default=False, description="KVTransfer parameters used for disaggregated serving.") # doc: end-completion-extra-params @@ -909,6 +913,11 @@ def to_sampling_params( whitespace_pattern=self.guided_whitespace_pattern, ) + kv_transfer_params = KVTransferParams.from_optional( + do_remote_decode=self.do_remote_decode, + do_remote_prefill=self.do_remote_prefill, + ) + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -938,7 +947,7 @@ def to_sampling_params( guided_decoding=guided_decoding, logit_bias=self.logit_bias, allowed_token_ids=self.allowed_token_ids, - kv_transfer_params=self.kv_transfer_params, + kv_transfer_params=kv_transfer_params, ) @model_validator(mode="before") @@ -1189,7 +1198,8 @@ class CompletionResponse(OpenAIBaseModel): model: str choices: list[CompletionResponseChoice] usage: UsageInfo - kv_transfer_params: Optional[KVTransferParams] = Field(default=None) + # TODO: make this into a pydanic object + do_remote_prefill: Optional[bool] = Field(default=None) class CompletionResponseStreamChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index d08b0059e2cf..e704cfb2b775 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -476,13 +476,16 @@ def request_output_to_completion_response( request_metadata.final_usage_info = usage + do_remote_prefill = ( + final_res_batch[0].kv_transfer_params + and final_res_batch[0].kv_transfer_params.do_remote_prefill) return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, - kv_transfer_params=final_res_batch[0].kv_transfer_params, + do_remote_prefill=do_remote_prefill, ) def _create_completion_logprobs( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a58d48788d63..2504b0367b04 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -32,14 +32,26 @@ class KVTransferParams( omit_defaults=True, # type: ignore[call-arg] # required for @cached_property. dict=True): - request_id: str # TODO(rob): we can handle xPyD and direct KV block Xfer - # by passing these data. # remote_instance_id: Optional[str] = None # remote_block_ids: Optional[list[int]] = None do_remote_decode: bool = False do_remote_prefill: bool = False + @staticmethod + def from_optional(do_remote_decode: bool, + do_remote_prefill: bool) -> Optional["KVTransferParams"]: + if do_remote_prefill and do_remote_prefill: + raise ValueError( + "Cannot do both remote prefill and remote decode.") + elif do_remote_decode or do_remote_prefill: + return KVTransferParams( + do_remote_decode=do_remote_decode, + do_remote_prefill=do_remote_prefill, + ) + else: + return None + # maybe make msgspec? @dataclass diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index deb468c06049..c1bcf99c0577 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -764,10 +764,7 @@ def update_from_output( # DWorker. From the POV of the DWorker, it should be a # remote Prefill. kv_transfer_params = KVTransferParams( - request_id=request.request_id, - do_remote_prefill=True, - do_remote_decode=False, - ) + do_remote_prefill=True, ) # Add EngineCoreOutput for this Request. outputs.append( From 635c858bec6bcfd2552b0d2cc01c6ec30e880666 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 20 Apr 2025 12:58:21 +0000 Subject: [PATCH 29/33] updated Signed-off-by: rshaw@neuralmagic.com --- .../online_serving/openai_completion_client.py | 12 +++++------- examples/other/LMCache/disagg-example.sh | 6 ++---- examples/other/LMCache/disagg_proxy_server.py | 5 +---- vllm/v1/core/sched/output.py | 2 +- vllm/v1/core/sched/scheduler.py | 18 ++++++++---------- 5 files changed, 17 insertions(+), 26 deletions(-) diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index 6ab7619bff19..7e26882192e3 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -4,7 +4,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" +openai_api_base = "http://localhost:9000/v1" def main(): @@ -14,18 +14,16 @@ def main(): base_url=openai_api_base, ) - models = client.models.list() - model = models.data[0].id + # models = client.models.list() + # model = models.data[0].id # Completion API stream = False completion = client.completions.create( - model=model, + model="meta-llama/Llama-3.1-8B-Instruct", prompt="A robot may not injure a human being", echo=False, - n=2, - stream=stream, - logprobs=3) + stream=stream) print("-" * 50) print("Completion results:") diff --git a/examples/other/LMCache/disagg-example.sh b/examples/other/LMCache/disagg-example.sh index 43b0b59c88f8..89f1c753e887 100644 --- a/examples/other/LMCache/disagg-example.sh +++ b/examples/other/LMCache/disagg-example.sh @@ -25,10 +25,9 @@ if [[ $1 == "prefiller" ]]; then LMCACHE_USE_EXPERIMENTAL=True \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ - CUDA_VISIBLE_DEVICES=0 \ + CUDA_VISIBLE_DEVICES=6 \ vllm serve $MODEL \ --port 8100 \ - --disable-log-requests \ --enforce-eager \ --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' @@ -46,10 +45,9 @@ elif [[ $1 == "decoder" ]]; then LMCACHE_USE_EXPERIMENTAL=True \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ - CUDA_VISIBLE_DEVICES=1 \ + CUDA_VISIBLE_DEVICES=7 \ vllm serve $MODEL \ --port 8200 \ - --disable-log-requests \ --enforce-eager \ --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' diff --git a/examples/other/LMCache/disagg_proxy_server.py b/examples/other/LMCache/disagg_proxy_server.py index 1ce62f0c74e4..2cf75de27a7e 100644 --- a/examples/other/LMCache/disagg_proxy_server.py +++ b/examples/other/LMCache/disagg_proxy_server.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import json import os import time from contextlib import asynccontextmanager @@ -89,9 +88,7 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, Send a request to a service using a persistent client. """ req_data = req_data.copy() - req_data['kv_transfer_params'] = json.dumps({ - "do_remote_decode": True, - }) + req_data['do_remote_decode'] = True req_data["stream"] = False headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 6263743d9710..297a2d2a1355 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -128,4 +128,4 @@ class SchedulerOutput: kv_connector_metadata: Optional[KVConnectorMetadata] = None sending_KV_req_ids: set[str] = field(default_factory=set) receiving_KV_req_ids: set[str] = field(default_factory=set) - new_KV_requests_to_send: list[NewRequestData] = field(default_factory=list) + new_KV_req_ids_to_send: list[str] = field(default_factory=list) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c1bcf99c0577..726bf8cc1d18 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -68,9 +68,11 @@ def __init__( # will have a corresponding KVConnector with Role=WORKER. # KV Connector pushes/pull of remote KVs for P/D and offloading. self.connector = None + print("================== START CREATE CONNECTOR ==================") if self.vllm_config.kv_transfer_config is not None: self.connector = KVConnectorFactory.create_connector_v1( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + print("================== END CREATE CONNECTOR ==================") num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0 @@ -174,13 +176,13 @@ def schedule(self) -> SchedulerOutput: scheduled_timestamp = time.monotonic() # Check for new remote decode requests for P/D - new_KV_requests_to_send: list[Request] = [] + new_KV_req_ids_to_send: list[str] = [] if self.connector is not None: # Check if any P/D requests have finished sending or receiving for req_id in list(self.waiting_to_send_KV_req_ids): self.sending_KV_req_ids.add(req_id) self.waiting_to_send_KV_req_ids.remove(req_id) - new_KV_requests_to_send.append(self.requests[req_id]) + new_KV_req_ids_to_send.append(req_id) for req_id in list(self.sending_KV_req_ids): if self.connector.is_request_done_sending(req_id): @@ -330,7 +332,8 @@ def schedule(self) -> SchedulerOutput: # TODO(rob): this logic is incorrect if the req was preempted. if request.do_remote_decode: assert self.connector is not None - if not self.connector.is_request_done_receiving(req_id): + if not self.connector.is_request_done_receiving( + request.request_id): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS self.receiving_KV_req_ids.add(request.request_id) self.waiting.popleft() @@ -521,12 +524,7 @@ def schedule(self) -> SchedulerOutput: # TODO: encapsulate these in the KV connector metadata scheduler_output.sending_KV_req_ids = self.sending_KV_req_ids scheduler_output.receiving_KV_req_ids = self.receiving_KV_req_ids - new_KV_to_send_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_block_ids[req.request_id]) - for req in new_KV_requests_to_send - ] - scheduler_output.new_KV_requests_to_send = new_KV_to_send_reqs_data + scheduler_output.new_KV_req_ids_to_send = new_KV_req_ids_to_send # Advance the number of computed tokens for the request AFTER # the request is scheduled. @@ -764,7 +762,7 @@ def update_from_output( # DWorker. From the POV of the DWorker, it should be a # remote Prefill. kv_transfer_params = KVTransferParams( - do_remote_prefill=True, ) + do_remote_prefill=True) # Add EngineCoreOutput for this Request. outputs.append( From 7d8b39b1f871854fb186454261eeae9a93873ce7 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 20 Apr 2025 13:13:57 +0000 Subject: [PATCH 30/33] prefill worker is 'working' Signed-off-by: rshaw@neuralmagic.com --- examples/online_serving/openai_completion_client.py | 3 ++- examples/other/LMCache/disagg_proxy_server.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index 7e26882192e3..f1e6175268de 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -21,7 +21,8 @@ def main(): stream = False completion = client.completions.create( model="meta-llama/Llama-3.1-8B-Instruct", - prompt="A robot may not injure a human being", + prompt= + "The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, which means that there is a lot of opportunity to work with community and customers on key infrastructure projects. This means", # noqa: E501 echo=False, stream=stream) diff --git a/examples/other/LMCache/disagg_proxy_server.py b/examples/other/LMCache/disagg_proxy_server.py index 2cf75de27a7e..5e052400e100 100644 --- a/examples/other/LMCache/disagg_proxy_server.py +++ b/examples/other/LMCache/disagg_proxy_server.py @@ -94,6 +94,7 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} response = await client.post(endpoint, json=req_data, headers=headers) response.raise_for_status() + return response @@ -103,6 +104,7 @@ async def stream_service_response(client: httpx.AsyncClient, endpoint: str, Asynchronously stream the response from a service using a persistent client. """ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + req_data['do_remote_prefill'] = True async with client.stream("POST", endpoint, json=req_data, headers=headers) as response: response.raise_for_status() @@ -120,9 +122,8 @@ async def handle_completions(request: Request): req_data = await request.json() # Send request to prefill service, ignore the response - response = await send_request_to_service(app.state.prefill_client, - "/completions", req_data) - print(response) + await send_request_to_service(app.state.prefill_client, "/completions", + req_data) et = time.time() stats_calculator.add(et - st) From 9b4daf394b758c00c931c141149a4d05733c0575 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 20 Apr 2025 14:35:15 +0000 Subject: [PATCH 31/33] updated Signed-off-by: rshaw@neuralmagic.com --- .../openai_completion_client2.py | 39 +++++++++++++++++++ examples/other/LMCache/disagg_proxy_server.py | 2 +- vllm/v1/core/sched/scheduler.py | 16 ++++---- vllm/v1/worker/gpu_model_runner.py | 6 +++ 4 files changed, 53 insertions(+), 10 deletions(-) create mode 100644 examples/online_serving/openai_completion_client2.py diff --git a/examples/online_serving/openai_completion_client2.py b/examples/online_serving/openai_completion_client2.py new file mode 100644 index 000000000000..fb6d0d120b5e --- /dev/null +++ b/examples/online_serving/openai_completion_client2.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 + +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8100/v1" + + +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + # models = client.models.list() + # model = models.data[0].id + + # Completion API + stream = False + completion = client.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + prompt="The quick brown job jumped", + echo=False, + stream=stream) + + print("-" * 50) + print("Completion results:") + if stream: + for c in completion: + print(c) + else: + print(completion) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/other/LMCache/disagg_proxy_server.py b/examples/other/LMCache/disagg_proxy_server.py index 5e052400e100..e81c1ace7873 100644 --- a/examples/other/LMCache/disagg_proxy_server.py +++ b/examples/other/LMCache/disagg_proxy_server.py @@ -104,7 +104,7 @@ async def stream_service_response(client: httpx.AsyncClient, endpoint: str, Asynchronously stream the response from a service using a persistent client. """ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} - req_data['do_remote_prefill'] = True + # req_data['do_remote_prefill'] = True async with client.stream("POST", endpoint, json=req_data, headers=headers) as response: response.raise_for_status() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 726bf8cc1d18..9fa04ab93e13 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -178,12 +178,6 @@ def schedule(self) -> SchedulerOutput: # Check for new remote decode requests for P/D new_KV_req_ids_to_send: list[str] = [] if self.connector is not None: - # Check if any P/D requests have finished sending or receiving - for req_id in list(self.waiting_to_send_KV_req_ids): - self.sending_KV_req_ids.add(req_id) - self.waiting_to_send_KV_req_ids.remove(req_id) - new_KV_req_ids_to_send.append(req_id) - for req_id in list(self.sending_KV_req_ids): if self.connector.is_request_done_sending(req_id): self.sending_KV_req_ids.remove(req_id) @@ -192,6 +186,10 @@ def schedule(self) -> SchedulerOutput: if self.connector.is_request_done_receiving(req_id): self.receiving_KV_req_ids.remove(req_id) self.waiting.append(self.requests[req_id]) + for req_id in list(self.waiting_to_send_KV_req_ids): + self.sending_KV_req_ids.add(req_id) + self.waiting_to_send_KV_req_ids.remove(req_id) + new_KV_req_ids_to_send.append(req_id) # First, schedule the RUNNING requests. req_index = 0 @@ -518,14 +516,14 @@ def schedule(self) -> SchedulerOutput: # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) - scheduler_output.kv_connector_metadata = meta - # TODO: encapsulate these in the KV connector metadata scheduler_output.sending_KV_req_ids = self.sending_KV_req_ids scheduler_output.receiving_KV_req_ids = self.receiving_KV_req_ids scheduler_output.new_KV_req_ids_to_send = new_KV_req_ids_to_send + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + # Advance the number of computed tokens for the request AFTER # the request is scheduled. # 1. The scheduler_output of the current step has to include the diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c780bbe40934..e6c61ecb0254 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1016,6 +1016,11 @@ def maybe_setup_kv_connector(): if get_forward_context().attn_metadata is not None: kv_connector.start_load_kv(get_forward_context()) + def maybe_wait_for_save(): + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + kv_connector.wait_for_save() + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOutput if there's no work to do. @@ -1094,6 +1099,7 @@ def maybe_setup_kv_connector(): intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + maybe_wait_for_save() if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. return hidden_states From 6bf63a3eb88e4fc2b5c8ff4e0b40719e76fa55e4 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 20 Apr 2025 14:38:58 +0000 Subject: [PATCH 32/33] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/entrypoints/openai/protocol.py | 2 -- vllm/entrypoints/openai/serving_completion.py | 4 ---- 2 files changed, 6 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c372dfd8a383..c56a15af1367 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1198,8 +1198,6 @@ class CompletionResponse(OpenAIBaseModel): model: str choices: list[CompletionResponseChoice] usage: UsageInfo - # TODO: make this into a pydanic object - do_remote_prefill: Optional[bool] = Field(default=None) class CompletionResponseStreamChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e704cfb2b775..1067f35ce240 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -476,16 +476,12 @@ def request_output_to_completion_response( request_metadata.final_usage_info = usage - do_remote_prefill = ( - final_res_batch[0].kv_transfer_params - and final_res_batch[0].kv_transfer_params.do_remote_prefill) return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, - do_remote_prefill=do_remote_prefill, ) def _create_completion_logprobs( From c785762a74cf7b84dff69d19e6243a05d8235f5f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 20 Apr 2025 14:40:19 +0000 Subject: [PATCH 33/33] updated Signed-off-by: rshaw@neuralmagic.com --- examples/other/LMCache/disagg_proxy_server.py | 3 +-- vllm/v1/core/sched/scheduler.py | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/other/LMCache/disagg_proxy_server.py b/examples/other/LMCache/disagg_proxy_server.py index e81c1ace7873..2639409a1522 100644 --- a/examples/other/LMCache/disagg_proxy_server.py +++ b/examples/other/LMCache/disagg_proxy_server.py @@ -90,7 +90,6 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, req_data = req_data.copy() req_data['do_remote_decode'] = True req_data["stream"] = False - headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} response = await client.post(endpoint, json=req_data, headers=headers) response.raise_for_status() @@ -104,7 +103,7 @@ async def stream_service_response(client: httpx.AsyncClient, endpoint: str, Asynchronously stream the response from a service using a persistent client. """ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} - # req_data['do_remote_prefill'] = True + req_data['do_remote_prefill'] = True async with client.stream("POST", endpoint, json=req_data, headers=headers) as response: response.raise_for_status() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9fa04ab93e13..000375e6a533 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -68,11 +68,9 @@ def __init__( # will have a corresponding KVConnector with Role=WORKER. # KV Connector pushes/pull of remote KVs for P/D and offloading. self.connector = None - print("================== START CREATE CONNECTOR ==================") if self.vllm_config.kv_transfer_config is not None: self.connector = KVConnectorFactory.create_connector_v1( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) - print("================== END CREATE CONNECTOR ==================") num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0