Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
891a5da
updated
robertgshaw2-redhat Apr 19, 2025
0fde021
updated
robertgshaw2-redhat Apr 19, 2025
2fb8eea
updated
robertgshaw2-redhat Apr 19, 2025
2f98b23
updated
robertgshaw2-redhat Apr 19, 2025
ad63d9a
updated
robertgshaw2-redhat Apr 19, 2025
54f1e64
updated
robertgshaw2-redhat Apr 19, 2025
a125762
updated
robertgshaw2-redhat Apr 19, 2025
dd3e299
updated
robertgshaw2-redhat Apr 19, 2025
dd8df0c
updated
robertgshaw2-redhat Apr 19, 2025
69c9d13
updated
robertgshaw2-redhat Apr 19, 2025
778389b
updated
robertgshaw2-redhat Apr 19, 2025
6b2fa35
stash
robertgshaw2-redhat Apr 19, 2025
a4ab996
updated
robertgshaw2-redhat Apr 19, 2025
ec3ed2e
updated
robertgshaw2-redhat Apr 19, 2025
5f73147
updated
robertgshaw2-redhat Apr 19, 2025
ab22fb8
update
robertgshaw2-redhat Apr 19, 2025
09ff580
pr readability
robertgshaw2-redhat Apr 19, 2025
ecfd8d6
updated
robertgshaw2-redhat Apr 19, 2025
51fac6c
updated
robertgshaw2-redhat Apr 19, 2025
9f6109d
updated
robertgshaw2-redhat Apr 19, 2025
82f001b
updated
robertgshaw2-redhat Apr 19, 2025
3a768c2
updated
robertgshaw2-redhat Apr 19, 2025
fd0e92d
updated
robertgshaw2-redhat Apr 19, 2025
77cf32b
update beam search
robertgshaw2-redhat Apr 19, 2025
319095f
updated
robertgshaw2-redhat Apr 19, 2025
59fab97
updated
robertgshaw2-redhat Apr 19, 2025
5a94164
updated
robertgshaw2-redhat Apr 19, 2025
8408dba
stash
robertgshaw2-redhat Apr 19, 2025
635c858
updated
robertgshaw2-redhat Apr 20, 2025
7d8b39b
prefill worker is 'working'
robertgshaw2-redhat Apr 20, 2025
9b4daf3
updated
robertgshaw2-redhat Apr 20, 2025
6bf63a3
updated
robertgshaw2-redhat Apr 20, 2025
c785762
updated
robertgshaw2-redhat Apr 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from vllm.logger import init_logger
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
KVTransferParams, RequestOutputKind,
SamplingParams)
from vllm.sequence import Logprob
from vllm.utils import random_uuid, resolve_obj_by_qualname

Expand Down Expand Up @@ -807,6 +808,10 @@ class CompletionRequest(OpenAIBaseModel):
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))

kv_transfer_params: Optional[KVTransferParams] = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.")

# doc: end-completion-extra-params

# Default sampling parameters for completion requests
Expand Down Expand Up @@ -932,7 +937,9 @@ def to_sampling_params(
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids)
allowed_token_ids=self.allowed_token_ids,
kv_transfer_params=self.kv_transfer_params,
)

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -1182,6 +1189,7 @@ class CompletionResponse(OpenAIBaseModel):
model: str
choices: list[CompletionResponseChoice]
usage: UsageInfo
kv_transfer_params: Optional[KVTransferParams] = Field(default=None)


class CompletionResponseStreamChoice(OpenAIBaseModel):
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def request_output_to_completion_response(
model=model_name,
choices=choices,
usage=usage,
kv_transfer_params=final_res_batch[0].kv_transfer_params,
)

def _create_completion_logprobs(
Expand Down
6 changes: 5 additions & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind
from vllm.sampling_params import KVTransferParams, RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceGroupBase, SequenceStatus)

Expand Down Expand Up @@ -103,6 +103,7 @@ class RequestOutput:
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
kv_transfer_params: The params for remote K/V transfer.
"""

def __init__(
Expand All @@ -120,6 +121,7 @@ def __init__(
num_cached_tokens: Optional[int] = None,
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
Expand All @@ -133,11 +135,13 @@ def __init__(
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params

def add(self, next_output: "RequestOutput") -> None:
"""Merge subsequent RequestOutput into this one"""

self.finished |= next_output.finished
self.kv_transfer_params = next_output.kv_transfer_params

for next_completion in next_output.outputs:
for completion in self.outputs:
Expand Down
20 changes: 20 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ class SamplingType(IntEnum):
RANDOM_SEED = 2


# TODO(rob): make this per connector
class KVTransferParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True):
request_id: str
# TODO(rob): we can handle xPyD and direct KV block Xfer
# by passing these data.
# remote_instance_id: Optional[str] = None
# remote_block_ids: Optional[list[int]] = None
do_remote_decode: bool = False
do_remote_prefill: bool = False


# maybe make msgspec?
@dataclass
class GuidedDecodingParams:
Expand Down Expand Up @@ -237,6 +252,9 @@ class SamplingParams(
bad_words: Optional[list[str]] = None
_bad_words_token_ids: Optional[list[list[int]]] = None

# Fields used for KVTransfer in disaggregated serving.
kv_transfer_params: Optional[KVTransferParams] = None

@staticmethod
def from_optional(
n: Optional[int] = 1,
Expand Down Expand Up @@ -268,6 +286,7 @@ def from_optional(
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[list[int]] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
extra_args: Optional[dict[str, Any]] = None,
) -> "SamplingParams":
if logit_bias is not None:
Expand Down Expand Up @@ -310,6 +329,7 @@ def from_optional(
guided_decoding=guided_decoding,
logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids,
kv_transfer_params=kv_transfer_params,
extra_args=extra_args,
)

Expand Down
45 changes: 32 additions & 13 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.sampling_params import KVTransferParams
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
Expand Down Expand Up @@ -175,10 +176,6 @@ def schedule(self) -> SchedulerOutput:
# Check for new remote decode requests for P/D
new_KV_requests_to_send: list[Request] = []
if self.connector is not None:
# TODO: Receive request over ZMQ
# self.receiving_KV_req_ids.update(
# self.connector.receive_remote_decode_requests())

# Check if any P/D requests have finished sending or receiving
for req_id in list(self.waiting_to_send_KV_req_ids):
self.sending_KV_req_ids.add(req_id)
Expand Down Expand Up @@ -328,6 +325,18 @@ def schedule(self) -> SchedulerOutput:
skipped_waiting_requests.appendleft(request)
continue

# TODO(rob): we should do this after we allocate the blocks if
# we want to write directly into the BlockTable (like Dynamo).
# TODO(rob): this logic is incorrect if the req was preempted.
if request.do_remote_decode:
assert self.connector is not None
if not self.connector.is_request_done_receiving(req_id):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
self.receiving_KV_req_ids.add(request.request_id)
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue

# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request and (
Expand Down Expand Up @@ -719,7 +728,6 @@ def update_from_output(

# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
# TODO: What if we detect we're done here when doing P/D disagg?
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
Expand All @@ -742,6 +750,21 @@ def update_from_output(
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids:
# Stop request after the first token if doing a remote_decode.
# NOTE(rob): req is not freed (or preempted) in the EngineCore
# until the xfer is done to ensure we do not free the KV blocks.
kv_transfer_params = None
if request.do_remote_decode and not stopped:
stopped = True
request.status = RequestStatus.FINISHED_REMOTE_DECODE
self.sending_KV_req_ids.add(req_id)
assert self.connector is not None
# TODO(rob): do this on a per-Connector basis.
kv_transfer_params = KVTransferParams(
request_id=request.request_id,
do_remote_prefill=True,
)

# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
Expand All @@ -751,18 +774,14 @@ def update_from_output(
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
))

else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors

if self.connector is not None and request.do_remote_decode:
stopped = True
self.waiting_to_send_KV_req_ids.add(req_id)
# TODO: Add ZMQ request
#self.connector.send_remote_decode_request(
# self.kv_cache_manager.req_to_blocks[req_id])

self.scheduled_req_ids.remove(req_id)
if not stopped:
new_running.append(request)
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import KVTransferParams, SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors

# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
FINISH_REASON_STRINGS = ("stop", "length", "abort", "remote_decode")


class FinishReason(enum.IntEnum):
Expand All @@ -28,11 +28,13 @@ class FinishReason(enum.IntEnum):
stop - a stop string was emitted
length - max_tokens was consumed, or max_model_len was reached
abort - aborted for another reason
remote_decode - request will be processed as a remote_decode

"""
STOP = 0
LENGTH = 1
ABORT = 2
REMOTE_DECODE = 3

def __str__(self):
return FINISH_REASON_STRINGS[self.value]
Expand Down Expand Up @@ -102,6 +104,7 @@ class EngineCoreOutput(
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[KVTransferParams] = None

@property
def finished(self) -> bool:
Expand Down
12 changes: 9 additions & 3 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, Union

from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind
from vllm.sampling_params import KVTransferParams, RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
Expand Down Expand Up @@ -148,6 +148,7 @@ def make_request_output(
new_token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: KVTransferParams,
) -> Optional[RequestOutput]:

finished = finish_reason is not None
Expand All @@ -169,13 +170,15 @@ def make_request_output(
if not outputs:
return None

return self._new_request_output(request_id, outputs, finished)
return self._new_request_output(request_id, outputs, finished,
kv_transfer_params)

def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
finished: bool,
kv_transfer_params: KVTransferParams,
) -> RequestOutput:

if self.output_kind == RequestOutputKind.DELTA:
Expand All @@ -191,6 +194,7 @@ def _new_request_output(
prompt_logprobs=prompt_logprobs,
outputs=outputs,
finished=finished,
kv_transfer_params=kv_transfer_params,
)

def _new_completion_output(
Expand Down Expand Up @@ -337,6 +341,7 @@ def process_outputs(
new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params

req_state.is_prefilling = False

Expand All @@ -352,7 +357,8 @@ def process_outputs(

# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason):
new_token_ids, finish_reason, stop_reason,
kv_transfer_params):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
Expand Down
13 changes: 11 additions & 2 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,14 @@ def __init__(
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0

# P/D disagg related
self.do_remote_decode = False
# Disaggregated serving related
self.do_remote_decode = (
False if sampling_params.kv_transfer_params is None else
sampling_params.kv_transfer_params.do_remote_decode)
self.do_remote_prefill = (
False if sampling_params.kv_transfer_params is None else
sampling_params.kv_transfer_params.do_remote_decode)
assert not (self.do_remote_decode and self.do_remote_prefill)

# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
Expand Down Expand Up @@ -153,6 +159,7 @@ class RequestStatus(enum.IntEnum):
"""Status of a request."""
WAITING = enum.auto()
WAITING_FOR_FSM = enum.auto()
WAITING_FOR_REMOTE_KVS = enum.auto()
RUNNING = enum.auto()
SENDING_KV = enum.auto()
PREEMPTED = enum.auto()
Expand All @@ -162,6 +169,7 @@ class RequestStatus(enum.IntEnum):
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
FINISHED_REMOTE_DECODE = enum.auto()

@staticmethod
def is_finished(status: "RequestStatus") -> bool:
Expand All @@ -182,4 +190,5 @@ def get_finished_reason(
RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
RequestStatus.FINISHED_REMOTE_DECODE: FinishReason.REMOTE_DECODE
}