From e8ed722498795baa79e598ee165ba47a1592a7d0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Apr 2023 01:12:59 +0000 Subject: [PATCH 1/9] Add support for block size 1, 2, 4 --- cacheflow/master/block_manager.py | 2 +- cacheflow/master/server.py | 2 +- csrc/attention.cpp | 16 - csrc/attention_kernels.cu | 527 +----------------------------- csrc/cuda_primitives.h | 58 +++- 5 files changed, 59 insertions(+), 546 deletions(-) diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index 30dfa1e8c28e..b931742a117f 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -15,7 +15,7 @@ def __init__( block_size: int, num_blocks: int, ) -> None: - if block_size not in [8, 16, 32]: + if block_size not in [1, 2, 4, 8, 16, 32]: raise ValueError(f'Unsupported block size: {block_size}' 'The block size must be one of {8, 16, 32}.') self.device = device diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 3370058703ca..edae566c3312 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -180,7 +180,7 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') # KV cache arguments - parser.add_argument('--block-size', type=int, default=8, choices=[8, 16, 32], help='token block size') + parser.add_argument('--block-size', type=int, default=8, choices=[1, 2, 4, 8, 16, 32], help='token block size') # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). diff --git a/csrc/attention.cpp b/csrc/attention.cpp index 57dff9dc0b2b..bb2766c1d6b6 100644 --- a/csrc/attention.cpp +++ b/csrc/attention.cpp @@ -11,25 +11,9 @@ void single_query_cached_kv_attention( int block_size, int max_context_len); -void multi_query_cached_kv_attention( - torch::Tensor& cu_query_lens, - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len); - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "single_query_cached_kv_attention", &single_query_cached_kv_attention, "Compute the attention between an input query and the cached key/value tensors"); - m.def( - "multi_query_cached_kv_attention", - &multi_query_cached_kv_attention, - "Compute the attention between multiple input queries and the cached key/value tensors"); } diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index fb7ada38d935..2f911e1af080 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -42,7 +42,8 @@ __global__ void single_query_cached_kv_attention_kernel( // fetch or comput 16 bytes at a time. // For example, if the size of a thread group is 4 and the data type is half, // then the vector size is 16 / (4 * sizeof(half)) == 2. - constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); + constexpr int _VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); + constexpr int VEC_SIZE = _VEC_SIZE > 0 ? _VEC_SIZE : 1; using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; @@ -161,7 +162,8 @@ __global__ void single_query_cached_kv_attention_kernel( __syncthreads(); // Each thread will fetch 16 bytes from the value cache at a time. - constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); + constexpr int _V_VEC_SIZE = 16 / sizeof(scalar_t); + constexpr int V_VEC_SIZE = BLOCK_SIZE > _V_VEC_SIZE ? _V_VEC_SIZE : BLOCK_SIZE; using V_vec = typename Vec::Type; using L_vec = typename FloatVec::Type; @@ -254,291 +256,6 @@ __global__ void single_query_cached_kv_attention_kernel( } } - -// Grid: (num_heads, num_query_tokens). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS> -__device__ void multi_query_cached_kv_attention_kernel_unoptimized_( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const int seq_start_idx, - const int seq_len, - const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] - const float scale, - const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] - const int context_len, - const int max_num_blocks_per_seq, - const int q_stride) { - constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int thread_idx = threadIdx.x; - const int warp_idx = thread_idx / WARP_SIZE; - const int lane = thread_idx % WARP_SIZE; - - const int head_idx = blockIdx.x; - const int num_heads = gridDim.x; - const int seq_idx = blockIdx.y; - - // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or comput 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. - constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); - using K_vec = typename Vec::Type; - using Q_vec = typename Vec::Type; - - constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; - constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - - const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; - const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - - // Load the query to registers. - // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. - const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - Q_vec q_vecs[NUM_VECS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); - } - - // Memory planning. - extern __shared__ char shared_mem[]; - // NOTE(woosuk): We use FP32 logits and accumulation. - float *logits = reinterpret_cast(shared_mem); - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // x == THREAD_GROUP_SIZE * VEC_SIZE - // Each thread group fetches x elements from the key at a time. - constexpr int x = 16 / sizeof(scalar_t); - float qk_max = -FLT_MAX; - - const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx); - - // Iterate over the key blocks. - // Each warp fetches a block of keys for each iteration. - // Each thread group in a warp fetches a key from the block, and computes - // dot product with the query. - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { - const int physical_block_number = block_table[block_idx]; - const int physical_block_offset = thread_group_idx % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - - // Load a key to registers. - // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. - K_vec k_vecs[NUM_VECS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { - const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE - + head_idx * HEAD_SIZE * BLOCK_SIZE - + physical_block_offset * x; - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - const int offset1 = (vec_idx * VEC_SIZE) / x; - const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[i] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - } - - // Compute dot product. - // This includes a reduction across the threads in the same thread group. - const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); - const bool mask = token_idx >= mask_boundary; - - if (thread_group_offset == 0) { - // Store the partial reductions to shared memory. - // NOTE(woosuk): It is required to zero out the masked logits. - logits[token_idx] = mask ? 0.f : qk; - // Update the max value. - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - } - } - - // Perform reduction across the threads in the same warp to get the - // max qk value for each "warp" (not across the thread block yet). - // The 0-th thread of each thread group already has its max qk value. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - __syncthreads(); - - // TODO(woosuk): Refactor this part. - // Get the max qk value for the sequence. - qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - // Broadcast the max qk value to all threads. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Get the sum of the exp values. - float exp_sum = 0.f; - for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) { - float val = __expf(logits[i] - qk_max); - logits[i] = val; - exp_sum += val; - } - exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - - // Compute softmax. - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - __syncthreads(); - - // Each thread will fetch 16 bytes from the value cache at a time. - constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); - using V_vec = typename Vec::Type; - using L_vec = typename FloatVec::Type; - - constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; - constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; - - float accs[NUM_ROWS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - accs[i] = 0.f; - } - - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { - const int physical_block_number = block_table[block_idx]; - const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - L_vec logits_vec = *reinterpret_cast(logits + token_idx); - - const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE - + head_idx * HEAD_SIZE * BLOCK_SIZE; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE) { - const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec = *reinterpret_cast(v_ptr + offset); - accs[i] += dot(logits_vec, cast_to_float(v_vec)); - } - } - } - - // Perform reduction within each warp. -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - float acc = accs[i]; -#pragma unroll - for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += __shfl_xor_sync(uint32_t(-1), acc, mask); - } - accs[i] = acc; - } - - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. - __syncthreads(); - - // Perform reduction across warps. - float* out_smem = reinterpret_cast(shared_mem); -#pragma unroll - for (int i = NUM_WARPS; i > 1; i /= 2) { - int mid = i / 2; - // Upper warps write to shared memory. - if (warp_idx >= mid && warp_idx < i) { - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - dst[row_idx] = accs[i]; - } - } - } - __syncthreads(); - - // Lower warps update the output. - if (warp_idx < mid) { - const float* src = &out_smem[warp_idx * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - accs[i] += src[row_idx]; - } - } - } - __syncthreads(); - } - - // Write the final output. - if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - convert_from_float(*(out_ptr + row_idx), accs[i]); - } - } - } -} - - -// Grid: (num_heads, num_query_tokens). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS> -__global__ void multi_query_cached_kv_attention_kernel( - const int* cu_query_lens, // [num_prompts+1] - const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] - const float scale, - const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_prompts] - const int max_num_blocks_per_seq, - const int q_stride) { - const int seq_idx = blockIdx.y; - const int prompt_idx = seq_prompt_mapping[seq_idx]; - const int seq_start_idx = cu_query_lens[prompt_idx]; - const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx; - const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; - const int context_len = context_lens[prompt_idx]; - multi_query_cached_kv_attention_kernel_unoptimized_< - scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( - out, - q, - seq_start_idx, - seq_len, - k_cache, - v_cache, - scale, - block_table, - context_len, - max_num_blocks_per_seq, - q_stride); -} - } // namespace cacheflow #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ @@ -574,6 +291,9 @@ void single_query_cached_kv_attention_launcher( int max_num_blocks_per_seq = block_tables.size(1); int query_stride = query.stride(0); + int thread_group_size = WARP_SIZE / BLOCK_SIZE; + assert(head_size % thread_group_size == 0); + T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); @@ -634,43 +354,8 @@ void single_query_cached_kv_attention( // TODO(woosuk): Support BF16. if (query.element_size() == 2) { // Half. - if (block_size == 8) { - single_query_cached_kv_attention_launcher( - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); - } else if (block_size == 16) { - single_query_cached_kv_attention_launcher( - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); - } else if (block_size == 32) { - single_query_cached_kv_attention_launcher( - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); - } else { - assert(false); - } - } else if (query.element_size() == 4) { - // Float. - if (block_size == 8) { - single_query_cached_kv_attention_launcher( + if (block_size == 1) { + single_query_cached_kv_attention_launcher( out, query, key_cache, @@ -679,8 +364,8 @@ void single_query_cached_kv_attention( block_tables, context_lens, max_context_len); - } else if (block_size == 16) { - single_query_cached_kv_attention_launcher( + } else if (block_size == 2) { + single_query_cached_kv_attention_launcher( out, query, key_cache, @@ -689,8 +374,8 @@ void single_query_cached_kv_attention( block_tables, context_lens, max_context_len); - } else if (block_size == 32) { - single_query_cached_kv_attention_launcher( + } else if (block_size == 4) { + single_query_cached_kv_attention_launcher( out, query, key_cache, @@ -699,182 +384,8 @@ void single_query_cached_kv_attention( block_tables, context_lens, max_context_len); - } else { - assert(false); - } - } else { - assert(false); - } -} - - -#define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ - cacheflow::multi_query_cached_kv_attention_kernel \ - <<>>( \ - cu_query_lens_ptr, \ - seq_prompt_mapping_ptr, \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - query_stride); - - -// TODO(woosuk): Tune NUM_THREADS. -template< - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 128> -void multi_query_cached_kv_attention_launcher( - torch::Tensor& cu_query_lens, - torch::Tensor& seq_prompt_mapping, - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int query_stride = query.stride(0); - - int* cu_query_lens_ptr = cu_query_lens.data_ptr(); - int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr(); - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - int shared_mem_size = std::max(logits_size, outputs_size); - - dim3 grid(num_heads, num_seqs); - dim3 block(NUM_THREADS); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - case 32: - LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); - break; - case 64: - LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); - break; - case 80: - LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); - break; - case 96: - LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); - break; - case 128: - LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); - break; - case 160: - LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); - break; - case 192: - LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); - break; - case 256: - LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); - break; - default: - assert(false); - break; - } -} - -void multi_query_cached_kv_attention( - torch::Tensor& cu_query_lens, - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len) { - - torch::Tensor query_lens = cu_query_lens.to(torch::kCPU); - - int num_queries = query_lens.size(0) - 1; - const int* query_lens_ptr = query_lens.data_ptr(); - int num_seqs = query.size(0); - - torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32)); - auto accessor = cpu_tensor.accessor(); - for (int i = 0, query_cursor = 0; i < num_seqs; ++i) { - if (i >= query_lens_ptr[query_cursor + 1]) { - ++query_cursor; - } - accessor[i] = query_cursor; - } - - // TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA) - // implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving - // the mapping as an input parameter. Let's do this optimization in a later PR. - torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA); - - // TODO(woosuk): Support BF16. - if (query.element_size() == 2) { - // Half. - if (block_size == 8) { - multi_query_cached_kv_attention_launcher( - cu_query_lens, - seq_prompt_mapping, - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); - } else if (block_size == 16) { - multi_query_cached_kv_attention_launcher( - cu_query_lens, - seq_prompt_mapping, - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); - } else if (block_size == 32) { - multi_query_cached_kv_attention_launcher( - cu_query_lens, - seq_prompt_mapping, - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); - } else { - assert(false); - } - } else if (query.element_size() == 4) { - // Float. - if (block_size == 8) { - multi_query_cached_kv_attention_launcher( - cu_query_lens, - seq_prompt_mapping, + } else if (block_size == 8) { + single_query_cached_kv_attention_launcher( out, query, key_cache, @@ -884,9 +395,7 @@ void multi_query_cached_kv_attention( context_lens, max_context_len); } else if (block_size == 16) { - multi_query_cached_kv_attention_launcher( - cu_query_lens, - seq_prompt_mapping, + single_query_cached_kv_attention_launcher( out, query, key_cache, @@ -896,9 +405,7 @@ void multi_query_cached_kv_attention( context_lens, max_context_len); } else if (block_size == 32) { - multi_query_cached_kv_attention_launcher( - cu_query_lens, - seq_prompt_mapping, + single_query_cached_kv_attention_launcher( out, query, key_cache, diff --git a/csrc/cuda_primitives.h b/csrc/cuda_primitives.h index f8f137a7eb56..10e730fd7bda 100644 --- a/csrc/cuda_primitives.h +++ b/csrc/cuda_primitives.h @@ -1074,6 +1074,21 @@ inline __device__ float sum(Float8_ v) //////////////////////////////////////////////////////////////////////////////////////////////////// +inline __device__ float dot(float a, float b) +{ + return a * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float dot(float2 a, float2 b) +{ + float2 c = mul(a, b); + return c.x + c.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline __device__ float dot(Float4_ a, Float4_ b) { float2 acc = mul(a.x, b.x); @@ -1253,37 +1268,44 @@ inline __device__ float convert_to_float(uint4 u) //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ float cast_to_float(float u) -{ - return u; -} +// inline __device__ float cast_to_float(float u) +// { +// return u; +// } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ float2 cast_to_float(float2 u) -{ - return u; -} +// inline __device__ float2 cast_to_float(float2 u) +// { +// return u; +// } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ float4 cast_to_float(float4 u) -{ - return u; -} +// inline __device__ float4 cast_to_float(float4 u) +// { +// return u; +// } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ Float4_ cast_to_float(Float4_ u) -{ - return u; -} +// inline __device__ Float4_ cast_to_float(Float4_ u) +// { +// return u; +// } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// inline __device__ Float8_ cast_to_float(Float8_ u) +// { +// return u; +// } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ Float8_ cast_to_float(Float8_ u) +inline __device__ float cast_to_float(uint16_t u) { - return u; + return half_to_float(u); } //////////////////////////////////////////////////////////////////////////////////////////////////// From 4fe90e93768f0c6b29eb307143a07ffc6066c759 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Apr 2023 02:49:39 +0000 Subject: [PATCH 2/9] Add block-s ze to dir name --- benchmark/benchmark_text_completion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmark/benchmark_text_completion.py b/benchmark/benchmark_text_completion.py index 742e2002ae2b..6a61c9828e7b 100644 --- a/benchmark/benchmark_text_completion.py +++ b/benchmark/benchmark_text_completion.py @@ -263,11 +263,12 @@ def get_sampling_dir_name( args.n1, args.n2, args.n3, args.n4, args.n6, args.n2_beam, args.n4_beam, args.n6_beam, args.n8_beam) if args.output_dir is None: args.output_dir = os.path.join( - '../exp', + '../block', dataset_name, f'{model_name}-tp{args.tensor_parallel_size}', sample_dir, 'cacheflow', + f'block{args.block_size}', f'req-rate-{args.request_rate}', f'seed{args.seed}', f'duration-{args.duration}', From 510c46ea7c9d68fbe45993a135828eac035971d3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Apr 2023 06:06:53 +0000 Subject: [PATCH 3/9] Use max and min --- csrc/attention_kernels.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 2f911e1af080..24a5ed9409ba 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -8,6 +8,8 @@ #include #define WARP_SIZE 32 +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) namespace cacheflow { @@ -42,8 +44,7 @@ __global__ void single_query_cached_kv_attention_kernel( // fetch or comput 16 bytes at a time. // For example, if the size of a thread group is 4 and the data type is half, // then the vector size is 16 / (4 * sizeof(half)) == 2. - constexpr int _VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); - constexpr int VEC_SIZE = _VEC_SIZE > 0 ? _VEC_SIZE : 1; + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; @@ -162,8 +163,7 @@ __global__ void single_query_cached_kv_attention_kernel( __syncthreads(); // Each thread will fetch 16 bytes from the value cache at a time. - constexpr int _V_VEC_SIZE = 16 / sizeof(scalar_t); - constexpr int V_VEC_SIZE = BLOCK_SIZE > _V_VEC_SIZE ? _V_VEC_SIZE : BLOCK_SIZE; + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename FloatVec::Type; @@ -423,3 +423,5 @@ void single_query_cached_kv_attention( } #undef WARP_SIZE +#undef MAX +#undef MIN From 10807c5c611c3bd3dd72d43cccf6e38b84f03785 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Apr 2023 08:16:07 +0000 Subject: [PATCH 4/9] Support block size 64, 128, 256 --- cacheflow/master/block_manager.py | 3 - cacheflow/master/server.py | 4 +- csrc/attention_kernels.cu | 134 ++++++++++++------------------ 3 files changed, 56 insertions(+), 85 deletions(-) diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index b931742a117f..0b188508d15c 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -15,9 +15,6 @@ def __init__( block_size: int, num_blocks: int, ) -> None: - if block_size not in [1, 2, 4, 8, 16, 32]: - raise ValueError(f'Unsupported block size: {block_size}' - 'The block size must be one of {8, 16, 32}.') self.device = device self.block_size = block_size self.num_blocks = num_blocks diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index edae566c3312..9e668ef26f3d 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -180,9 +180,9 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') # KV cache arguments - parser.add_argument('--block-size', type=int, default=8, choices=[1, 2, 4, 8, 16, 32], help='token block size') + parser.add_argument('--block-size', type=int, default=8, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size') # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. - parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') + parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 24a5ed9409ba..42a0db1cc1a7 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -29,7 +29,8 @@ __global__ void single_query_cached_kv_attention_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const int q_stride) { - constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -41,7 +42,7 @@ __global__ void single_query_cached_kv_attention_kernel( // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread group - // fetch or comput 16 bytes at a time. + // fetch or compute 16 bytes at a time. // For example, if the size of a thread group is 4 and the data type is half, // then the vector size is 16 / (4 * sizeof(half)) == 2. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); @@ -90,37 +91,40 @@ __global__ void single_query_cached_kv_attention_kernel( // dot product with the query. for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; - const int physical_block_offset = thread_group_idx % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; // Load a key to registers. // Each thread in a thread group has a different part of the key. // For example, if the the thread group size is 4, then the first thread in the group // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th // vectors of the key, and so on. - K_vec k_vecs[NUM_VECS_PER_THREAD]; + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = thread_group_idx % BLOCK_SIZE + i * WARP_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + #pragma unroll - for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { - const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE - + head_idx * HEAD_SIZE * BLOCK_SIZE - + physical_block_offset * x; - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - const int offset1 = (vec_idx * VEC_SIZE) / x; - const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[i] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - } + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + + head_idx * HEAD_SIZE * BLOCK_SIZE + + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } - // Compute dot product. - // This includes a reduction across the threads in the same thread group. - const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); - const bool mask = token_idx >= context_len; - - if (thread_group_offset == 0) { - // Store the partial reductions to shared memory. - // NOTE(woosuk): It is required to zero out the masked logits. - logits[token_idx] = mask ? 0.f : qk; - // Update the max value. - qk_max = mask ? qk_max : fmaxf(qk_max, qk); + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + const bool mask = token_idx >= context_len; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } } } @@ -291,7 +295,7 @@ void single_query_cached_kv_attention_launcher( int max_num_blocks_per_seq = block_tables.size(1); int query_stride = query.stride(0); - int thread_group_size = WARP_SIZE / BLOCK_SIZE; + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); assert(head_size % thread_group_size == 0); T* out_ptr = reinterpret_cast(out.data_ptr()); @@ -341,6 +345,17 @@ void single_query_cached_kv_attention_launcher( } } +#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + single_query_cached_kv_attention_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len); + void single_query_cached_kv_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] @@ -355,69 +370,28 @@ void single_query_cached_kv_attention( if (query.element_size() == 2) { // Half. if (block_size == 1) { - single_query_cached_kv_attention_launcher( - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); + CALL_KERNEL_LAUNCHER(uint16_t, 1); } else if (block_size == 2) { - single_query_cached_kv_attention_launcher( - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); + CALL_KERNEL_LAUNCHER(uint16_t, 2); } else if (block_size == 4) { - single_query_cached_kv_attention_launcher( - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); + CALL_KERNEL_LAUNCHER(uint16_t, 4); } else if (block_size == 8) { - single_query_cached_kv_attention_launcher( - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); + CALL_KERNEL_LAUNCHER(uint16_t, 8); } else if (block_size == 16) { - single_query_cached_kv_attention_launcher( - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); + CALL_KERNEL_LAUNCHER(uint16_t, 16); } else if (block_size == 32) { - single_query_cached_kv_attention_launcher( - out, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - max_context_len); + CALL_KERNEL_LAUNCHER(uint16_t, 32); + } else if (block_size == 64) { + CALL_KERNEL_LAUNCHER(uint16_t, 64); + } else if (block_size == 128) { + CALL_KERNEL_LAUNCHER(uint16_t, 128); + } else if (block_size == 256) { + CALL_KERNEL_LAUNCHER(uint16_t, 256); } else { assert(false); } } else { + // Float. assert(false); } } From 90d9910bfac8cdbaa1b2a58dc967e507203d2860 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Apr 2023 08:46:53 +0000 Subject: [PATCH 5/9] bugfix --- csrc/attention_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 42a0db1cc1a7..1bab3d75f325 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -98,7 +98,7 @@ __global__ void single_query_cached_kv_attention_kernel( // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th // vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = thread_group_idx % BLOCK_SIZE + i * WARP_SIZE; + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; K_vec k_vecs[NUM_VECS_PER_THREAD]; From f5484598dca494b001ab0248a3fff2510d457e02 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Apr 2023 09:37:07 +0000 Subject: [PATCH 6/9] Minor --- benchmark/benchmark_text_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/benchmark_text_completion.py b/benchmark/benchmark_text_completion.py index 6a61c9828e7b..e6741d577c6e 100644 --- a/benchmark/benchmark_text_completion.py +++ b/benchmark/benchmark_text_completion.py @@ -263,7 +263,7 @@ def get_sampling_dir_name( args.n1, args.n2, args.n3, args.n4, args.n6, args.n2_beam, args.n4_beam, args.n6_beam, args.n8_beam) if args.output_dir is None: args.output_dir = os.path.join( - '../block', + '../exp', dataset_name, f'{model_name}-tp{args.tensor_parallel_size}', sample_dir, From 32ff328ce697e0311556b70397b5bc83e2f1627f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Apr 2023 09:45:42 +0000 Subject: [PATCH 7/9] Change default block size to 16 --- cacheflow/master/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 9e668ef26f3d..5b8110a3dab4 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -180,7 +180,7 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') # KV cache arguments - parser.add_argument('--block-size', type=int, default=8, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size') + parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size') # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). From 7213b989599239cb5b82f406c73f4dd248693f5d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Apr 2023 15:48:37 +0000 Subject: [PATCH 8/9] Comment out multi-query cached kv attention --- csrc/attention_kernels.cu | 495 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 495 insertions(+) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 1bab3d75f325..c25acbb8be6f 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -396,6 +396,501 @@ void single_query_cached_kv_attention( } } +// namespace cacheflow { + +// // Grid: (num_heads, num_query_tokens). +// template< +// typename scalar_t, +// int HEAD_SIZE, +// int BLOCK_SIZE, +// int NUM_THREADS> +// __device__ void multi_query_cached_kv_attention_kernel_unoptimized_( +// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] +// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] +// const int seq_start_idx, +// const int seq_len, +// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] +// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] +// const float scale, +// const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] +// const int context_len, +// const int max_num_blocks_per_seq, +// const int q_stride) { +// constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; +// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; +// const int thread_idx = threadIdx.x; +// const int warp_idx = thread_idx / WARP_SIZE; +// const int lane = thread_idx % WARP_SIZE; + +// const int head_idx = blockIdx.x; +// const int num_heads = gridDim.x; +// const int seq_idx = blockIdx.y; + +// // A vector type to store a part of a key or a query. +// // The vector size is configured in such a way that the threads in a thread group +// // fetch or comput 16 bytes at a time. +// // For example, if the size of a thread group is 4 and the data type is half, +// // then the vector size is 16 / (4 * sizeof(half)) == 2. +// constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); +// using K_vec = typename Vec::Type; +// using Q_vec = typename Vec::Type; + +// constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; +// constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + +// const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; +// const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + +// // Load the query to registers. +// // Each thread in a thread group has a different part of the query. +// // For example, if the the thread group size is 4, then the first thread in the group +// // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... +// // th vectors of the query, and so on. +// // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. +// const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; +// Q_vec q_vecs[NUM_VECS_PER_THREAD]; +// #pragma unroll +// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { +// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; +// q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); +// } + +// // Memory planning. +// extern __shared__ char shared_mem[]; +// // NOTE(woosuk): We use FP32 logits and accumulation. +// float *logits = reinterpret_cast(shared_mem); +// // Workspace for reduction. +// __shared__ float red_smem[2 * NUM_WARPS]; + +// // x == THREAD_GROUP_SIZE * VEC_SIZE +// // Each thread group fetches x elements from the key at a time. +// constexpr int x = 16 / sizeof(scalar_t); +// float qk_max = -FLT_MAX; + +// const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; +// const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx); + +// // Iterate over the key blocks. +// // Each warp fetches a block of keys for each iteration. +// // Each thread group in a warp fetches a key from the block, and computes +// // dot product with the query. +// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { +// const int physical_block_number = block_table[block_idx]; +// const int physical_block_offset = thread_group_idx % BLOCK_SIZE; +// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + +// // Load a key to registers. +// // Each thread in a thread group has a different part of the key. +// // For example, if the the thread group size is 4, then the first thread in the group +// // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th +// // vectors of the key, and so on. +// K_vec k_vecs[NUM_VECS_PER_THREAD]; +// #pragma unroll +// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { +// const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE +// + head_idx * HEAD_SIZE * BLOCK_SIZE +// + physical_block_offset * x; +// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; +// const int offset1 = (vec_idx * VEC_SIZE) / x; +// const int offset2 = (vec_idx * VEC_SIZE) % x; +// k_vecs[i] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); +// } + +// // Compute dot product. +// // This includes a reduction across the threads in the same thread group. +// const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); +// const bool mask = token_idx >= mask_boundary; + +// if (thread_group_offset == 0) { +// // Store the partial reductions to shared memory. +// // NOTE(woosuk): It is required to zero out the masked logits. +// logits[token_idx] = mask ? 0.f : qk; +// // Update the max value. +// qk_max = mask ? qk_max : fmaxf(qk_max, qk); +// } +// } + +// // Perform reduction across the threads in the same warp to get the +// // max qk value for each "warp" (not across the thread block yet). +// // The 0-th thread of each thread group already has its max qk value. +// #pragma unroll +// for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { +// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +// } +// if (lane == 0) { +// red_smem[warp_idx] = qk_max; +// } +// __syncthreads(); + +// // TODO(woosuk): Refactor this part. +// // Get the max qk value for the sequence. +// qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +// #pragma unroll +// for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +// } +// // Broadcast the max qk value to all threads. +// qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + +// // Get the sum of the exp values. +// float exp_sum = 0.f; +// for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) { +// float val = __expf(logits[i] - qk_max); +// logits[i] = val; +// exp_sum += val; +// } +// exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + +// // Compute softmax. +// const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); +// for (int i = thread_idx; i < context_len; i += NUM_THREADS) { +// logits[i] *= inv_sum; +// } +// __syncthreads(); + +// // Each thread will fetch 16 bytes from the value cache at a time. +// constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); +// using V_vec = typename Vec::Type; +// using L_vec = typename FloatVec::Type; + +// constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; +// constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; +// constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + +// float accs[NUM_ROWS_PER_THREAD]; +// #pragma unroll +// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { +// accs[i] = 0.f; +// } + +// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { +// const int physical_block_number = block_table[block_idx]; +// const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; +// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; +// L_vec logits_vec = *reinterpret_cast(logits + token_idx); + +// const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE +// + head_idx * HEAD_SIZE * BLOCK_SIZE; +// #pragma unroll +// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { +// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; +// if (row_idx < HEAD_SIZE) { +// const int offset = row_idx * BLOCK_SIZE + physical_block_offset; +// V_vec v_vec = *reinterpret_cast(v_ptr + offset); +// accs[i] += dot(logits_vec, cast_to_float(v_vec)); +// } +// } +// } + +// // Perform reduction within each warp. +// #pragma unroll +// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { +// float acc = accs[i]; +// #pragma unroll +// for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { +// acc += __shfl_xor_sync(uint32_t(-1), acc, mask); +// } +// accs[i] = acc; +// } + +// // NOTE(woosuk): A barrier is required because the shared memory space for logits +// // is reused for the output. +// __syncthreads(); + +// // Perform reduction across warps. +// float* out_smem = reinterpret_cast(shared_mem); +// #pragma unroll +// for (int i = NUM_WARPS; i > 1; i /= 2) { +// int mid = i / 2; +// // Upper warps write to shared memory. +// if (warp_idx >= mid && warp_idx < i) { +// float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +// #pragma unroll +// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { +// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; +// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { +// dst[row_idx] = accs[i]; +// } +// } +// } +// __syncthreads(); + +// // Lower warps update the output. +// if (warp_idx < mid) { +// const float* src = &out_smem[warp_idx * HEAD_SIZE]; +// #pragma unroll +// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { +// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; +// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { +// accs[i] += src[row_idx]; +// } +// } +// } +// __syncthreads(); +// } + +// // Write the final output. +// if (warp_idx == 0) { +// scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +// #pragma unroll +// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { +// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; +// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { +// convert_from_float(*(out_ptr + row_idx), accs[i]); +// } +// } +// } +// } + + +// // Grid: (num_heads, num_query_tokens). +// template< +// typename scalar_t, +// int HEAD_SIZE, +// int BLOCK_SIZE, +// int NUM_THREADS> +// __global__ void multi_query_cached_kv_attention_kernel( +// const int* cu_query_lens, // [num_prompts+1] +// const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx +// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] +// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] +// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] +// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] +// const float scale, +// const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq] +// const int* __restrict__ context_lens, // [num_prompts] +// const int max_num_blocks_per_seq, +// const int q_stride) { +// const int seq_idx = blockIdx.y; +// const int prompt_idx = seq_prompt_mapping[seq_idx]; +// const int seq_start_idx = cu_query_lens[prompt_idx]; +// const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx; +// const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; +// const int context_len = context_lens[prompt_idx]; +// multi_query_cached_kv_attention_kernel_unoptimized_< +// scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( +// out, +// q, +// seq_start_idx, +// seq_len, +// k_cache, +// v_cache, +// scale, +// block_table, +// context_len, +// max_num_blocks_per_seq, +// q_stride); +// } + +// } // namespace cacheflow + +// #define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ +// cacheflow::multi_query_cached_kv_attention_kernel \ +// <<>>( \ +// cu_query_lens_ptr, \ +// seq_prompt_mapping_ptr, \ +// out_ptr, \ +// query_ptr, \ +// key_cache_ptr, \ +// value_cache_ptr, \ +// scale, \ +// block_tables_ptr, \ +// context_lens_ptr, \ +// max_num_blocks_per_seq, \ +// query_stride); + + +// // TODO(woosuk): Tune NUM_THREADS. +// template< +// typename T, +// int BLOCK_SIZE, +// int NUM_THREADS = 128> +// void multi_query_cached_kv_attention_launcher( +// torch::Tensor& cu_query_lens, +// torch::Tensor& seq_prompt_mapping, +// torch::Tensor& out, +// torch::Tensor& query, +// torch::Tensor& key_cache, +// torch::Tensor& value_cache, +// float scale, +// torch::Tensor& block_tables, +// torch::Tensor& context_lens, +// int max_context_len) { +// int num_seqs = query.size(0); +// int num_heads = query.size(1); +// int head_size = query.size(2); +// int max_num_blocks_per_seq = block_tables.size(1); +// int query_stride = query.stride(0); + +// int* cu_query_lens_ptr = cu_query_lens.data_ptr(); +// int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr(); +// T* out_ptr = reinterpret_cast(out.data_ptr()); +// T* query_ptr = reinterpret_cast(query.data_ptr()); +// T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); +// T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); +// int* block_tables_ptr = block_tables.data_ptr(); +// int* context_lens_ptr = context_lens.data_ptr(); + +// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; +// int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; +// int logits_size = padded_max_context_len * sizeof(float); +// int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); +// int shared_mem_size = std::max(logits_size, outputs_size); + +// dim3 grid(num_heads, num_seqs); +// dim3 block(NUM_THREADS); +// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +// switch (head_size) { +// case 32: +// LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); +// break; +// case 64: +// LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); +// break; +// case 80: +// LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); +// break; +// case 96: +// LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); +// break; +// case 128: +// LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); +// break; +// case 160: +// LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); +// break; +// case 192: +// LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); +// break; +// case 256: +// LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); +// break; +// default: +// assert(false); +// break; +// } +// } + +// void multi_query_cached_kv_attention( +// torch::Tensor& cu_query_lens, +// torch::Tensor& out, +// torch::Tensor& query, +// torch::Tensor& key_cache, +// torch::Tensor& value_cache, +// float scale, +// torch::Tensor& block_tables, +// torch::Tensor& context_lens, +// int block_size, +// int max_context_len) { + +// torch::Tensor query_lens = cu_query_lens.to(torch::kCPU); + +// int num_queries = query_lens.size(0) - 1; +// const int* query_lens_ptr = query_lens.data_ptr(); +// int num_seqs = query.size(0); + +// torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32)); +// auto accessor = cpu_tensor.accessor(); +// for (int i = 0, query_cursor = 0; i < num_seqs; ++i) { +// if (i >= query_lens_ptr[query_cursor + 1]) { +// ++query_cursor; +// } +// accessor[i] = query_cursor; +// } + +// // TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA) +// // implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving +// // the mapping as an input parameter. Let's do this optimization in a later PR. +// torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA); + +// // TODO(woosuk): Support BF16. +// if (query.element_size() == 2) { +// // Half. +// if (block_size == 8) { +// multi_query_cached_kv_attention_launcher( +// cu_query_lens, +// seq_prompt_mapping, +// out, +// query, +// key_cache, +// value_cache, +// scale, +// block_tables, +// context_lens, +// max_context_len); +// } else if (block_size == 16) { +// multi_query_cached_kv_attention_launcher( +// cu_query_lens, +// seq_prompt_mapping, +// out, +// query, +// key_cache, +// value_cache, +// scale, +// block_tables, +// context_lens, +// max_context_len); +// } else if (block_size == 32) { +// multi_query_cached_kv_attention_launcher( +// cu_query_lens, +// seq_prompt_mapping, +// out, +// query, +// key_cache, +// value_cache, +// scale, +// block_tables, +// context_lens, +// max_context_len); +// } else { +// assert(false); +// } +// } else if (query.element_size() == 4) { +// // Float. +// if (block_size == 8) { +// multi_query_cached_kv_attention_launcher( +// cu_query_lens, +// seq_prompt_mapping, +// out, +// query, +// key_cache, +// value_cache, +// scale, +// block_tables, +// context_lens, +// max_context_len); +// } else if (block_size == 16) { +// multi_query_cached_kv_attention_launcher( +// cu_query_lens, +// seq_prompt_mapping, +// out, +// query, +// key_cache, +// value_cache, +// scale, +// block_tables, +// context_lens, +// max_context_len); +// } else if (block_size == 32) { +// multi_query_cached_kv_attention_launcher( +// cu_query_lens, +// seq_prompt_mapping, +// out, +// query, +// key_cache, +// value_cache, +// scale, +// block_tables, +// context_lens, +// max_context_len); +// } else { +// assert(false); +// } +// } else { +// assert(false); +// } +// } + #undef WARP_SIZE #undef MAX #undef MIN From 6337f356cf99ed1dc72be565fd58d61c0e89509a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Apr 2023 15:59:23 +0000 Subject: [PATCH 9/9] Enforce FCFS --- cacheflow/master/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 66d1fa4e2c56..da461798bb6e 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -125,7 +125,8 @@ def _schedule( # Swap in the sequence groups in the SWAPPED state if possible. self.swapped = self.policy.sort_by_priority(now, self.swapped) - while self.swapped: + # FCFS + while self.swapped and not blocks_to_swap_out: seq_group = self.swapped[0] # If the sequence group has been preempted in this step, stop. if seq_group in preempted: