@@ -373,6 +373,52 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
373373 seq_group , num_lookahead_slots ) == AllocStatus .NEVER
374374
375375
376+ @pytest .mark .parametrize ("num_lookahead_slots" , [0 , 2 , 10 ])
377+ @pytest .mark .parametrize ("enable_caching" , [False , True ])
378+ def test_swap_in_infeasible (num_lookahead_slots , enable_caching ):
379+ """Verifies that swapping fails if there is not enough free blocks
380+ to account for unseen tokens and lookahead_slots.
381+ """
382+ block_size = 8
383+ num_cpu_blocks = 1
384+ num_gpu_blocks = 1
385+ block_manager = BlockSpaceManagerV2 (block_size ,
386+ num_cpu_blocks ,
387+ num_gpu_blocks ,
388+ watermark = 0 ,
389+ enable_caching = enable_caching )
390+ prompt_length = block_size - 3
391+ assert prompt_length > 0
392+ prompt , seq_group = create_dummy_prompt ("1" , prompt_length = prompt_length )
393+ prompt .status = SequenceStatus .WAITING
394+ block_manager .allocate (seq_group )
395+ # Emulate a forward pass by appending a single token.
396+ # The block manager then knows how many unprocessed
397+ # tokens will be written in the next forward pass.
398+ token_id = 0
399+ prompt .status = SequenceStatus .RUNNING
400+ prompt .append_token_id (token_id , {token_id : Logprob (0.0 )})
401+
402+ # Swap seq group from GPU -> CPU.
403+ assert block_manager .can_swap_out (seq_group )
404+ block_manager .swap_out (seq_group )
405+ prompt .status = SequenceStatus .SWAPPED
406+
407+ # Swap seq group from CPU -> GPU.
408+ # The number of unseen tokens is 1. If the number of existing
409+ # tokens plus the unseen ones and number of lookahead slots exceeds
410+ # the total number of available GPU blocks then the swap
411+ # should fail.
412+ num_unseen_tokens = 1
413+ if (num_lookahead_slots + num_unseen_tokens +
414+ prompt_length ) <= (block_size * num_gpu_blocks ):
415+ assert block_manager .can_swap_in (seq_group ,
416+ num_lookahead_slots ) == AllocStatus .OK
417+ else :
418+ assert block_manager .can_swap_in (
419+ seq_group , num_lookahead_slots ) == AllocStatus .NEVER
420+
421+
376422# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level.
377423
378424
@@ -400,7 +446,6 @@ def check_used(min_n, max_n=None):
400446 if max_n is None :
401447 max_n = min_n
402448 used = num_gpu_blocks - block_manager .get_num_free_gpu_blocks ()
403- #print("check", min_n, used, max_n)
404449 assert min_n <= used
405450 assert used <= max_n
406451
0 commit comments