Skip to content

Commit e6654f2

Browse files
committed
[Fix] the failed unit tests
Signed-off-by: ApostaC <[email protected]>
1 parent f60a8fa commit e6654f2

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

csrc/cache_kernels.cu

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "quantization/fp8/nvidia/quant_utils.cuh"
1212
#endif
1313

14+
#include <cstdio>
1415
#include <algorithm>
1516
#include <cassert>
1617
#include <map>
@@ -27,11 +28,12 @@ template <typename T, typename ACC_T>
2728
__global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src,
2829
ACC_T src_to_dst, const int num_pages,
2930
const int num_elements_per_page) {
30-
const int srcPageIdx = src_to_dst[blockIdx.x][0];
31-
const int dstPageIdx = src_to_dst[blockIdx.x][1];
31+
const int64_t srcPageIdx = src_to_dst[blockIdx.x][0];
32+
const int64_t dstPageIdx = src_to_dst[blockIdx.x][1];
3233

33-
const int srcPageOffset = srcPageIdx * num_elements_per_page;
34-
const int dstPageOffset = dstPageIdx * num_elements_per_page;
34+
35+
const int64_t srcPageOffset = srcPageIdx * num_elements_per_page;
36+
const int64_t dstPageOffset = dstPageIdx * num_elements_per_page;
3537

3638
for (int i = threadIdx.x; i < num_elements_per_page; i += blockDim.x) {
3739
dst[dstPageOffset + i] = src[srcPageOffset + i];
@@ -45,6 +47,7 @@ void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src,
4547
const torch::Tensor& block_mapping,
4648
const int num_blocks,
4749
const int block_size_in_bytes) {
50+
c10::cuda::CUDAGuard device_guard(block_mapping.device());
4851
auto block_mapping_accessor =
4952
block_mapping.packed_accessor32<int64_t, 2, torch::RestrictPtrTraits>();
5053

@@ -125,6 +128,25 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
125128
// fall back to the slow implementation
126129
swap_blocks_slow(src, dst, block_mapping.cpu());
127130
} else {
131+
// Check the device
132+
torch::Device src_device = src.device();
133+
torch::Device dst_device = dst.device();
134+
torch::Device block_mapping_device = block_mapping.device();
135+
TORCH_CHECK(block_mapping_device.is_cuda(),
136+
"block_mapping must be on GPU");
137+
if (src_device.is_cuda() && dst_device.is_cuda()) {
138+
TORCH_CHECK(src_device.index() == dst_device.index(),
139+
"src and dst must be on the same GPU");
140+
}
141+
if (src_device.is_cuda()) {
142+
TORCH_CHECK(src_device.index() == block_mapping_device.index(),
143+
"src and block_mapping must be on the same GPU");
144+
}
145+
if (dst_device.is_cuda()) {
146+
TORCH_CHECK(dst_device.index() == block_mapping_device.index(),
147+
"src and block_mapping must be on the same GPU");
148+
}
149+
128150
const int64_t num_blocks = block_mapping.size(0);
129151
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
130152

tests/kernels/test_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def test_swap_blocks(
362362
block_mapping = list(zip(src_blocks, dst_blocks))
363363
block_mapping_tensor = torch.tensor(block_mapping,
364364
dtype=torch.int64,
365-
device="cuda").view(-1, 2)
365+
device=device).view(-1, 2)
366366

367367
# Create the KV caches on the first device.
368368
src_key_caches, src_value_caches = kv_cache_factory(

vllm/entrypoints/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def __init__(
163163
gpu_memory_utilization: float = 0.9,
164164
swap_space: float = 4,
165165
cpu_offload_gb: float = 0,
166-
block_allocator: str = "CpuOffloadingBlockAllocator",
166+
block_allocator: str = "CpuGpuBlockAllocator",
167167
enforce_eager: Optional[bool] = None,
168168
max_seq_len_to_capture: int = 8192,
169169
disable_custom_all_reduce: bool = False,

0 commit comments

Comments
 (0)