1515from vllm .v1 .core .kv_cache_utils import (
1616 FreeKVCacheBlockQueue , KVCacheBlock , PrefixCachingMetrics ,
1717 estimate_max_model_len , generate_block_hash_extra_keys ,
18- get_max_concurrency_for_kv_cache_config , hash_block_tokens ,
19- hash_request_tokens , unify_kv_cache_configs )
18+ get_kv_cache_config , get_max_concurrency_for_kv_cache_config ,
19+ hash_block_tokens , hash_request_tokens , unify_kv_cache_configs )
2020from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
2121 KVCacheGroupSpec , KVCacheTensor ,
2222 SlidingWindowSpec )
@@ -63,6 +63,20 @@ def new_kv_cache_spec(block_size=16,
6363 sliding_window = sliding_window )
6464
6565
66+ def new_sliding_window_spec (block_size = 16 ,
67+ num_kv_heads = 2 ,
68+ head_size = 64 ,
69+ dtype = torch .float32 ,
70+ use_mla = False ,
71+ sliding_window = 1 ):
72+ return SlidingWindowSpec (block_size = block_size ,
73+ num_kv_heads = num_kv_heads ,
74+ head_size = head_size ,
75+ dtype = dtype ,
76+ use_mla = use_mla ,
77+ sliding_window = sliding_window )
78+
79+
6680def test_none_hash (monkeypatch ):
6781 import vllm .v1 .core .kv_cache_utils
6882
@@ -403,10 +417,10 @@ def test_unify_kv_cache_configs():
403417 same_kv_cache_config = [
404418 KVCacheConfig (
405419 num_blocks = 10 ,
406- tensors = {
407- "layer1" : KVCacheTensor (100 ),
408- "layer2" : KVCacheTensor (100 ),
409- } ,
420+ kv_cache_tensors = [
421+ KVCacheTensor (size = 100 , shared_by = [ "layer1" ] ),
422+ KVCacheTensor (size = 100 , shared_by = [ "layer2" ] ),
423+ ] ,
410424 kv_cache_groups = [
411425 KVCacheGroupSpec (["layer1" ], new_kv_cache_spec ()),
412426 KVCacheGroupSpec (["layer2" ],
@@ -415,10 +429,10 @@ def test_unify_kv_cache_configs():
415429 ),
416430 KVCacheConfig (
417431 num_blocks = 20 ,
418- tensors = {
419- "layer1" : KVCacheTensor (100 ),
420- "layer2" : KVCacheTensor (100 ),
421- } ,
432+ kv_cache_tensors = [
433+ KVCacheTensor (size = 100 , shared_by = [ "layer1" ] ),
434+ KVCacheTensor (size = 100 , shared_by = [ "layer2" ] ),
435+ ] ,
422436 kv_cache_groups = [
423437 KVCacheGroupSpec (["layer1" ], new_kv_cache_spec ()),
424438 KVCacheGroupSpec (["layer2" ],
@@ -433,10 +447,10 @@ def test_unify_kv_cache_configs():
433447 need_sort_kv_cache_config = [
434448 KVCacheConfig (
435449 num_blocks = 10 ,
436- tensors = {
437- "layer1" : KVCacheTensor (100 ),
438- "layer2" : KVCacheTensor (100 ),
439- } ,
450+ kv_cache_tensors = [
451+ KVCacheTensor (size = 100 , shared_by = [ "layer1" ] ),
452+ KVCacheTensor (size = 100 , shared_by = [ "layer2" ] ),
453+ ] ,
440454 kv_cache_groups = [
441455 KVCacheGroupSpec (["layer1" ], new_kv_cache_spec ()),
442456 KVCacheGroupSpec (["layer2" ],
@@ -445,10 +459,10 @@ def test_unify_kv_cache_configs():
445459 ),
446460 KVCacheConfig (
447461 num_blocks = 20 ,
448- tensors = {
449- "layer1" : KVCacheTensor (100 ),
450- "layer2" : KVCacheTensor (100 ),
451- } ,
462+ kv_cache_tensors = [
463+ KVCacheTensor (size = 100 , shared_by = [ "layer1" ] ),
464+ KVCacheTensor (size = 100 , shared_by = [ "layer2" ] ),
465+ ] ,
452466 kv_cache_groups = [
453467 KVCacheGroupSpec (["layer2" ],
454468 new_kv_cache_spec (num_kv_heads = 4 )),
@@ -464,10 +478,10 @@ def test_unify_kv_cache_configs():
464478 diff_kv_cache_config = [
465479 KVCacheConfig (
466480 num_blocks = 10 ,
467- tensors = {
468- "layer1" : KVCacheTensor (100 ),
469- "layer2" : KVCacheTensor (100 ),
470- } ,
481+ kv_cache_tensors = [
482+ KVCacheTensor (size = 100 , shared_by = [ "layer1" ] ),
483+ KVCacheTensor (size = 100 , shared_by = [ "layer2" ] ),
484+ ] ,
471485 kv_cache_groups = [
472486 KVCacheGroupSpec (["layer1" ], new_kv_cache_spec ()),
473487 KVCacheGroupSpec (["layer2" ],
@@ -476,10 +490,10 @@ def test_unify_kv_cache_configs():
476490 ),
477491 KVCacheConfig (
478492 num_blocks = 20 ,
479- tensors = {
480- "layer1" : KVCacheTensor (100 ),
481- "layer2" : KVCacheTensor (100 ),
482- } ,
493+ kv_cache_tensors = [
494+ KVCacheTensor (size = 100 , shared_by = [ "layer1" ] ),
495+ KVCacheTensor (size = 100 , shared_by = [ "layer2" ] ),
496+ ] ,
483497 kv_cache_groups = [
484498 KVCacheGroupSpec (["layer1" ], new_kv_cache_spec ()),
485499 KVCacheGroupSpec (["layer2" ],
@@ -636,7 +650,7 @@ def test_get_max_concurrency_for_kv_cache_config():
636650
637651 kv_cache_config_full_attention = KVCacheConfig (
638652 num_blocks = int (1024 * 1.5 ),
639- tensors = {} ,
653+ kv_cache_tensors = [] ,
640654 kv_cache_groups = [
641655 KVCacheGroupSpec ([f"layer_{ i } " for i in range (32 )],
642656 full_attention_spec ),
@@ -648,7 +662,7 @@ def test_get_max_concurrency_for_kv_cache_config():
648662
649663 kv_cache_config_sliding_window = KVCacheConfig (
650664 num_blocks = 129 * 3 ,
651- tensors = {} ,
665+ kv_cache_tensors = [] ,
652666 kv_cache_groups = [
653667 KVCacheGroupSpec ([f"layer_{ i } " for i in range (32 )],
654668 sliding_window_spec ),
@@ -660,7 +674,7 @@ def test_get_max_concurrency_for_kv_cache_config():
660674
661675 kv_cache_config_hybrid_model = KVCacheConfig (
662676 num_blocks = (1024 + 129 ) * 3 ,
663- tensors = {} ,
677+ kv_cache_tensors = [] ,
664678 kv_cache_groups = [
665679 KVCacheGroupSpec ([f"layer_{ i } " for i in range (32 )],
666680 full_attention_spec ),
@@ -678,9 +692,9 @@ def test_allocate_with_lookahead():
678692 block_size = 4
679693 config = KVCacheConfig (
680694 num_blocks = 10 ,
681- tensors = {
682- "layer1" : KVCacheTensor (100 ),
683- } ,
695+ kv_cache_tensors = [
696+ KVCacheTensor (size = 100 , shared_by = [ "layer1" ] ),
697+ ] ,
684698 kv_cache_groups = [
685699 KVCacheGroupSpec (["layer1" ],
686700 new_kv_cache_spec (block_size = block_size )),
@@ -702,7 +716,7 @@ def test_allocate_with_lookahead():
702716 num_new_tokens = 3 ,
703717 num_lookahead_tokens = 2 , # Total required: 3+2=5 tokens
704718 )
705- assert len (blocks .blocks ) == 2 # ceil(5/4)=2 blocks
719+ assert len (blocks .get_block_ids ()[ 0 ] ) == 2 # ceil(5/4)=2 blocks
706720
707721 # Test case 2: With precomputed blocks
708722 kv_cache_manager = KVCacheManager (kv_cache_config = config ,
@@ -713,7 +727,7 @@ def test_allocate_with_lookahead():
713727 num_new_tokens = 3 ,
714728 num_lookahead_tokens = 2 ,
715729 )
716- assert len (blocks .blocks ) == 2
730+ assert len (blocks .get_block_ids ()[ 0 ] ) == 2
717731
718732 # Test case 3: With precomputed blocks
719733 # required_blocks = ceil((3 + 4) / 4) = 2
@@ -724,4 +738,165 @@ def test_allocate_with_lookahead():
724738 num_new_tokens = 3 ,
725739 num_lookahead_tokens = 4 ,
726740 )
727- assert len (blocks .blocks ) == 2
741+ assert len (blocks .get_block_ids ()[0 ]) == 2
742+
743+
744+ def test_get_kv_cache_config ():
745+ # pass max_model_len to pass check_enough_kv_cache_memory
746+ model_config = ModelConfig (max_model_len = 16 )
747+ vllm_config = VllmConfig (model_config = model_config )
748+
749+ mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2
750+ # all layers are full attention -> single group
751+ kv_cache_specs_full = {
752+ 'layer_1' : new_kv_cache_spec (),
753+ 'layer_2' : new_kv_cache_spec (),
754+ }
755+ kv_cache_config_full = get_kv_cache_config (
756+ vllm_config , kv_cache_specs_full , mem_per_block_per_layer * 2 * 32 )
757+ assert kv_cache_config_full == KVCacheConfig (
758+ num_blocks = 32 ,
759+ kv_cache_tensors = [
760+ KVCacheTensor (size = mem_per_block_per_layer * 32 ,
761+ shared_by = ["layer_1" ]),
762+ KVCacheTensor (size = mem_per_block_per_layer * 32 ,
763+ shared_by = ["layer_2" ]),
764+ ],
765+ kv_cache_groups = [
766+ KVCacheGroupSpec (["layer_1" , "layer_2" ], new_kv_cache_spec ())
767+ ])
768+
769+ # all layers are sliding window -> single group
770+ kv_cache_specs_sliding = {
771+ 'layer_1' : new_sliding_window_spec (),
772+ 'layer_2' : new_sliding_window_spec (),
773+ }
774+ kv_cache_config_sliding = get_kv_cache_config (
775+ vllm_config , kv_cache_specs_sliding , mem_per_block_per_layer * 2 * 32 )
776+ assert kv_cache_config_sliding == KVCacheConfig (
777+ num_blocks = 32 ,
778+ kv_cache_tensors = [
779+ KVCacheTensor (size = mem_per_block_per_layer * 32 ,
780+ shared_by = ["layer_1" ]),
781+ KVCacheTensor (size = mem_per_block_per_layer * 32 ,
782+ shared_by = ["layer_2" ]),
783+ ],
784+ kv_cache_groups = [
785+ KVCacheGroupSpec (["layer_1" , "layer_2" ], new_sliding_window_spec ())
786+ ])
787+
788+ # full + sliding, but disable_hybrid_kv_cache_manager
789+ vllm_config .scheduler_config .disable_hybrid_kv_cache_manager = True
790+ kv_cache_specs_hybrid = {
791+ 'layer_1' : new_kv_cache_spec (),
792+ 'layer_2' : new_sliding_window_spec (),
793+ }
794+ kv_cache_config_hybrid = get_kv_cache_config (
795+ vllm_config , kv_cache_specs_hybrid , mem_per_block_per_layer * 2 * 32 )
796+ assert kv_cache_config_hybrid == KVCacheConfig (
797+ num_blocks = 32 ,
798+ kv_cache_tensors = [
799+ KVCacheTensor (size = mem_per_block_per_layer * 32 ,
800+ shared_by = ["layer_1" ]),
801+ KVCacheTensor (size = mem_per_block_per_layer * 32 ,
802+ shared_by = ["layer_2" ]),
803+ ],
804+ kv_cache_groups = [
805+ KVCacheGroupSpec (["layer_1" , "layer_2" ],
806+ new_kv_cache_spec (sliding_window = 1 )),
807+ ],
808+ )
809+ vllm_config .scheduler_config .disable_hybrid_kv_cache_manager = False
810+
811+ # full + sliding, with hybrid_kv_cache_manager
812+ kv_cache_specs_hybrid = {
813+ 'layer_1' : new_kv_cache_spec (),
814+ 'layer_2' : new_sliding_window_spec (),
815+ }
816+ kv_cache_config_hybrid = get_kv_cache_config (
817+ vllm_config , kv_cache_specs_hybrid , mem_per_block_per_layer * 2 * 32 )
818+ assert kv_cache_config_hybrid == KVCacheConfig (
819+ num_blocks = 64 ,
820+ kv_cache_tensors = [
821+ KVCacheTensor (size = mem_per_block_per_layer * 64 ,
822+ shared_by = ["layer_1" , "layer_2" ]),
823+ ],
824+ kv_cache_groups = [
825+ KVCacheGroupSpec (["layer_1" ], new_kv_cache_spec ()),
826+ KVCacheGroupSpec (["layer_2" ], new_sliding_window_spec ()),
827+ ],
828+ )
829+
830+ # 2 full + 4 sliding, 2 layers per group
831+ kv_cache_specs_hybrid = {
832+ 'layer_1' : new_kv_cache_spec (),
833+ 'layer_2' : new_kv_cache_spec (),
834+ 'layer_3' : new_sliding_window_spec (),
835+ 'layer_4' : new_sliding_window_spec (),
836+ 'layer_5' : new_sliding_window_spec (),
837+ 'layer_6' : new_sliding_window_spec (),
838+ }
839+ kv_cache_config_hybrid = get_kv_cache_config (
840+ vllm_config , kv_cache_specs_hybrid , mem_per_block_per_layer * 2 * 32 )
841+ assert kv_cache_config_hybrid == KVCacheConfig (
842+ num_blocks = 32 ,
843+ kv_cache_tensors = [
844+ KVCacheTensor (size = mem_per_block_per_layer * 32 ,
845+ shared_by = ["layer_1" , "layer_3" , "layer_5" ]),
846+ KVCacheTensor (size = mem_per_block_per_layer * 32 ,
847+ shared_by = ["layer_2" , "layer_4" , "layer_6" ]),
848+ ],
849+ kv_cache_groups = [
850+ KVCacheGroupSpec (["layer_1" , "layer_2" ], new_kv_cache_spec ()),
851+ KVCacheGroupSpec (["layer_3" , "layer_4" ],
852+ new_sliding_window_spec ()),
853+ KVCacheGroupSpec (["layer_5" , "layer_6" ],
854+ new_sliding_window_spec ()),
855+ ],
856+ )
857+
858+ # 3 full + 7 sliding, pad to 3 full + 9 sliding
859+ kv_cache_specs_hybrid = {
860+ 'layer_1' : new_kv_cache_spec (),
861+ 'layer_2' : new_kv_cache_spec (),
862+ 'layer_3' : new_kv_cache_spec (),
863+ 'layer_4' : new_sliding_window_spec (),
864+ 'layer_5' : new_sliding_window_spec (),
865+ 'layer_6' : new_sliding_window_spec (),
866+ 'layer_7' : new_sliding_window_spec (),
867+ 'layer_8' : new_sliding_window_spec (),
868+ 'layer_9' : new_sliding_window_spec (),
869+ 'layer_10' : new_sliding_window_spec (),
870+ }
871+ kv_cache_config_hybrid = get_kv_cache_config (
872+ vllm_config , kv_cache_specs_hybrid , mem_per_block_per_layer * 3 * 32 )
873+ assert kv_cache_config_hybrid == KVCacheConfig (
874+ num_blocks = 32 ,
875+ kv_cache_tensors = [
876+ KVCacheTensor (
877+ size = mem_per_block_per_layer * 32 ,
878+ shared_by = ["layer_1" , "layer_4" , "layer_7" , "layer_10" ]),
879+ KVCacheTensor (size = mem_per_block_per_layer * 32 ,
880+ shared_by = ["layer_2" , "layer_5" , "layer_8" ]),
881+ KVCacheTensor (size = mem_per_block_per_layer * 32 ,
882+ shared_by = ["layer_3" , "layer_6" , "layer_9" ]),
883+ ],
884+ kv_cache_groups = [
885+ KVCacheGroupSpec (["layer_1" , "layer_2" , "layer_3" ],
886+ new_kv_cache_spec ()),
887+ KVCacheGroupSpec (["layer_4" , "layer_5" , "layer_6" ],
888+ new_sliding_window_spec ()),
889+ KVCacheGroupSpec (["layer_7" , "layer_8" , "layer_9" ],
890+ new_sliding_window_spec ()),
891+ KVCacheGroupSpec (["layer_10" ], new_sliding_window_spec ()),
892+ ],
893+ )
894+
895+ # different hidden size, unimplemented
896+ kv_cache_specs_hybrid = {
897+ 'layer_1' : new_kv_cache_spec (head_size = 128 ),
898+ 'layer_2' : new_kv_cache_spec (),
899+ }
900+ with pytest .raises (NotImplementedError ):
901+ get_kv_cache_config (vllm_config , kv_cache_specs_hybrid ,
902+ mem_per_block_per_layer * 2 * 32 )
0 commit comments