Skip to content

Commit 7a002d7

Browse files
authored
Handle new_block_ids is None. (#533)
Signed-off-by: Qiliang Cui <[email protected]>
1 parent bdefde4 commit 7a002d7

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

tpu_commons/runner/jax/tpu_jax_runner.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,11 +1572,13 @@ def _update_states(self, scheduler_output: "VllmSchedulerOutput") -> bool:
15721572
# Update the cached states.
15731573
req_state.num_computed_tokens = num_computed_tokens
15741574
if not resumed_from_preemption:
1575-
# Append the new blocks to the existing block IDs.
1576-
for block_ids, new_ids in zip(req_state.block_ids,
1577-
new_block_ids):
1578-
block_ids.extend(new_ids)
1575+
if new_block_ids is not None:
1576+
# Append the new blocks to the existing block IDs.
1577+
for block_ids, new_ids in zip(req_state.block_ids,
1578+
new_block_ids):
1579+
block_ids.extend(new_ids)
15791580
else:
1581+
assert new_block_ids is not None
15801582
# The request is resumed from preemption.
15811583
# Replace the existing block IDs with the new ones.
15821584
req_state.block_ids = new_block_ids
@@ -1592,7 +1594,9 @@ def _update_states(self, scheduler_output: "VllmSchedulerOutput") -> bool:
15921594
# Update the persistent batch.
15931595
self.input_batch.num_computed_tokens_cpu[req_index] = (
15941596
num_computed_tokens)
1595-
self.input_batch.block_table.append_row(new_block_ids, req_index)
1597+
if new_block_ids is not None:
1598+
self.input_batch.block_table.append_row(
1599+
new_block_ids, req_index)
15961600

15971601
# Add spec_token_ids to token_ids_cpu.
15981602
spec_token_ids = (

0 commit comments

Comments
 (0)