Skip to content

Commit 467f70e

Browse files
Jeff Johnsonfacebook-github-bot
authored andcommitted
Consolidate GPU IVF query tile calculation + special handling for large query memory requirements
Summary: In the GPU IVF (Flat, SQ and PQ) code, there is a requirement for using temporary memory for storing unfiltered (or partially filtered) vector distances calculated during list scanning which are k-selected by separate kernels. While a batch query may be presented to an IVF index, the amount of temporary memory needed to store all these unfiltered distances prior to filtering may be very huge depending upon IVF characteristics (such as the maximum number of vectors encoded in any of the IVF lists), in which case we cannot process the entire batch of queries at once and instead must tile over the batch of queries to reuse the temporary memory that we make available for these distances. The old code duplicated this roughly equivalent logic in 3 different places (the IVFFlat/SQ code, IVFPQ with precomputed codes, and IVFPQ without precomputed codes). Furthermore, in the case where either little/no temporary memory was available or where what temporary memory was available was (vastly) exceeded by the amount needed to handle a particular query, the old code enforced a minimum number of queries to be processed at once of 8. In certain cases (huge IVF list imbalance), this memory request could exceed the amount of memory that can be safely allocated on a GPU. This diff consolidates the original 3 separate places where this calculation took place to 1 place in IVFUtils. The logic proceeds roughly as before, to figure out how many queries can be processed in the available temporary memory, except we add a new heuristic in the case where the number of queries that can be concurrently processed falls below 8. This could be either due to little temporary memory being available, or due to huge memory requirements. In this case, we instead ignore the amount of temporary memory available and instead see how many queries' memory requirements would fit into a single 512 MiB memory allocation, so we reasonably cap this amount. If the query still cannot be satisfied with this allocation, we still proceed executing 1 query at a time (which note could still potentially exhaust the GPU memory, but this is an error that is unavoidable). While a different heuristic using the amount of actual memory allocatable on the device could be used instead of this fixed 512 MiB amount, there is no guarantee to my knowledge that a single cudaMalloc up to this limit could succeed (e.g., GPU reports 3 GiB available, you attempt to allocate all of that in a single allocation), so we just pick an amount which is a reasonable balance between efficiency (parallelism) and memory consumption. Note that if not enough temporary memory is available and a single 512 MiB allocation fails, then there is likely little memory to proceed efficiently at all under any scenario, as Faiss does require some headroom in terms of memory available for scratch spaces. Reviewed By: mdouze Differential Revision: D45574455 fbshipit-source-id: 08f5204e3e9656627c9134d7409b9b0960f07b2d
1 parent 411c172 commit 467f70e

File tree

5 files changed

+185
-88
lines changed

5 files changed

+185
-88
lines changed

faiss/gpu/impl/IVFFlatScan.cu

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,6 @@ void runIVFFlatScan(
345345
GpuResources* res) {
346346
auto stream = res->getDefaultStreamCurrentDevice();
347347

348-
constexpr idx_t kMinQueryTileSize = 8;
349-
constexpr idx_t kMaxQueryTileSize = 65536; // used as blockIdx.y dimension
350-
constexpr idx_t kThrustMemSize = 16384;
351-
352348
auto nprobe = listIds.getSize(1);
353349

354350
// If the maximum list length (in terms of number of vectors) times nprobe
@@ -359,37 +355,22 @@ void runIVFFlatScan(
359355

360356
// Make a reservation for Thrust to do its dirty work (global memory
361357
// cross-block reduction space); hopefully this is large enough.
358+
constexpr idx_t kThrustMemSize = 16384;
359+
362360
DeviceTensor<char, 1, true> thrustMem1(
363361
res, makeTempAlloc(AllocType::Other, stream), {kThrustMemSize});
364362
DeviceTensor<char, 1, true> thrustMem2(
365363
res, makeTempAlloc(AllocType::Other, stream), {kThrustMemSize});
366364
DeviceTensor<char, 1, true>* thrustMem[2] = {&thrustMem1, &thrustMem2};
367365

368-
// How much temporary storage is available?
369-
// If possible, we'd like to fit within the space available.
370-
size_t sizeAvailable = res->getTempMemoryAvailableCurrentDevice();
371-
372-
// We run two passes of heap selection
373-
// This is the size of the first-level heap passes
374-
constexpr idx_t kNProbeSplit = 8;
375-
idx_t pass2Chunks = std::min(nprobe, kNProbeSplit);
366+
// How much temporary memory would we need to handle a single query?
367+
size_t sizePerQuery = getIVFPerQueryTempMemory(k, nprobe, maxListLength);
376368

377-
idx_t sizeForFirstSelectPass =
378-
pass2Chunks * k * (sizeof(float) + sizeof(idx_t));
379-
380-
// How much temporary storage we need per each query
381-
idx_t sizePerQuery = 2 * // # streams
382-
((nprobe * sizeof(idx_t) + sizeof(idx_t)) + // prefixSumOffsets
383-
nprobe * maxListLength * sizeof(float) + // allDistances
384-
sizeForFirstSelectPass);
385-
386-
idx_t queryTileSize = sizeAvailable / sizePerQuery;
387-
388-
if (queryTileSize < kMinQueryTileSize) {
389-
queryTileSize = kMinQueryTileSize;
390-
} else if (queryTileSize > kMaxQueryTileSize) {
391-
queryTileSize = kMaxQueryTileSize;
392-
}
369+
// How many queries do we wish to run at once?
370+
idx_t queryTileSize = getIVFQueryTileSize(
371+
queries.getSize(0),
372+
res->getTempMemoryAvailableCurrentDevice(),
373+
sizePerQuery);
393374

394375
// Temporary memory buffers
395376
// Make sure there is space prior to the start which will be 0, and
@@ -428,6 +409,7 @@ void runIVFFlatScan(
428409
DeviceTensor<float, 1, true>* allDistances[2] = {
429410
&allDistances1, &allDistances2};
430411

412+
idx_t pass2Chunks = getIVFKSelectionPass2Chunks(nprobe);
431413
DeviceTensor<float, 3, true> heapDistances1(
432414
res,
433415
makeTempAlloc(AllocType::Other, stream),

faiss/gpu/impl/IVFUtils.cu

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,114 @@
1818
namespace faiss {
1919
namespace gpu {
2020

21+
size_t getIVFKSelectionPass2Chunks(size_t nprobe) {
22+
// We run two passes of heap selection
23+
// This is the size of the second-level heap passes
24+
constexpr size_t kNProbeSplit = 8;
25+
return std::min(nprobe, kNProbeSplit);
26+
}
27+
28+
size_t getIVFPerQueryTempMemory(size_t k, size_t nprobe, size_t maxListLength) {
29+
size_t pass2Chunks = getIVFKSelectionPass2Chunks(nprobe);
30+
31+
size_t sizeForFirstSelectPass =
32+
pass2Chunks * k * (sizeof(float) + sizeof(idx_t));
33+
34+
// Each IVF list being scanned concurrently needs a separate array to
35+
// indicate where the per-IVF list distances are being stored via prefix
36+
// sum. There is one per each nprobe, plus 1 more entry at the end
37+
size_t prefixSumOffsets = nprobe * sizeof(idx_t) + sizeof(idx_t);
38+
39+
// Storage for all distances from all the IVF lists we are processing
40+
size_t allDistances = nprobe * maxListLength * sizeof(float);
41+
42+
// There are 2 streams on which computations is performed (hence the 2 *)
43+
return 2 * (prefixSumOffsets + allDistances + sizeForFirstSelectPass);
44+
}
45+
46+
size_t getIVFPQPerQueryTempMemory(
47+
size_t k,
48+
size_t nprobe,
49+
size_t maxListLength,
50+
bool usePrecomputedCodes,
51+
size_t numSubQuantizers,
52+
size_t numSubQuantizerCodes) {
53+
// Residual PQ distances per each IVF partition (in case we are not using
54+
// precomputed codes;
55+
size_t residualDistances = usePrecomputedCodes
56+
? 0
57+
: (nprobe * numSubQuantizers * numSubQuantizerCodes *
58+
sizeof(float));
59+
60+
// There are 2 streams on which computations is performed (hence the 2 *)
61+
// The IVF-generic temp memory allocation already takes this multi-streaming
62+
// into account, but we need to do so for the PQ residual distances too
63+
return (2 * residualDistances) +
64+
getIVFPerQueryTempMemory(k, nprobe, maxListLength);
65+
}
66+
67+
size_t getIVFQueryTileSize(
68+
size_t numQueries,
69+
size_t tempMemoryAvailable,
70+
size_t sizePerQuery) {
71+
// Our ideal minimum number of queries that we'd like to run concurrently
72+
constexpr size_t kMinQueryTileSize = 8;
73+
74+
// Our absolute maximum number of queries that we can run concurrently
75+
// (based on max Y grid dimension)
76+
constexpr size_t kMaxQueryTileSize = 65536;
77+
78+
// First, see how many queries we can run within the limit of our available
79+
// temporary memory. If all queries can run within the temporary memory
80+
// limit, we'll just use that.
81+
size_t withinTempMemoryNumQueries =
82+
std::min(tempMemoryAvailable / sizePerQuery, numQueries);
83+
84+
// However, there is a maximum cap on the number of queries that we can run
85+
// at once, even if memory were unlimited (due to max Y grid dimension)
86+
withinTempMemoryNumQueries =
87+
std::min(withinTempMemoryNumQueries, kMaxQueryTileSize);
88+
89+
// However. withinTempMemoryNumQueries could be really small, or even zero
90+
// (in the case where there is no temporary memory available, or the memory
91+
// resources for a single query required are really large). If we are below
92+
// the ideal minimum number of queries to run concurrently, then we will
93+
// ignore the temporary memory limit and fall back to a general device
94+
// allocation.
95+
// Note that if we only had a single query, then this is ok to run as-is
96+
if (withinTempMemoryNumQueries < numQueries &&
97+
withinTempMemoryNumQueries < kMinQueryTileSize) {
98+
// Either the amount of temporary memory available is too low, or the
99+
// amount of memory needed to run a single query is really high. Ignore
100+
// the temporary memory available, and always attempt to use this amount
101+
// of memory for temporary results
102+
//
103+
// FIXME: could look at amount of memory available on the current
104+
// device, but there is no guarantee that all that memory available
105+
// could be done in a single allocation, so we just pick a suitably
106+
// large allocation that can yield enough efficiency but something that
107+
// the GPU can likely allocate.
108+
constexpr size_t kMinMemoryAllocation = 512 * 1024 * 1024; // 512 MiB
109+
110+
size_t withinMemoryNumQueries =
111+
std::min(kMinMemoryAllocation / sizePerQuery, numQueries);
112+
113+
// It is possible that the per-query size is incredibly huge, in which
114+
// case even the 512 MiB allocation will not fit it. In this case, we
115+
// have no option except to try running a single one.
116+
return std::max(withinMemoryNumQueries, size_t(1));
117+
} else {
118+
// withinTempMemoryNumQueries cannot be > numQueries.
119+
// Either:
120+
// 1. == numQueries, >= kMinQueryTileSize (i.e., we can satisfy all
121+
// queries in one go, or are limited by max query tile size)
122+
// 2. < numQueries, >= kMinQueryTileSize (i.e., we can't satisfy all
123+
// queries in one go, but we have a large enough batch to run which is
124+
// ok
125+
return withinTempMemoryNumQueries;
126+
}
127+
}
128+
21129
// Calculates the total number of intermediate distances to consider
22130
// for all queries
23131
__global__ void getResultLengths(

faiss/gpu/impl/IVFUtils.cuh

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,39 @@ namespace gpu {
1919

2020
class GpuResources;
2121

22+
/// For the final k-selection of IVF query distances, we perform two passes.
23+
/// The first pass scans some number of per-IVF list distances reducing them to
24+
/// at most 8, then a second pass processes these <= 8 to the single final list
25+
/// of NN candidates
26+
size_t getIVFKSelectionPass2Chunks(size_t nprobe);
27+
28+
/// Function to determine amount of temporary space that we allocate
29+
/// for storing basic IVF list scanning distances during query, based on the
30+
/// memory allocation per query. This is the memory requirement for
31+
/// IVFFlat/IVFSQ but IVFPQ will add some additional allocation as well (see
32+
/// getIVFPQPerQueryTempMemory)
33+
size_t getIVFPerQueryTempMemory(size_t k, size_t nprobe, size_t maxListLength);
34+
35+
/// Function to determine amount of temporary space that we allocate
36+
/// for storing basic IVFPQ list scanning distances during query, based on the
37+
/// memory allocation per query.
38+
size_t getIVFPQPerQueryTempMemory(
39+
size_t k,
40+
size_t nprobe,
41+
size_t maxListLength,
42+
bool usePrecomputedCodes,
43+
size_t numSubQuantizers,
44+
size_t numSubQuantizerCodes);
45+
46+
/// Based on the amount of temporary memory needed per IVF query (determined by
47+
/// one of the above functions) and the amount of current temporary memory
48+
/// available, determine how many queries we will run concurrently in a single
49+
/// tile so as to stay within reasonable temporary memory allocation limits.
50+
size_t getIVFQueryTileSize(
51+
size_t numQueries,
52+
size_t tempMemoryAvailable,
53+
size_t sizePerQuery);
54+
2255
/// Function for multi-pass scanning that collects the length of
2356
/// intermediate results for all (query, probe) pair
2457
void runCalcListOffsets(

faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -550,10 +550,6 @@ void runPQScanMultiPassNoPrecomputed(
550550
GpuResources* res) {
551551
auto stream = res->getDefaultStreamCurrentDevice();
552552

553-
constexpr idx_t kMinQueryTileSize = 8;
554-
constexpr idx_t kMaxQueryTileSize = 65536; // typical max gridDim.y
555-
constexpr idx_t kThrustMemSize = 16384;
556-
557553
auto nprobe = coarseIndices.getSize(1);
558554

559555
// If the maximum list length (in terms of number of vectors) times nprobe
@@ -566,39 +562,28 @@ void runPQScanMultiPassNoPrecomputed(
566562

567563
// Make a reservation for Thrust to do its dirty work (global memory
568564
// cross-block reduction space); hopefully this is large enough.
565+
constexpr idx_t kThrustMemSize = 16384;
566+
569567
DeviceTensor<char, 1, true> thrustMem1(
570568
res, makeTempAlloc(AllocType::Other, stream), {kThrustMemSize});
571569
DeviceTensor<char, 1, true> thrustMem2(
572570
res, makeTempAlloc(AllocType::Other, stream), {kThrustMemSize});
573571
DeviceTensor<char, 1, true>* thrustMem[2] = {&thrustMem1, &thrustMem2};
574572

575-
// How much temporary storage is available?
576-
// If possible, we'd like to fit within the space available.
577-
idx_t sizeAvailable = res->getTempMemoryAvailableCurrentDevice();
578-
579-
// We run two passes of heap selection
580-
// This is the size of the first-level heap passes
581-
constexpr idx_t kNProbeSplit = 8;
582-
idx_t pass2Chunks = std::min(nprobe, kNProbeSplit);
583-
584-
idx_t sizeForFirstSelectPass =
585-
pass2Chunks * k * (sizeof(float) + sizeof(idx_t));
586-
587-
// How much temporary storage we need per each query
588-
idx_t sizePerQuery = 2 * // streams
589-
((nprobe * sizeof(idx_t) + sizeof(idx_t)) + // prefixSumOffsets
590-
nprobe * maxListLength * sizeof(float) + // allDistances
591-
// residual distances
592-
nprobe * numSubQuantizers * numSubQuantizerCodes * sizeof(float) +
593-
sizeForFirstSelectPass);
594-
595-
idx_t queryTileSize = (sizeAvailable / sizePerQuery);
596-
597-
if (queryTileSize < kMinQueryTileSize) {
598-
queryTileSize = kMinQueryTileSize;
599-
} else if (queryTileSize > kMaxQueryTileSize) {
600-
queryTileSize = kMaxQueryTileSize;
601-
}
573+
// How much temporary memory would we need to handle a single query?
574+
size_t sizePerQuery = getIVFPQPerQueryTempMemory(
575+
k,
576+
nprobe,
577+
maxListLength,
578+
false, /* no precomputed codes */
579+
numSubQuantizers,
580+
numSubQuantizerCodes);
581+
582+
// How many queries do we wish to run at once?
583+
idx_t queryTileSize = getIVFQueryTileSize(
584+
queries.getSize(0),
585+
res->getTempMemoryAvailableCurrentDevice(),
586+
sizePerQuery);
602587

603588
// Temporary memory buffers
604589
// Make sure there is space prior to the start which will be 0, and
@@ -664,6 +649,7 @@ void runPQScanMultiPassNoPrecomputed(
664649
DeviceTensor<float, 1, true>* allDistances[2] = {
665650
&allDistances1, &allDistances2};
666651

652+
idx_t pass2Chunks = getIVFKSelectionPass2Chunks(nprobe);
667653
DeviceTensor<float, 3, true> heapDistances1(
668654
res,
669655
makeTempAlloc(AllocType::Other, stream),

faiss/gpu/impl/PQScanMultiPassPrecomputed.cu

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -562,10 +562,6 @@ void runPQScanMultiPassPrecomputed(
562562
GpuResources* res) {
563563
auto stream = res->getDefaultStreamCurrentDevice();
564564

565-
constexpr idx_t kMinQueryTileSize = 8;
566-
constexpr idx_t kMaxQueryTileSize = 65536; // typical max gridDim.y
567-
constexpr idx_t kThrustMemSize = 16384;
568-
569565
auto nprobe = ivfListIds.getSize(1);
570566

571567
// If the maximum list length (in terms of number of vectors) times nprobe
@@ -578,37 +574,28 @@ void runPQScanMultiPassPrecomputed(
578574

579575
// Make a reservation for Thrust to do its dirty work (global memory
580576
// cross-block reduction space); hopefully this is large enough.
577+
constexpr idx_t kThrustMemSize = 16384;
578+
581579
DeviceTensor<char, 1, true> thrustMem1(
582580
res, makeTempAlloc(AllocType::Other, stream), {kThrustMemSize});
583581
DeviceTensor<char, 1, true> thrustMem2(
584582
res, makeTempAlloc(AllocType::Other, stream), {kThrustMemSize});
585583
DeviceTensor<char, 1, true>* thrustMem[2] = {&thrustMem1, &thrustMem2};
586584

587-
// How much temporary storage is available?
588-
// If possible, we'd like to fit within the space available.
589-
size_t sizeAvailable = res->getTempMemoryAvailableCurrentDevice();
590-
591-
// We run two passes of heap selection
592-
// This is the size of the first-level heap passes
593-
constexpr idx_t kNProbeSplit = 8;
594-
idx_t pass2Chunks = std::min(nprobe, kNProbeSplit);
595-
596-
idx_t sizeForFirstSelectPass =
597-
pass2Chunks * k * (sizeof(float) + sizeof(idx_t));
598-
599-
// How much temporary storage we need per each query
600-
idx_t sizePerQuery = 2 * // # streams
601-
((nprobe * sizeof(idx_t) + sizeof(idx_t)) + // prefixSumOffsets
602-
nprobe * maxListLength * sizeof(float) + // allDistances
603-
sizeForFirstSelectPass);
604-
605-
idx_t queryTileSize = sizeAvailable / sizePerQuery;
606-
607-
if (queryTileSize < kMinQueryTileSize) {
608-
queryTileSize = kMinQueryTileSize;
609-
} else if (queryTileSize > kMaxQueryTileSize) {
610-
queryTileSize = kMaxQueryTileSize;
611-
}
585+
// How much temporary memory would we need to handle a single query?
586+
size_t sizePerQuery = getIVFPQPerQueryTempMemory(
587+
k,
588+
nprobe,
589+
maxListLength,
590+
true, /* precomputed codes */
591+
numSubQuantizers,
592+
numSubQuantizerCodes);
593+
594+
// How many queries do we wish to run at once?
595+
idx_t queryTileSize = getIVFQueryTileSize(
596+
queries.getSize(0),
597+
res->getTempMemoryAvailableCurrentDevice(),
598+
sizePerQuery);
612599

613600
// Temporary memory buffers
614601
// Make sure there is space prior to the start which will be 0, and
@@ -647,6 +634,7 @@ void runPQScanMultiPassPrecomputed(
647634
DeviceTensor<float, 1, true>* allDistances[2] = {
648635
&allDistances1, &allDistances2};
649636

637+
idx_t pass2Chunks = getIVFKSelectionPass2Chunks(nprobe);
650638
DeviceTensor<float, 3, true> heapDistances1(
651639
res,
652640
makeTempAlloc(AllocType::Other, stream),

0 commit comments

Comments
 (0)