From e12a642d49fe3c8b87d14af48aa1e9fca1d5b17a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 22 Jul 2024 19:01:53 -0700 Subject: [PATCH 1/7] Add fp8 support to `reshape_and_cache_flash` --- csrc/cache.h | 14 ++-- csrc/cache_kernels.cu | 93 +++++++++++++++++---------- csrc/torch_bindings.cpp | 3 +- tests/kernels/test_cache.py | 34 ++++++++-- vllm/_custom_ops.py | 4 +- vllm/attention/backends/flash_attn.py | 2 + vllm/attention/backends/flashinfer.py | 2 + 7 files changed, 105 insertions(+), 47 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 52177e8901a8..41964ebf14b0 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -21,11 +21,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, const std::string& kv_cache_dtype, const double k_scale, const double v_scale); -void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); +void reshape_and_cache_flash( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + const double k_scale, + const double v_scale); // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index caef7f5e1863..591359c8825b 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -203,17 +203,21 @@ __global__ void reshape_and_cache_kernel( } } -template +template __global__ void reshape_and_cache_flash_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, - // head_size] - scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, - // head_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, const int key_stride, const int value_stride, - const int num_heads, const int head_size, const int block_size) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size] + cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size, + const float k_scale, + const float v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -228,11 +232,19 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_value_idx = block_idx * block_stride + - block_offset * num_heads * head_size + - head_idx * head_size + head_offset; - k_cache[tgt_value_idx] = key[src_key_idx]; - v_cache[tgt_value_idx] = value[src_value_idx]; + const int64_t tgt_key_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + + head_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_value_idx] = tgt_key; + value_cache[tgt_key_value_idx] = tgt_value; + } else { + key_cache[tgt_key_value_idx] = fp8::scaled_convert(tgt_key, k_scale); + value_cache[tgt_key_value_idx] = fp8::scaled_convert(tgt_value, v_scale); + } } } } // namespace vllm @@ -278,13 +290,35 @@ void reshape_and_cache( CALL_RESHAPE_AND_CACHE) } +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + block_stride, \ + key_stride, \ + value_stride, \ + num_heads, \ + head_size, \ + block_size, \ + k_scale, \ + v_scale); + void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype) { + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, + const double k_scale, + const double v_scale) +{ // FIXME: only support auto datatype, does not support fp8 if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); @@ -292,26 +326,19 @@ void reshape_and_cache_flash( int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); - int block_size = k_cache.size(1); + int block_size = key_cache.size(1); int key_stride = key.stride(0); int value_stride = value.stride(0); - int block_stride = k_cache.stride(0); - TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); + int block_stride = key_cache.stride(0); + TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), "reshape_and_cache_flash", [&] { - vllm::reshape_and_cache_flash_kernel - <<>>( - key.data_ptr(), value.data_ptr(), - k_cache.data_ptr(), v_cache.data_ptr(), - slot_mapping.data_ptr(), block_stride, key_stride, - value_stride, num_heads, head_size, block_size); - }); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE_FLASH); } namespace vllm { diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 0df9bdb75018..3027b63ba2b3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -248,7 +248,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache," " Tensor! value_cache," " Tensor slot_mapping," - " str kv_cache_dtype) -> ()"); + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 70ae3d0c6e0c..f9e060b90426 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -215,8 +215,6 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, ) -> None: - if kv_cache_dtype == "fp8": - pytest.skip() random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -251,12 +249,24 @@ def test_reshape_and_cache_flash( key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() + if kv_cache_dtype == "fp8": + cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + ops.convert_fp8(cloned_key_cache, key_cache) + cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + ops.convert_fp8(cloned_value_cache, value_cache) + else: + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() # Call the reshape_and_cache kernel. ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + slot_mapping, kv_cache_dtype, k_scale, v_scale) + + if kv_cache_dtype == "fp8": + result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + ops.convert_fp8(result_key_cache, key_cache) + result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + ops.convert_fp8(result_value_cache, value_cache) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") @@ -269,8 +279,18 @@ def test_reshape_and_cache_flash( cloned_key_cache[block_idx, block_offset, :, :] = key[i] cloned_value_cache[block_idx, block_offset, :, :] = value[i] - assert torch.allclose(key_cache, cloned_key_cache) - assert torch.allclose(value_cache, cloned_value_cache) + if kv_cache_dtype == "fp8": + assert torch.allclose(result_key_cache, + cloned_key_cache, + atol=0.001, + rtol=0.1) + assert torch.allclose(result_value_cache, + cloned_value_cache, + atol=0.001, + rtol=0.1) + else: + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) @pytest.mark.parametrize("direction", COPYING_DIRECTION) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e5151c070f2f..458ccac94984 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -426,10 +426,12 @@ def reshape_and_cache_flash( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, + k_scale: float, + v_scale: float, ) -> None: torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype) + kv_cache_dtype, k_scale, v_scale) def copy_blocks(key_caches: List[torch.Tensor], diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index b16a204c8f44..949bd973cf3c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -478,6 +478,8 @@ def forward( value_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, + k_scale, + v_scale, ) num_prefill_tokens = attn_metadata.num_prefill_tokens diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 9dac12d3b906..2a4900489df3 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -489,6 +489,8 @@ def forward( kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, + k_scale, + v_scale, ) query = query.contiguous( From 53fc6d7c2bb1c41ed3b7f45ede12957c3648fdd7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 22 Jul 2024 19:07:47 -0700 Subject: [PATCH 2/7] Lint --- tests/kernels/test_cache.py | 3 +++ vllm/_custom_ops.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index f9e060b90426..40b61d56cd4c 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -258,6 +258,9 @@ def test_reshape_and_cache_flash( cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() + # Using default kv_scale + k_scale = v_scale = 1.0 + # Call the reshape_and_cache kernel. ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, v_scale) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 458ccac94984..0186594656cc 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -431,7 +431,8 @@ def reshape_and_cache_flash( ) -> None: torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, k_scale, v_scale) + kv_cache_dtype, k_scale, + v_scale) def copy_blocks(key_caches: List[torch.Tensor], From 6c0eeb0ff90d571b55b507ddbfdf78d5bf5b0088 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 23 Jul 2024 17:12:12 +0000 Subject: [PATCH 3/7] Lint --- csrc/cache.h | 15 ++++---- csrc/cache_kernels.cu | 82 +++++++++++++++++++------------------------ 2 files changed, 43 insertions(+), 54 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 41964ebf14b0..11c4c5001daa 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -21,15 +21,12 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, const std::string& kv_cache_dtype, const double k_scale, const double v_scale); -void reshape_and_cache_flash( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const double k_scale, - const double v_scale); +void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + const double k_scale, const double v_scale); // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 591359c8825b..45972457b3ab 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -203,21 +203,18 @@ __global__ void reshape_and_cache_kernel( } } -template +template __global__ void reshape_and_cache_flash_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size] - cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const float k_scale, - const float v_scale) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, + // head_size] + cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, + // head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, const int key_stride, const int value_stride, + const int num_heads, const int head_size, const int block_size, + const float k_scale, const float v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -232,18 +229,19 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_key_value_idx = block_idx * block_stride - + block_offset * num_heads * head_size - + head_idx * head_size - + head_offset; + const int64_t tgt_key_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_value_idx] = tgt_key; value_cache[tgt_key_value_idx] = tgt_value; } else { - key_cache[tgt_key_value_idx] = fp8::scaled_convert(tgt_key, k_scale); - value_cache[tgt_key_value_idx] = fp8::scaled_convert(tgt_value, v_scale); + key_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_key, k_scale); + value_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_value, v_scale); } } } @@ -293,32 +291,25 @@ void reshape_and_cache( // KV_T is the stored data type of kv-cache. // CACHE_T is the data type of key and value tensors. // KV_DTYPE is the real data type of kv-cache. -#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ - vllm::reshape_and_cache_flash_kernel<<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - block_stride, \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - k_scale, \ - v_scale); +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, key_stride, \ + value_stride, num_heads, head_size, block_size, k_scale, v_scale); void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, - const double k_scale, - const double v_scale) -{ + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& + value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { // FIXME: only support auto datatype, does not support fp8 if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); @@ -338,7 +329,8 @@ void reshape_and_cache_flash( const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE_FLASH); + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE_FLASH); } namespace vllm { From 8d528f74bcd349829cfee6e5c8e0e97012776172 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 23 Jul 2024 13:25:15 -0700 Subject: [PATCH 4/7] Fix test --- vllm/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 83605631b5bd..876c3bf90b02 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -491,7 +491,6 @@ def create_kv_caches_with_random_flash( seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - assert cache_dtype != "fp8" torch.random.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) @@ -507,7 +506,13 @@ def create_kv_caches_with_random_flash( key_value_cache = torch.empty(size=key_value_cache_shape, dtype=torch_dtype, device=device) - key_value_cache.uniform_(-scale, scale) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") key_caches.append(key_value_cache[:, 0]) value_caches.append(key_value_cache[:, 1]) return key_caches, value_caches From d1d859c1dfc01af470b9d8df168960fa930e9685 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 23 Jul 2024 23:46:43 +0000 Subject: [PATCH 5/7] Fix --- csrc/cache_kernels.cu | 4 ---- tests/kernels/test_cache.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 45972457b3ab..1be806bbfa43 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -310,10 +310,6 @@ void reshape_and_cache_flash( torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype, const double k_scale, const double v_scale) { - // FIXME: only support auto datatype, does not support fp8 - if (kv_cache_dtype != "auto") { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 40b61d56cd4c..1d5c75ce16f8 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -246,7 +246,7 @@ def test_reshape_and_cache_flash( dtype, device=device, ) - key_cache, value_cache = key_caches[0], value_caches[0] + key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous() # Clone the KV caches. if kv_cache_dtype == "fp8": From 6d2672f98e88e607e8eb5d201da538b2b168703c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 24 Jul 2024 00:02:21 +0000 Subject: [PATCH 6/7] Lint --- tests/kernels/test_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 1d5c75ce16f8..89ad48ec9806 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -246,7 +246,8 @@ def test_reshape_and_cache_flash( dtype, device=device, ) - key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous() + key_cache, value_cache = key_caches[0].contiguous( + ), value_caches[0].contiguous() # Clone the KV caches. if kv_cache_dtype == "fp8": From 691f7fd117103aa925d2bec679bf85aa473e1719 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 24 Jul 2024 10:10:17 -0700 Subject: [PATCH 7/7] Del in test --- tests/kernels/test_cache.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 89ad48ec9806..f9a609464abf 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -248,6 +248,8 @@ def test_reshape_and_cache_flash( ) key_cache, value_cache = key_caches[0].contiguous( ), value_caches[0].contiguous() + del key_caches + del value_caches # Clone the KV caches. if kv_cache_dtype == "fp8":