@@ -26,11 +26,10 @@ def test_first_block_has_correct_content_hash(seed: int, block_size: int,
2626 token_ids = list (range (num_to_fill ))
2727 mock_allocator = MagicMock (spec = PrefixCachingBlockAllocator )
2828
29- block_with_prev = PrefixCachingBlock (
30- prev_block = None ,
31- token_ids = token_ids ,
32- block_size = block_size ,
33- prefix_caching_allocator = mock_allocator )
29+ block_with_prev = PrefixCachingBlock (prev_block = None ,
30+ token_ids = token_ids ,
31+ block_size = block_size ,
32+ allocator = mock_allocator )
3433
3534 if is_curr_block_full :
3635 # Expect hash since block is full.
@@ -71,7 +70,7 @@ def test_nth_block_has_correct_content_hash(seed: int, block_size: int,
7170 prev_block = previous_block ,
7271 token_ids = token_ids ,
7372 block_size = block_size ,
74- prefix_caching_allocator = mock_allocator ,
73+ allocator = mock_allocator ,
7574 )
7675
7776 if is_curr_block_full and prev_block_has_hash :
@@ -138,7 +137,7 @@ def create_chain(block_size: int,
138137 prev_block = prev_block ,
139138 token_ids = [],
140139 block_size = block_size ,
141- prefix_caching_allocator = allocator ,
140+ allocator = allocator ,
142141 )
143142
144143 tokens_to_append = token_ids [block_number *
@@ -159,11 +158,11 @@ def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator,
159158 prev_block : Optional [Block ],
160159 token_ids : List [int ]):
161160 if allocate_type == "immutable" :
162- allocate_block = lambda : allocator .allocate_immutable (
161+ allocate_block = lambda : allocator .allocate_immutable_block (
163162 prev_block = prev_block , token_ids = token_ids )
164163 elif allocate_type == "mutable" :
165- allocate_block = lambda : allocator .allocate_mutable ( prev_block =
166- prev_block )
164+ allocate_block = lambda : allocator .allocate_mutable_block (
165+ prev_block = prev_block )
167166 else :
168167 raise ValueError ()
169168
@@ -233,12 +232,13 @@ def test_allocate_immutable_ooms_many_hash(num_blocks: int,
233232
234233 # Expect allocation with unseen hash to fail.
235234 with pytest .raises (BlockAllocator .NoFreeBlocksError ):
236- allocator .allocate_immutable (prev_block = chain [- 1 ],
237- token_ids = list (range (block_size )))
235+ allocator .allocate_immutable_block (prev_block = chain [- 1 ],
236+ token_ids = list (
237+ range (block_size )))
238238
239239 # Expect mutable allocation to fail.
240240 with pytest .raises (BlockAllocator .NoFreeBlocksError ):
241- allocator .allocate_mutable (prev_block = chain [- 1 ])
241+ allocator .allocate_mutable_block (prev_block = chain [- 1 ])
242242
243243 # Expect allocation of exact same chain to pass.
244244 second_chain = TestPrefixCachingBlockAllocator .create_immutable_chain (
@@ -270,7 +270,7 @@ def test_free_prevents_oom(num_blocks: int, block_size: int):
270270
271271 # Expect mutable allocation to fail.
272272 with pytest .raises (BlockAllocator .NoFreeBlocksError ):
273- allocator .allocate_mutable (prev_block = None )
273+ allocator .allocate_mutable_block (prev_block = None )
274274
275275 block_to_free = chain [- 1 ]
276276
@@ -280,11 +280,11 @@ def test_free_prevents_oom(num_blocks: int, block_size: int):
280280 allocator .free (block_to_free )
281281 assert block_to_free .block_id is None , i
282282
283- new_block = allocator .allocate_mutable (prev_block = None )
283+ new_block = allocator .allocate_mutable_block (prev_block = None )
284284 assert new_block .block_id == block_id , i
285285
286286 with pytest .raises (BlockAllocator .NoFreeBlocksError ):
287- allocator .allocate_mutable (prev_block = None )
287+ allocator .allocate_mutable_block (prev_block = None )
288288
289289 block_to_free = new_block
290290
@@ -376,17 +376,13 @@ def test_get_common_computed_block_ids(num_blocks: int, block_size: int,
376376
377377 # Create token ids that will exhaust all blocks.
378378 token_ids = list (range (num_blocks_to_consume * block_size ))
379- blocks = list (range (num_blocks_to_consume ))
380379
381380 first_chain = TestPrefixCachingBlockAllocator .create_immutable_chain (
382381 block_size = block_size ,
383382 token_ids = token_ids ,
384383 allocator = allocator ,
385384 )
386385
387- # mark all blocks in first chain as computed
388- allocator .mark_blocks_as_computed (blocks )
389-
390386 # After zero_point, second_chain's token_ids would be set -1, which
391387 # make it different from here comparing with first_chain
392388 zero_point = random .randint (1 , len (token_ids ) - 1 )
@@ -424,15 +420,16 @@ def test_alloc_promotion(num_blocks: int, block_size: int, seed: int):
424420 block_size = block_size )
425421 token_ids = list (range (block_size ))
426422
427- block = allocator .allocate_immutable (prev_block = None ,
428- token_ids = token_ids )
423+ block = allocator .allocate_immutable_block (prev_block = None ,
424+ token_ids = token_ids )
429425
430426 assert allocator ._refcounter .get (block .block_id ) == 1
431- m = allocator .allocate_mutable (prev_block = None )
427+ m = allocator .allocate_mutable_block (prev_block = None )
432428
433429 block_id = m .block_id
434430 for i in range (block_size ):
435431 m .append_token_ids ([i ])
432+
436433 # After block get promoted to immutable from mutable, if there is
437434 # already same content hash block, then it shall be released into
438435 # hashless_allocator
@@ -452,48 +449,79 @@ def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int):
452449
453450 all_blocks_list = [i for i in range (num_blocks )]
454451 zero_ref = {i : 0 for i in range (num_blocks )}
452+ one_ref = {i : 1 for i in range (num_blocks )}
455453 allocator = PrefixCachingBlockAllocator (num_blocks = num_blocks ,
456454 block_size = block_size )
457455 token_ids = list (range (num_blocks * block_size ))
458456
459- # now we have num_blocks free blocks in hashless allocator
460- # with internal tracking list _blocks _cached_blocks and evictor
461- # empty and block's ref shall be 0
457+ # Verify initial/pre-alloc state
458+
459+ # Ensure all blocks are free inside hashless allocator
462460 assert list (allocator ._hashless_allocator ._free_block_indices
463461 ) == all_blocks_list
464- assert len (allocator ._blocks .keys ()) == 0
462+ # Ensure no tracked blocks
463+ assert len (allocator ._block_tracker .keys ()) == num_blocks
464+ for block_id in range (num_blocks ):
465+ assert not allocator ._block_tracker [block_id ].active
466+ # Ensure no cached blocks
465467 assert len (allocator ._cached_blocks .values ()) == 0
468+ # Ensure no evicted blocks
466469 assert len (allocator .evictor .free_table .keys ()) == 0
470+ # Ensure 0s ref counts for all blocks
467471 assert allocator ._refcounter ._refcounts == zero_ref
468472
469473 # Allocate immutable chains with only one block residuled in
470474 new_block = []
471475 for i in range (num_blocks ):
472- block = allocator .allocate_immutable (
476+ block = allocator .allocate_immutable_block (
473477 prev_block = None ,
474478 token_ids = token_ids [block_size * i :block_size * (i + 1 )])
475479 new_block .append (block )
476480
481+ # Verify post-alloc state
482+
483+ # Ensure no blocks are free inside hashless allocator
484+ assert (len (allocator ._hashless_allocator ._free_block_indices ) == 0 )
485+ # Ensure all blocks are tracked
486+ assert len (allocator ._block_tracker .keys ()) == num_blocks
487+ for block_id in range (num_blocks ):
488+ assert allocator ._block_tracker [block_id ].active
489+ # Ensure all blocks are cached (all promoted)
490+ assert len (allocator ._cached_blocks .values ()) == num_blocks
491+ # Ensure no evicted blocks
492+ assert len (allocator .evictor .free_table .keys ()) == 0
493+ # Ensure 1s ref counts for all blocks
494+ assert allocator ._refcounter ._refcounts == one_ref
495+
477496 # Free all blocks, and now all blocks shall be in the evictor
478- # there shall be no tracking data left in _blocks
497+ # there shall be no tracking data left in _block_tracker
479498 # all blocks shall be tracked in _cached_blocks
480499 # all blocks' ref shall be zero
481500 for block in new_block :
482501 allocator .free (block )
483502
484- assert len (allocator ._blocks .keys ()) == 0
503+ # Verify post-free state
504+
505+ # Ensure no tracked blocks
506+ assert len (allocator ._block_tracker .keys ()) == num_blocks
507+ for block_id in range (num_blocks ):
508+ assert not allocator ._block_tracker [block_id ].active
509+ # Ensure no blocks in hashless allocator (all promoted)
485510 assert len (allocator ._hashless_allocator ._free_block_indices ) == 0
511+ # Ensure all blocks are cached
486512 assert list (allocator ._cached_blocks .values ()) == all_blocks_list
513+ # Ensure all blocks are inside the evictor
487514 assert list (allocator .evictor .free_table .keys ()) == all_blocks_list
515+ # Ensure 0s refcounts
488516 assert allocator ._refcounter ._refcounts == zero_ref
489517
490518 # Allocate a mutable block, and the first block shall be evicted
491519 # and set its content hash into None, ref to 1
492- mutable = allocator .allocate_mutable (prev_block = None )
520+ mutable = allocator .allocate_mutable_block (prev_block = None )
493521
494522 assert mutable .block_id == 0
495523 assert mutable .content_hash is None
496- assert 0 in allocator ._blocks
524+ assert allocator ._block_tracker [ 0 ]. active
497525 assert allocator ._refcounter .get (0 ) == 1
498526 assert 0 not in allocator ._cached_blocks
499527 assert 0 not in allocator .evictor
@@ -502,27 +530,27 @@ def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int):
502530 # hashless allocator
503531 allocator .free (mutable )
504532
505- assert len ( allocator ._blocks . keys ()) == 0
533+ assert not allocator ._block_tracker [ 0 ]. active
506534 assert allocator ._refcounter ._refcounts == zero_ref
507535 assert 0 not in allocator ._cached_blocks
508536 assert 0 not in allocator .evictor
509537 assert 0 in allocator ._hashless_allocator ._free_block_indices
510538
511- # when allocate immutable with first block_size tokens, we
539+ # When allocate immutable with first block_size tokens, we
512540 # shall get free block from hashless allocator, thus no block left
513541 # in hashless
514- block = allocator .allocate_immutable ( prev_block = None ,
515- token_ids = token_ids [:block_size ])
542+ block = allocator .allocate_immutable_block (
543+ prev_block = None , token_ids = token_ids [:block_size ])
516544
517545 assert block .block_id == 0
518546 assert len (allocator ._hashless_allocator ._free_block_indices ) == 0
519- assert 0 in allocator ._blocks
547+ assert allocator ._block_tracker [ 0 ]. active
520548 assert 0 in allocator ._cached_blocks .values ()
521549 assert allocator ._refcounter .get (0 ) == 1
522550 assert 0 not in allocator .evictor
523551
524552 # allocate mutable block again, it shall be popped from evictor
525- mutable = allocator .allocate_mutable (prev_block = None )
553+ mutable = allocator .allocate_mutable_block (prev_block = None )
526554 assert len (allocator ._hashless_allocator ._free_block_indices ) == 0
527555 assert mutable .block_id not in allocator .evictor .free_table
528556 assert allocator ._refcounter .get (mutable .block_id ) == 1
@@ -619,7 +647,7 @@ def create_immutable_chain(
619647 block_token_ids = token_ids [block_number *
620648 block_size :(block_number + 1 ) *
621649 block_size ]
622- prev_block = allocator .allocate_immutable (
650+ prev_block = allocator .allocate_immutable_block (
623651 prev_block = prev_block , token_ids = block_token_ids )
624652 blocks .append (prev_block )
625653
0 commit comments