diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index 1616b7c78517..97ce313b7815 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -15,9 +15,9 @@ def __init__( block_size: int, num_blocks: int, ) -> None: - if block_size not in [8, 16]: + if block_size not in [8, 16, 32]: raise ValueError(f'Unsupported block size: {block_size}' - 'The block size must be either 8 or 16.') + 'The block size must be one of {8, 16, 32}.') self.device = device self.block_size = block_size self.num_blocks = num_blocks diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index ff8b549eb4c4..0c31cc898652 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -174,7 +174,7 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') # KV cache arguments - parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') + parser.add_argument('--block-size', type=int, default=8, choices=[8, 16, 32], help='token block size') # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index fc3a2717436e..fb7ada38d935 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -654,6 +654,16 @@ void single_query_cached_kv_attention( block_tables, context_lens, max_context_len); + } else if (block_size == 32) { + single_query_cached_kv_attention_launcher( + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); } else { assert(false); } @@ -679,6 +689,16 @@ void single_query_cached_kv_attention( block_tables, context_lens, max_context_len); + } else if (block_size == 32) { + single_query_cached_kv_attention_launcher( + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); } else { assert(false); } @@ -834,6 +854,18 @@ void multi_query_cached_kv_attention( block_tables, context_lens, max_context_len); + } else if (block_size == 32) { + multi_query_cached_kv_attention_launcher( + cu_query_lens, + seq_prompt_mapping, + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); } else { assert(false); } @@ -863,6 +895,18 @@ void multi_query_cached_kv_attention( block_tables, context_lens, max_context_len); + } else if (block_size == 32) { + multi_query_cached_kv_attention_launcher( + cu_query_lens, + seq_prompt_mapping, + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); } else { assert(false); } diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index a66f2c3daca7..7c2f350f1140 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -350,7 +350,7 @@ def test_attention(seed: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) for dtype in [torch.half, torch.float]: - for block_size in [8, 16]: + for block_size in [8, 16, 32]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Testing single_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' @@ -368,7 +368,7 @@ def test_attention(seed: int) -> None: # note that the test is also more likely to fail due to the much # larger amount of tokens in the input may increase the variance. for dtype in [torch.half, torch.float]: - for block_size in [8, 16]: + for block_size in [8, 16, 32]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Testing multi_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, '