Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/online_serving/openai_completion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8192/v1"

PROMPT = "The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, which means that there is a lot of opportunity to work with community and customers on key infrastructure projects. This means", # noqa: E501
PROMPT = "The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, " # noqa: E501


def main():
client = OpenAI(
Expand All @@ -21,8 +24,7 @@ def main():
stream = True
completion = client.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
prompt=
"The absolute best part about working for Red Hat is that we get to work on open source software. Red Hat is a leader in many key open source infrastructure technologies like Linux, Kubernetes, and recently vLLM, which means that there is a lot of opportunity to work with community and customers on key infrastructure projects. This means", # noqa: E501
prompt=PROMPT,
echo=False,
stream=stream)

Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def cache_full_blocks(
prev_block_hash_value = prev_block.block_hash.hash_value

for i, blk in enumerate(new_full_blocks):
if blk.block_hash is not None:
continue
assert blk.block_hash is None

if i < len(new_block_hashes):
Expand Down
19 changes: 13 additions & 6 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,19 @@ def allocate_slots(
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)

if not self.enable_caching or skip_cache_blocks:
# If self.enable_caching, this is true since can only
# get to this codepath when we have never been scheduled.
assert request.request_id not in self.num_cached_block
if not self.enable_caching:
return new_blocks

if skip_cache_blocks:
# NOTE(rob): this assert is valid because we only call
# skip_cache_blocks=True on the first time of WAITING
# during a P/D setup.
assert request.request_id not in self.num_cached_block
# NOTE(rob): this is necessary so we don't double
# cache a block after is has finished recving.
self.num_cached_block[request.request_id] = len(
new_computed_blocks)

self.cache_blocks(
request=request,
num_tokens=num_tokens,
Expand Down Expand Up @@ -313,8 +320,8 @@ def cache_blocks(
# Speculated tokens might be rejected in the future, so we do
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
num_full_blocks_after_append = (
num_computed_tokens + num_tokens - len(request.spec_token_ids)) // self.block_size
num_full_blocks_after_append = (num_computed_tokens + num_tokens - len(
request.spec_token_ids)) // self.block_size

self.block_pool.cache_full_blocks(
request=request,
Expand Down