Skip to content
16 changes: 7 additions & 9 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,12 @@ def get_num_new_matched_tokens(self, request: "Request",
assert num_computed_tokens % self.block_size == 0

if request.do_remote_prefill:
# NOTE: subtract 1 since we compute the last token
# here so that we can sample the first token.
num_prompt_tokens = len(request.prompt_token_ids) - 1
Comment on lines -175 to -177
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, glad this is going away since I never understood it


# Round down to a full block shape.
num_external_blocks = num_prompt_tokens // self.block_size
num_external_blocks = len(
request.prompt_token_ids) // self.block_size
rounded_num_prompt_tokens = num_external_blocks * self.block_size
return max(rounded_num_prompt_tokens - num_computed_tokens, 0)
else:
return 0

return 0

def update_state_after_alloc(self, request: "Request",
block_ids: list[int],
Expand Down Expand Up @@ -310,7 +306,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):

# For debug, SENDER puts some stuff in the KV caches
# so the RECVER can check it
n_blocks_to_send = 4096
n_blocks_to_send = min(4096, kv_caches[first_layer_name].shape[1])
debug_xfer_gb = 2.0 * n_blocks_to_send * self.block_len / 1e9
print(f"gb {debug_xfer_gb} -- block_len {self.block_len}")
if NIXL_ROLE == "SENDER":
Expand Down Expand Up @@ -581,6 +577,8 @@ def _read_blocks(
if len(remote_block_ids) < len(local_block_ids):
local_block_ids = local_block_ids[:len(remote_block_ids)]
assert len(local_block_ids) == len(remote_block_ids)

# NOTE(rob): this can cause the remote blocks to not be freed?
if len(local_block_ids) == 0:
return

Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def schedule(self) -> SchedulerOutput:
[b.block_id for b in computed_blocks + new_blocks],
num_external_tokens,
)
# We should only trigger a KV transfer once per request.
request.do_remote_prefill = False
continue

# Number of tokens to be scheduled.
Expand Down