Skip to content
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
891a5da
updated
robertgshaw2-redhat Apr 19, 2025
0fde021
updated
robertgshaw2-redhat Apr 19, 2025
2fb8eea
updated
robertgshaw2-redhat Apr 19, 2025
2f98b23
updated
robertgshaw2-redhat Apr 19, 2025
ad63d9a
updated
robertgshaw2-redhat Apr 19, 2025
54f1e64
updated
robertgshaw2-redhat Apr 19, 2025
a125762
updated
robertgshaw2-redhat Apr 19, 2025
dd3e299
updated
robertgshaw2-redhat Apr 19, 2025
dd8df0c
updated
robertgshaw2-redhat Apr 19, 2025
69c9d13
updated
robertgshaw2-redhat Apr 19, 2025
778389b
updated
robertgshaw2-redhat Apr 19, 2025
6b2fa35
stash
robertgshaw2-redhat Apr 19, 2025
a4ab996
updated
robertgshaw2-redhat Apr 19, 2025
ec3ed2e
updated
robertgshaw2-redhat Apr 19, 2025
5f73147
updated
robertgshaw2-redhat Apr 19, 2025
ab22fb8
update
robertgshaw2-redhat Apr 19, 2025
09ff580
pr readability
robertgshaw2-redhat Apr 19, 2025
ecfd8d6
updated
robertgshaw2-redhat Apr 19, 2025
51fac6c
updated
robertgshaw2-redhat Apr 19, 2025
9f6109d
updated
robertgshaw2-redhat Apr 19, 2025
82f001b
updated
robertgshaw2-redhat Apr 19, 2025
3a768c2
updated
robertgshaw2-redhat Apr 19, 2025
fd0e92d
updated
robertgshaw2-redhat Apr 19, 2025
77cf32b
update beam search
robertgshaw2-redhat Apr 19, 2025
319095f
updated
robertgshaw2-redhat Apr 19, 2025
59fab97
updated
robertgshaw2-redhat Apr 19, 2025
5a94164
updated
robertgshaw2-redhat Apr 19, 2025
8408dba
stash
robertgshaw2-redhat Apr 19, 2025
635c858
updated
robertgshaw2-redhat Apr 20, 2025
7d8b39b
prefill worker is 'working'
robertgshaw2-redhat Apr 20, 2025
9b4daf3
updated
robertgshaw2-redhat Apr 20, 2025
6bf63a3
updated
robertgshaw2-redhat Apr 20, 2025
c785762
updated
robertgshaw2-redhat Apr 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions examples/online_serving/openai_completion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
openai_api_base = "http://localhost:9000/v1"


def main():
Expand All @@ -14,18 +14,17 @@ def main():
base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id
# models = client.models.list()
# model = models.data[0].id

# Completion API
stream = False
completion = client.completions.create(
model=model,
prompt="A robot may not injure a human being",
model="meta-llama/Llama-3.1-8B-Instruct",
prompt=
"The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, which means that there is a lot of opportunity to work with community and customers on key infrastructure projects. This means", # noqa: E501
echo=False,
n=2,
stream=stream,
logprobs=3)
stream=stream)

print("-" * 50)
print("Completion results:")
Expand Down
39 changes: 39 additions & 0 deletions examples/online_serving/openai_completion_client2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0

from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8100/v1"


def main():
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)

# models = client.models.list()
# model = models.data[0].id

# Completion API
stream = False
completion = client.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
prompt="The quick brown job jumped",
echo=False,
stream=stream)

print("-" * 50)
print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print(completion)
print("-" * 50)


if __name__ == "__main__":
main()
6 changes: 2 additions & 4 deletions examples/other/LMCache/disagg-example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ if [[ $1 == "prefiller" ]]; then
LMCACHE_USE_EXPERIMENTAL=True \
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
CUDA_VISIBLE_DEVICES=0 \
CUDA_VISIBLE_DEVICES=6 \
vllm serve $MODEL \
--port 8100 \
--disable-log-requests \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}'
Expand All @@ -46,10 +45,9 @@ elif [[ $1 == "decoder" ]]; then
LMCACHE_USE_EXPERIMENTAL=True \
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
CUDA_VISIBLE_DEVICES=1 \
CUDA_VISIBLE_DEVICES=7 \
vllm serve $MODEL \
--port 8200 \
--disable-log-requests \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}'
Expand Down
7 changes: 4 additions & 3 deletions examples/other/LMCache/disagg_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
Send a request to a service using a persistent client.
"""
req_data = req_data.copy()
req_data['max_tokens'] = 1
if 'max_completion_tokens' in req_data:
req_data['max_completion_tokens'] = 1
req_data['do_remote_decode'] = True
req_data["stream"] = False

headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
response = await client.post(endpoint, json=req_data, headers=headers)
response.raise_for_status()

return response


Expand All @@ -104,6 +104,7 @@ async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
Asynchronously stream the response from a service using a persistent client.
"""
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
# req_data['do_remote_prefill'] = True
async with client.stream("POST", endpoint, json=req_data,
headers=headers) as response:
response.raise_for_status()
Expand Down
22 changes: 20 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from vllm.logger import init_logger
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
KVTransferParams, RequestOutputKind,
SamplingParams)
from vllm.sequence import Logprob
from vllm.utils import random_uuid, resolve_obj_by_qualname

Expand Down Expand Up @@ -807,6 +808,14 @@ class CompletionRequest(OpenAIBaseModel):
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))

do_remote_decode: bool = Field(
default=False,
description="KVTransfer parameters used for disaggregated serving.")

do_remote_prefill: bool = Field(
default=False,
description="KVTransfer parameters used for disaggregated serving.")

# doc: end-completion-extra-params

# Default sampling parameters for completion requests
Expand Down Expand Up @@ -904,6 +913,11 @@ def to_sampling_params(
whitespace_pattern=self.guided_whitespace_pattern,
)

kv_transfer_params = KVTransferParams.from_optional(
do_remote_decode=self.do_remote_decode,
do_remote_prefill=self.do_remote_prefill,
)

return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
Expand Down Expand Up @@ -932,7 +946,9 @@ def to_sampling_params(
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids)
allowed_token_ids=self.allowed_token_ids,
kv_transfer_params=kv_transfer_params,
)

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -1182,6 +1198,8 @@ class CompletionResponse(OpenAIBaseModel):
model: str
choices: list[CompletionResponseChoice]
usage: UsageInfo
# TODO: make this into a pydanic object
do_remote_prefill: Optional[bool] = Field(default=None)


class CompletionResponseStreamChoice(OpenAIBaseModel):
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,12 +476,16 @@ def request_output_to_completion_response(

request_metadata.final_usage_info = usage

do_remote_prefill = (
final_res_batch[0].kv_transfer_params
and final_res_batch[0].kv_transfer_params.do_remote_prefill)
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
do_remote_prefill=do_remote_prefill,
)

def _create_completion_logprobs(
Expand Down
6 changes: 5 additions & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind
from vllm.sampling_params import KVTransferParams, RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceGroupBase, SequenceStatus)

Expand Down Expand Up @@ -103,6 +103,7 @@ class RequestOutput:
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
kv_transfer_params: The params for remote K/V transfer.
"""

def __init__(
Expand All @@ -120,6 +121,7 @@ def __init__(
num_cached_tokens: Optional[int] = None,
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
Expand All @@ -133,11 +135,13 @@ def __init__(
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params

def add(self, next_output: "RequestOutput") -> None:
"""Merge subsequent RequestOutput into this one"""

self.finished |= next_output.finished
self.kv_transfer_params = next_output.kv_transfer_params

for next_completion in next_output.outputs:
for completion in self.outputs:
Expand Down
32 changes: 32 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,33 @@ class SamplingType(IntEnum):
RANDOM_SEED = 2


# TODO(rob): make this per connector
class KVTransferParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True):
# TODO(rob): we can handle xPyD and direct KV block Xfer
# remote_instance_id: Optional[str] = None
# remote_block_ids: Optional[list[int]] = None
do_remote_decode: bool = False
do_remote_prefill: bool = False

@staticmethod
def from_optional(do_remote_decode: bool,
do_remote_prefill: bool) -> Optional["KVTransferParams"]:
if do_remote_prefill and do_remote_prefill:
raise ValueError(
"Cannot do both remote prefill and remote decode.")
elif do_remote_decode or do_remote_prefill:
return KVTransferParams(
do_remote_decode=do_remote_decode,
do_remote_prefill=do_remote_prefill,
)
else:
return None


# maybe make msgspec?
@dataclass
class GuidedDecodingParams:
Expand Down Expand Up @@ -237,6 +264,9 @@ class SamplingParams(
bad_words: Optional[list[str]] = None
_bad_words_token_ids: Optional[list[list[int]]] = None

# Fields used for KVTransfer in disaggregated serving.
kv_transfer_params: Optional[KVTransferParams] = None

@staticmethod
def from_optional(
n: Optional[int] = 1,
Expand Down Expand Up @@ -268,6 +298,7 @@ def from_optional(
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[list[int]] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
extra_args: Optional[dict[str, Any]] = None,
) -> "SamplingParams":
if logit_bias is not None:
Expand Down Expand Up @@ -310,6 +341,7 @@ def from_optional(
guided_decoding=guided_decoding,
logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids,
kv_transfer_params=kv_transfer_params,
extra_args=extra_args,
)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@ class SchedulerOutput:
kv_connector_metadata: Optional[KVConnectorMetadata] = None
sending_KV_req_ids: set[str] = field(default_factory=set)
receiving_KV_req_ids: set[str] = field(default_factory=set)
new_KV_requests_to_send: list[NewRequestData] = field(default_factory=list)
new_KV_req_ids_to_send: list[str] = field(default_factory=list)
Loading