Skip to content

Commit 0f2035c

Browse files
r-barnesfacebook-github-bot
authored andcommitted
Fix CUDA kernel index data type in faiss/gpu/impl/DistanceUtils.cuh +10 (#4246)
Summary: Pull Request resolved: #4246 CUDA kernel variables matching the type `(thread|block|grid).(Idx|Dim).(x|y|z)` [have the data type `uint`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#built-in-variables). Many programmers mistakenly use implicit casts to turn these data types into `int`. In fact, the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) it self is inconsistent and incorrect in its use of data types in programming examples. The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items. While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them. Reviewed By: dtolnay Differential Revision: D71355340 fbshipit-source-id: 77dac270e1d3415bfe7d5cc214006d5176508474
1 parent 1dcbb4a commit 0f2035c

10 files changed

Lines changed: 26 additions & 26 deletions

faiss/gpu/impl/DistanceUtils.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ __global__ void incrementIndex(
303303
int k,
304304
idx_t increment) {
305305
for (idx_t i = blockIdx.y; i < indices.getSize(0); i += gridDim.y) {
306-
for (int j = threadIdx.x; j < k; j += blockDim.x) {
306+
for (auto j = threadIdx.x; j < k; j += blockDim.x) {
307307
indices[i][idx_t(blockIdx.x) * k + j] += blockIdx.x * increment;
308308
}
309309
}

faiss/gpu/impl/GpuScalarQuantizer.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_8bit, DimMultiple> {
377377
smemVmin = smem;
378378
smemVdiff = smem + dim;
379379

380-
for (int i = threadIdx.x; i < dim; i += blockDim.x) {
380+
for (auto i = threadIdx.x; i < dim; i += blockDim.x) {
381381
// We are performing vmin + vdiff * (v + 0.5) / (2^bits - 1)
382382
// This can be simplified to vmin' + vdiff' * v where:
383383
// vdiff' = vdiff / (2^bits - 1)
@@ -587,7 +587,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_6bit, 1> {
587587
smemVmin = smem;
588588
smemVdiff = smem + dim;
589589

590-
for (int i = threadIdx.x; i < dim; i += blockDim.x) {
590+
for (auto i = threadIdx.x; i < dim; i += blockDim.x) {
591591
// We are performing vmin + vdiff * (v + 0.5) / (2^bits - 1)
592592
// This can be simplified to vmin' + vdiff' * v where:
593593
// vdiff' = vdiff / (2^bits - 1)
@@ -753,7 +753,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_4bit, 1> {
753753
smemVmin = smem;
754754
smemVdiff = smem + dim;
755755

756-
for (int i = threadIdx.x; i < dim; i += blockDim.x) {
756+
for (auto i = threadIdx.x; i < dim; i += blockDim.x) {
757757
// We are performing vmin + vdiff * (v + 0.5) / (2^bits - 1)
758758
// This can be simplified to vmin' + vdiff' * v where:
759759
// vdiff' = vdiff / (2^bits - 1)

faiss/gpu/impl/IVFAppend.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,9 +368,9 @@ __global__ void ivfInterleavedAppend(
368368
// The set of addresses for each of the lists
369369
void** listData) {
370370
// FIXME: some issue with getLaneId() and CUDA 10.1 and P4 GPUs?
371-
int laneId = threadIdx.x % kWarpSize;
372-
int warpId = threadIdx.x / kWarpSize;
373-
int warpsPerBlock = blockDim.x / kWarpSize;
371+
auto laneId = threadIdx.x % kWarpSize;
372+
auto warpId = threadIdx.x / kWarpSize;
373+
auto warpsPerBlock = blockDim.x / kWarpSize;
374374

375375
// Each block is dedicated to a separate list
376376
idx_t listId = uniqueLists[blockIdx.x];

faiss/gpu/impl/IVFFlatScan.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ struct IVFFlatScan {
6565
int limit = utils::divDown(dim, Codec::kDimPerIter);
6666

6767
// Each warp handles a separate chunk of vectors
68-
int warpId = threadIdx.x / kWarpSize;
68+
auto warpId = threadIdx.x / kWarpSize;
6969
// FIXME: why does getLaneId() not work when we write out below!?!?!
70-
int laneId = threadIdx.x % kWarpSize; // getLaneId();
70+
auto laneId = threadIdx.x % kWarpSize; // getLaneId();
7171

7272
// Divide the set of vectors among the warps
7373
idx_t vecsPerWarp = utils::divUp(numVecs, kIVFFlatScanWarps);

faiss/gpu/impl/IVFInterleaved.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ __global__ void ivfInterleavedScan2(
2727
Tensor<float, 2, true> distanceOut,
2828
Tensor<idx_t, 2, true> indicesOut) {
2929
if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) {
30-
int queryId = blockIdx.x;
30+
auto queryId = blockIdx.x;
3131

3232
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
3333

@@ -99,7 +99,7 @@ __global__ void ivfInterleavedScan2(
9999
// Merge all final results
100100
heap.reduce();
101101

102-
for (int i = threadIdx.x; i < k; i += blockDim.x) {
102+
for (auto i = threadIdx.x; i < k; i += blockDim.x) {
103103
// Re-adjust the value we are selecting based on the sorting order
104104
distanceOut[queryId][i] = smemK[i] * adj;
105105
auto packedIndex = smemV[i];

faiss/gpu/impl/IVFInterleaved.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ __global__ void ivfInterleavedScan(
5656

5757
for (idx_t queryId = blockIdx.y; queryId < queries.getSize(0);
5858
queryId += gridDim.y) {
59-
int probeId = blockIdx.x;
59+
auto probeId = blockIdx.x;
6060
idx_t listId = listIds[queryId][probeId];
6161

6262
// Safety guard in case NaNs in input cause no list ID to be
@@ -69,8 +69,8 @@ __global__ void ivfInterleavedScan(
6969
int dim = queries.getSize(1);
7070

7171
// FIXME: some issue with getLaneId() and CUDA 10.1 and P4 GPUs?
72-
int laneId = threadIdx.x % kWarpSize;
73-
int warpId = threadIdx.x / kWarpSize;
72+
auto laneId = threadIdx.x % kWarpSize;
73+
auto warpId = threadIdx.x / kWarpSize;
7474

7575
using EncodeT = typename Codec::EncodeT;
7676

@@ -215,7 +215,7 @@ __global__ void ivfInterleavedScan(
215215
auto distanceOutBase = distanceOut[queryId][probeId].data();
216216
auto indicesOutBase = indicesOut[queryId][probeId].data();
217217

218-
for (int i = threadIdx.x; i < k; i += blockDim.x) {
218+
for (auto i = threadIdx.x; i < k; i += blockDim.x) {
219219
distanceOutBase[i] = smemK[i];
220220
indicesOutBase[i] = smemV[i];
221221
}

faiss/gpu/impl/IVFUtilsSelect1.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ __global__ void pass1SelectLists(
9090

9191
// Write out the final k-selected values; they should be all
9292
// together
93-
for (int i = threadIdx.x; i < k; i += blockDim.x) {
93+
for (auto i = threadIdx.x; i < k; i += blockDim.x) {
9494
heapDistances[queryId][sliceId][i] = smemK[i];
9595
heapIndices[queryId][sliceId][i] = idx_t(smemV[i]);
9696
}

faiss/gpu/impl/IVFUtilsSelect2.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ __global__ void pass2SelectLists(
100100
// Merge all final results
101101
heap.reduce();
102102

103-
for (int i = threadIdx.x; i < k; i += blockDim.x) {
103+
for (auto i = threadIdx.x; i < k; i += blockDim.x) {
104104
outDistances[queryId][i] = smemK[i];
105105

106106
// `v` is the index in `heapIndices`

faiss/gpu/impl/IcmEncoder.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ __global__ void runIcmEncodeStep(
4646
int m) {
4747
using KVPair = Pair<float, int>;
4848

49-
int id = blockIdx.x; // each block takes care of one vector
50-
int code = threadIdx.x; // each thread takes care of one possible code
49+
auto id = blockIdx.x; // each block takes care of one vector
50+
auto code = threadIdx.x; // each thread takes care of one possible code
5151

5252
// compute the objective value by look-up tables
5353
KVPair obj(0.0f, code);
@@ -94,8 +94,8 @@ __global__ void runEvaluation(
9494
int M,
9595
int K,
9696
int dims) {
97-
int id = blockIdx.x; // each block takes care of one vector
98-
int d = threadIdx.x; // each thread takes care of one dimension
97+
auto id = blockIdx.x; // each block takes care of one vector
98+
auto d = threadIdx.x; // each thread takes care of one dimension
9999
float acc = 0.0f;
100100

101101
#pragma unroll
@@ -136,7 +136,7 @@ __global__ void runCodesPerturbation(
136136
int K,
137137
int nperts) {
138138
// each thread takes care of one vector
139-
int id = blockIdx.x * blockDim.x + threadIdx.x;
139+
auto id = blockIdx.x * blockDim.x + threadIdx.x;
140140

141141
if (id >= n) {
142142
return;
@@ -173,7 +173,7 @@ __global__ void runCodesSelection(
173173
int n,
174174
int M) {
175175
// each thread takes care of one vector
176-
int id = blockIdx.x * blockDim.x + threadIdx.x;
176+
auto id = blockIdx.x * blockDim.x + threadIdx.x;
177177

178178
if (id >= n || objs[id] >= bestObjs[id]) {
179179
return;
@@ -195,8 +195,8 @@ __global__ void runCodesSelection(
195195
* @param K number of codewords in a codebook
196196
*/
197197
__global__ void runNormAddition(float* uterm, const float* norm, int K) {
198-
int id = blockIdx.x;
199-
int code = threadIdx.x;
198+
auto id = blockIdx.x;
199+
auto code = threadIdx.x;
200200

201201
uterm[id * K + code] += norm[code];
202202
}

faiss/gpu/impl/L2Norm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ __global__ void l2NormRowMajor(
4040
// these are fine to be int (just based on block dimensions)
4141
int numWarps = utils::divUp(blockDim.x, kWarpSize);
4242
int laneId = getLaneId();
43-
int warpId = threadIdx.x / kWarpSize;
43+
auto warpId = threadIdx.x / kWarpSize;
4444

4545
bool lastRowTile = (blockIdx.x == (gridDim.x - 1));
4646
idx_t rowStart = idx_t(blockIdx.x) * RowTileSize;

0 commit comments

Comments
 (0)