1111 #include " quantization/fp8/nvidia/quant_utils.cuh"
1212#endif
1313
14+ #include < cstdio>
1415#include < algorithm>
1516#include < cassert>
1617#include < map>
@@ -27,11 +28,12 @@ template <typename T, typename ACC_T>
2728__global__ void paged_copy (T* __restrict__ dst, const T* __restrict__ src,
2829 ACC_T src_to_dst, const int num_pages,
2930 const int num_elements_per_page) {
30- const int srcPageIdx = src_to_dst[blockIdx .x ][0 ];
31- const int dstPageIdx = src_to_dst[blockIdx .x ][1 ];
31+ const int64_t srcPageIdx = src_to_dst[blockIdx .x ][0 ];
32+ const int64_t dstPageIdx = src_to_dst[blockIdx .x ][1 ];
3233
33- const int srcPageOffset = srcPageIdx * num_elements_per_page;
34- const int dstPageOffset = dstPageIdx * num_elements_per_page;
34+
35+ const int64_t srcPageOffset = srcPageIdx * num_elements_per_page;
36+ const int64_t dstPageOffset = dstPageIdx * num_elements_per_page;
3537
3638 for (int i = threadIdx .x ; i < num_elements_per_page; i += blockDim .x ) {
3739 dst[dstPageOffset + i] = src[srcPageOffset + i];
@@ -45,6 +47,7 @@ void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src,
4547 const torch::Tensor& block_mapping,
4648 const int num_blocks,
4749 const int block_size_in_bytes) {
50+ c10::cuda::CUDAGuard device_guard (block_mapping.device ());
4851 auto block_mapping_accessor =
4952 block_mapping.packed_accessor32 <int64_t , 2 , torch::RestrictPtrTraits>();
5053
@@ -125,6 +128,25 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
125128 // fall back to the slow implementation
126129 swap_blocks_slow (src, dst, block_mapping.cpu ());
127130 } else {
131+ // Check the device
132+ torch::Device src_device = src.device ();
133+ torch::Device dst_device = dst.device ();
134+ torch::Device block_mapping_device = block_mapping.device ();
135+ TORCH_CHECK (block_mapping_device.is_cuda (),
136+ " block_mapping must be on GPU" );
137+ if (src_device.is_cuda () && dst_device.is_cuda ()) {
138+ TORCH_CHECK (src_device.index () == dst_device.index (),
139+ " src and dst must be on the same GPU" );
140+ }
141+ if (src_device.is_cuda ()) {
142+ TORCH_CHECK (src_device.index () == block_mapping_device.index (),
143+ " src and block_mapping must be on the same GPU" );
144+ }
145+ if (dst_device.is_cuda ()) {
146+ TORCH_CHECK (dst_device.index () == block_mapping_device.index (),
147+ " src and block_mapping must be on the same GPU" );
148+ }
149+
128150 const int64_t num_blocks = block_mapping.size (0 );
129151 const int64_t block_size_in_bytes = src.element_size () * src[0 ].numel ();
130152
0 commit comments