-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Integrate IVF-PQ from RAFT #3044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 222 commits
Commits
Show all changes
253 commits
Select commit
Hold shift + click to select a range
8baee52
Updating raft ivf flat
cjnolet 2eb94f1
adding raftIVFFlat implementation
cjnolet f7d4185
Isolating quantizer training
cjnolet 26491cb
iAdding todos where we need to plug in raft functionality
cjnolet b4d08c4
Invocatino of index building has been compiled successfully. Still ne…
cjnolet bf876f9
Adding call to search.
cjnolet 9b1fc84
Adding stubs for remaining calls that need to be made from RAFT side in
cjnolet 884bfa5
iUpdating function calls for copyFrom to include populating the quant…
cjnolet 0958d2e
Implement some helpers
achirkin b7144a9
Make it compile
achirkin 38733bb
Make the tests to not crash... sometimes
achirkin 881fbc3
Merge pull request #1 from achirkin/raft_ivf_flat
cjnolet 173c459
Updates
cjnolet 8b7afe0
Merge branch 'raft_ivf_flat' of github.com:cjnolet/faiss into raft_iv…
cjnolet 548e0f0
More updates
cjnolet baa34d7
One test running so far.
cjnolet edc5991
Setting add_data_on_build = false;
cjnolet 10f89b4
Copying centroids directly and adding some prints for the test outputs
cjnolet 7c69020
reconstructions seems to be reasonable
cjnolet 8be7746
iUpdates to tests to compare against brute force as ground truth
cjnolet 933582a
Starting to look at resulting runtimes in raft ivf flat tests
cjnolet d2a6541
Adding timing info to raft test
cjnolet 986407a
Updating for rapids-cmake updates and RAFT updates
cjnolet ae4ed98
Adding RaftIndexIVFPQ
cjnolet d8894b8
Merge branch 'main' into raft_integration
cjnolet 410b2c6
Updates
cjnolet d7ca6b4
Adding FAISS_ENABLE_RAFT option to INSTALL.md
cjnolet 9875dad
Making build.sh work for quick building of proposal
cjnolet 60388dc
Merge branch 'main' into raft_integration
cjnolet c09d09b
Merging upstream
cjnolet 0081ed9
Integrating more deeply with `use_raft` option in the index config that
cjnolet a7e0cdd
IVF Flat
cjnolet fbf7e34
More updates
cjnolet a9b6963
Getting things building again. Adding raft handle to gpu resources.
cjnolet b640ba8
Getting FAISS building again w/ RaftIVFFlat
cjnolet af6d1e9
Adding the append vectors to raft index IVF flat.
cjnolet 545b3d2
Add ing flatindex for the fused l2 knn
cjnolet 2ac5a5b
Validating dispatch of flatindex
cjnolet 68944a5
1. Verified FlatIndex tests are passing (and using RAFT for k<=64 L2 …
cjnolet 3a37031
Calling train() on copyFrom() with reconstructed vectors and filling in
cjnolet 3f51425
IVFFlat gtests run through to completion without crash. Distances look
cjnolet db1801e
Some of the IVFFlat tests are passing.
cjnolet f0bbd41
CLeaning up the diff a bit
cjnolet f7da008
Removing the RaftIndex* files.
cjnolet 5ab762b
Using current raft 22.12
cjnolet 3684cd3
Checking in a little cleanup
cjnolet 35a46b2
Disabling raft from pulling in nn dependencies (e.g. faiss)
cjnolet eb9f6e9
Merge branch 'main' into raft_integration
cjnolet e7bf2e5
Updating raft for 23.02. Still working on failing tests.
cjnolet a8e2ad0
Isolating differences in results- it looks like it's related to the s…
cjnolet f19fd00
Add and query results appear to match well. LargeBatch tests are fail…
cjnolet 3ff97ab
Merge branch 'main' into raft_integration
cjnolet 1d2baed
Merge branch 'main' into raft_integration
cjnolet 6269ed1
Using facebook for licenses in cmake files
cjnolet b13593a
Adding small note to build.sh that the file is temporary.
cjnolet 81fbe64
Merge branch 'main' into raft_integration
cjnolet 333761c
Merge branch 'main' into raft_integration
cjnolet bc8885d
Fixing style
cjnolet 093579b
Merge branch 'master' into raft_integration
cjnolet 10f8080
Merge branch 'raft_integration' of github.com:cjnolet/faiss into raft…
cjnolet 19f38d4
Second pass of fixing formatting
cjnolet 17df798
Third pass at fixing format style
cjnolet 2993441
Adding nvidia license for traceability
cjnolet 5e7eb6d
Updates
cjnolet d1b0036
Merge remote-tracking branch 'faiss/main' into raft_integration
cjnolet ddc75ac
Merging
cjnolet 4ada77c
Merge branch 'main' into raft_integration
cjnolet ccc3bad
Merge branch 'main' into raft_integration
cjnolet 37ec2fa
Fix PR problems (#2839)
c07208e
Merge remote-tracking branch 'faiss/main' into raft_integration
cjnolet d3a98cc
Fixing cmakelists
cjnolet d91de3c
Merge remote-tracking branch 'alexanderguzhva/export-D45054275' into …
cjnolet 0af95a4
Updates
cjnolet 36bed23
Merge branch 'main' into raft_integration
cjnolet fc7f1f8
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 8769115
Merge branch 'main' into raft_integration
cjnolet 576f58f
Fixing merge
cjnolet eef2b28
Merge branch 'raft_integration' of https://github.com/cjnolet/faiss i…
tarang-jain 092721f
Removing indexflat tests from changeset
cjnolet 1c621ad
First version of copyFrom and copyTo
tarang-jain 62b568b
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 808f1d1
Merge branch 'raft_integration' of https://github.com/cjnolet/faiss i…
tarang-jain b8d616d
Update copyFrom and copyTo
tarang-jain cf87175
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 444c58d
Passing tests
tarang-jain 2887af8
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain f148f09
Passing copyTo
tarang-jain 284937b
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 5756508
All tests passing
tarang-jain c5edbf7
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 82e9791
cleanup
tarang-jain 8486b9b
cleanup
tarang-jain 38215bc
cleanup
tarang-jain ac67897
cleanup
tarang-jain 94817aa
cleanup
tarang-jain 91b1e32
cleanup
tarang-jain 613ca7a
cleanup
tarang-jain c43c83f
Separate out nan filtering
tarang-jain 7eb5209
Add USE_NVIDIA_RAFT
tarang-jain db1774b
Update test
tarang-jain 8cf7e05
update quantizer
tarang-jain ad9a596
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain a17b1f3
All except LongIVFList passing
tarang-jain d90d923
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 3c33ebb
Formatting
tarang-jain 971a6b2
Format
tarang-jain 5c0592e
remove debug statements
tarang-jain 7618b44
LargeBatch test added and now passing
tarang-jain bd5a217
final update to gtests
tarang-jain 2022a14
Pull latest
tarang-jain d0f8385
Merge branch 'main' into raft_integration
tarang-jain 5e52aca
im
tarang-jain a39d208
Initial machinery for RAFT support
tarang-jain 59236cb
merge main
tarang-jain c5db6cd
some ivf-flat assertions
tarang-jain 15ee365
More changes to IVF-Flat, fill in more funcs for IVF-PQ
tarang-jain 5c1e671
Filling in funcs for copyFrom and copyTo
tarang-jain c59f551
Update copyfrom
tarang-jain 6170c8c
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 4400d4f
testing
tarang-jain ff77e55
testing
tarang-jain f942371
update cmake for testing
tarang-jain ba2f8d5
expt
tarang-jain c917c81
Update impl, testing
tarang-jain 3e6fbea
GpuIndexIVF train working
tarang-jain 5943124
Raft index external update successful
tarang-jain f000652
ivfflat tests passing
tarang-jain 6b790f2
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 4a3254b
all tests passing except assertions
tarang-jain baaa844
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 22bf90c
cleanup
tarang-jain 66222ff
cleanup
tarang-jain 229ceda
correction
tarang-jain c63764a
Style
tarang-jain c4e48bf
cleanup
tarang-jain e940c75
style
tarang-jain dfebcb9
remove cudart_utils import
tarang-jain 925682e
Update warning msg
tarang-jain a718bd5
PQ assign_index WARN
tarang-jain ae1b243
Mark raft symbols with hidden visibility
robertmaynard 40864df
Update index_cpu_to_gpu
tarang-jain 867c0e8
benchmarking script for ivfflat
tarang-jain 8e5a3dc
Update IVF-PQ BM script
tarang-jain 5b6db4e
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 0a806bb
Update IVF-Flat BM script
tarang-jain ececb62
Docs, rename API
tarang-jain 217c7a6
format
tarang-jain 15411db
format
tarang-jain f81a9db
resolve build issue
tarang-jain 7cb7316
resolve failing copyTo_Raft
tarang-jain e993c01
style
tarang-jain 934a835
resolve failing LargeBatch
tarang-jain fd7c84d
update BM script, LargeBatch_Raft test
tarang-jain 575fe82
make build with raft-ann-bench, update bench to sift1M
tarang-jain 9a8492f
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain dbcc576
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain c8ea4a4
add sync_stream to search
tarang-jain a145998
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 0c82e30
Update raft add and search
tarang-jain 53bf4c0
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 4a74523
first commit
tarang-jain 9600fe0
Update AllocRequest
tarang-jain d73129d
use only custom MR for managed allocs
tarang-jain 480859a
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 05bcf38
remove dbg statements
tarang-jain 990c9d0
remove bugs
tarang-jain b894248
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain e0c3d05
bug fix
tarang-jain 401e640
address PR reviews
tarang-jain 394738b
Merge branch 'main' into rmm-mem-alloc
tarang-jain 2078c54
Merge branch 'main' into rmm-mem-alloc
tarang-jain 045537b
updated until rmm-pool-alloc cleanup 92e75b52433553d48e27937bf86fd922…
tarang-jain bb87df3
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 69fcfeb
remove bug due to nullptr indices
tarang-jain 8ce165e
Merge branch 'main' into rmm-mem-alloc
tarang-jain c9a23a5
update benchmarking, testing scripts
tarang-jain bc33838
Merge branch 'main' into raft_integration
tarang-jain 366dda3
format
tarang-jain ebec04c
Merge branch 'raft_integration' of https://github.com/tarang-jain/fai…
tarang-jain cf98b7d
Merge branch 'rmm-mem-alloc' of https://github.com/tarang-jain/faiss …
tarang-jain f3cc680
Merge branch 'main' into raft_integration
tarang-jain e2777d9
Removing build.sh
cjnolet 80742c9
Merge branch 'main' into raft_integration2
cjnolet 83124ae
Trying to pin fmt
cjnolet 130caa8
Merge branch 'raft_integration2' of github.com:cjnolet/faiss into raf…
cjnolet 2065e94
Adding conditional to disable RAFT if arch == pascal
cjnolet f24272b
Fixing a few more places
cjnolet abfc0e1
Fixing some syntax errors
cjnolet d7caab4
Fixing style
cjnolet 9c0ed23
DIsabling below volta
cjnolet 23f3041
A couple additional fixes
cjnolet f57e467
More updates
cjnolet 4b00a16
Missed one
cjnolet 40e4c11
Fixing style
cjnolet f66d1fe
More style fixes
cjnolet 24b6b74
Fixing another bug
cjnolet 6234107
Fixing RaftUtils in cmakelists
cjnolet f0985f4
Updates after PR reviews
tarang-jain 692e9bd
Merge branch 'raft_integration' of https://github.com/tarang-jain/fai…
tarang-jain 0469924
merge
tarang-jain 131b191
Merge branch 'main' into raft_integration
tarang-jain 46afd1e
Updates after PR reviews
tarang-jain bee418b
merge
tarang-jain f90b3e6
merge
tarang-jain 44bd50a
FIxing cmakelists
cjnolet 92475aa
changes to gpu CMakeLists
tarang-jain 2397bf3
changes to gpu CMakeLists
tarang-jain dfb7d32
changes to gpu CMakeLists
tarang-jain c2afb4c
changes to gpu CMakeLists
tarang-jain 1483389
restore comment
tarang-jain 496e923
changes to gpu CMakeLists
tarang-jain da6103d
FIxing cmakelists
cjnolet fcfadd4
changes to gpu CMakeLists
tarang-jain 03e7edb
changes to gpu CMakeLists
tarang-jain bfbc9a4
rebase
tarang-jain ca234e7
fix gpu CMakeLists after rebase
tarang-jain 248c1b9
fix gpu CMakeLists after rebase
tarang-jain 8d0eade
update latest changes from upstream to gpu CMakeLists
tarang-jain ac8c9a1
small style fix
tarang-jain 1d79f3c
GCC visibility for more files
tarang-jain 5790b48
Changes to gtests for default use_raft=true
tarang-jain 7f1d848
more small changes to IVFPQ tests
tarang-jain 5579056
small change in comment
tarang-jain 315d1b8
cleanup
tarang-jain 53329b7
temporary fix for failing CI test
tarang-jain 6125884
some updates to python tests
tarang-jain b5e8fcc
small change, remove fmt from faiss-gpu-raft meta.yaml
tarang-jain b4d9b78
use_raft in trainResidualQuantizer_
tarang-jain 77336a8
change interleavedLayout assertion to warning
tarang-jain a33029f
undo small change
tarang-jain 89368d4
update python tests
tarang-jain 66a8d63
format
tarang-jain 9be05b5
empty commit
tarang-jain bb9e0a3
more use_raft = False for failing tests
tarang-jain 2d73896
undo last change
tarang-jain 4e0e3de
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain be09adc
small change to ivf training
tarang-jain a6706dd
format
tarang-jain da2a425
updates after PR reviews
tarang-jain 24340ab
update niter
tarang-jain fda9ac8
undo change to getListLength
tarang-jain 1b05400
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain 7a0b61c
update should_use_raft, fix typo
tarang-jain c48f55c
Merge branch 'main' into raft_integration
tarang-jain 1fcd272
Merge branch 'main' into raft_integration
tarang-jain 4db273b
change default use_raft in GpuClonerOptions and GpuDistance
tarang-jain 059d036
Merge branch 'raft_integration' of https://github.com/tarang-jain/fai…
tarang-jain 2b33f28
Merge branch 'main' of https://github.com/facebookresearch/faiss into…
tarang-jain fad85a1
changes to pass CI
tarang-jain 17203b2
passing py tests
tarang-jain File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||
| # | ||
| # This source code is licensed under the MIT license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) 2023, NVIDIA CORPORATION. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| import numpy as np | ||
| import faiss | ||
| import time | ||
| import argparse | ||
| import rmm | ||
|
|
||
| ###################################################### | ||
| # Command-line parsing | ||
| ###################################################### | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
|
|
||
|
|
||
| def aa(*args, **kwargs): | ||
| group.add_argument(*args, **kwargs) | ||
|
|
||
|
|
||
| group = parser.add_argument_group('benchmarking options') | ||
|
|
||
| aa('--bm_train', default=False, action='store_true', | ||
| help='whether to benchmark train operation on GPU index') | ||
| aa('--bm_add', default=False, action='store_true', | ||
| help='whether to benchmark add operation on GPU index') | ||
| aa('--bm_search', default=True, | ||
| help='whether to benchmark search operation on GPU index') | ||
| aa('--raft_only', default=False, action='store_true', | ||
| help='whether to only produce RAFT enabled benchmarks') | ||
|
|
||
|
|
||
| group = parser.add_argument_group('IVF options') | ||
| aa('--n_centroids', default=256, type=int, | ||
| help="number of IVF centroids") | ||
|
|
||
|
|
||
| group = parser.add_argument_group('searching') | ||
|
|
||
| aa('--k', default=100, type=int, help='nb of nearest neighbors') | ||
| aa('--nprobe', default=50, help='nb of IVF lists to probe') | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| print("args:", args) | ||
|
|
||
| rs = np.random.RandomState(123) | ||
|
|
||
| res = faiss.StandardGpuResources() | ||
|
|
||
| # Use an RMM pool memory resource for device allocations | ||
| mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource()) | ||
| rmm.mr.set_current_device_resource(mr) | ||
|
|
||
| def bench_train_milliseconds(index, trainVecs, use_raft): | ||
| co = faiss.GpuMultipleClonerOptions() | ||
| co.use_raft = use_raft | ||
| index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) | ||
| t0 = time.time() | ||
| index_gpu.train(trainVecs) | ||
| return 1000*(time.time() - t0) | ||
|
|
||
|
|
||
| if args.bm_train: | ||
| print("=" * 40) | ||
| print("GPU Train Benchmarks") | ||
| print("=" * 40) | ||
| trainset_sizes = [5000, 10000, 100000, 1000000, 5000000] | ||
| dataset_dims = [128, 256, 1024] | ||
| for n_rows in trainset_sizes: | ||
| for n_cols in dataset_dims: | ||
| index = faiss.index_factory(n_cols, "IVF{},Flat".format(args.n_centroids)) | ||
| trainVecs = rs.rand(n_rows, n_cols).astype('float32') | ||
| raft_gpu_train_time = bench_train_milliseconds( | ||
| index, trainVecs, True) | ||
| if args.raft_only: | ||
| print("Method: IVFFlat, Operation: TRAIN, dim: %d, n_centroids %d, numTrain: %d, RAFT enabled GPU train time: %.3f milliseconds" % ( | ||
| n_cols, args.n_centroids, n_rows, raft_gpu_train_time)) | ||
| else: | ||
| classical_gpu_train_time = bench_train_milliseconds( | ||
| index, trainVecs, False) | ||
| print("Method: IVFFlat, Operation: TRAIN, dim: %d, n_centroids %d, numTrain: %d, classical GPU train time: %.3f milliseconds, RAFT enabled GPU train time: %.3f milliseconds" % ( | ||
| n_cols, args.n_centroids, n_rows, classical_gpu_train_time, raft_gpu_train_time)) | ||
|
|
||
|
|
||
| def bench_add_milliseconds(index, addVecs, use_raft): | ||
| co = faiss.GpuMultipleClonerOptions() | ||
| co.use_raft = use_raft | ||
| index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) | ||
| index_gpu.copyFrom(index) | ||
| t0 = time.time() | ||
| index_gpu.add(addVecs) | ||
| return 1000*(time.time() - t0) | ||
|
|
||
|
|
||
| if args.bm_add: | ||
| print("=" * 40) | ||
| print("GPU Add Benchmarks") | ||
| print("=" * 40) | ||
| addset_sizes = [5000, 10000, 100000, 1000000] | ||
| dataset_dims = [128, 256, 1024] | ||
| n_train = 10000 | ||
| trainVecs = rs.rand(n_train, n_cols).astype('float32') | ||
| index = faiss.index_factory( | ||
| n_cols, "IVF" + str(args.n_centroids) + ",Flat") | ||
| index.train(trainVecs) | ||
| for n_rows in addset_sizes: | ||
| for n_cols in dataset_dims: | ||
| addVecs = rs.rand(n_rows, n_cols).astype('float32') | ||
| raft_gpu_add_time = bench_add_milliseconds(index, addVecs, True) | ||
| if args.raft_only: | ||
| print("Method: IVFFlat, Operation: ADD, dim: %d, n_centroids %d, numAdd: %d, RAFT enabled GPU add time: %.3f milliseconds" % ( | ||
| n_train, n_rows, n_cols, args.n_centroids, raft_gpu_add_time)) | ||
| else: | ||
| classical_gpu_add_time = bench_add_milliseconds( | ||
| index, addVecs, False) | ||
| print("Method: IVFFlat, Operation: ADD, dim: %d, n_centroids %d, numAdd: %d, classical GPU add time: %.3f milliseconds, RAFT enabled GPU add time: %.3f milliseconds" % ( | ||
| n_train, n_rows, n_cols, args.n_centroids, classical_gpu_add_time, raft_gpu_add_time)) | ||
|
|
||
|
|
||
| def bench_search_milliseconds(index, addVecs, queryVecs, nprobe, k, use_raft): | ||
| co = faiss.GpuMultipleClonerOptions() | ||
| co.use_raft = use_raft | ||
| index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) | ||
| index_gpu.copyFrom(index) | ||
| index_gpu.add(addVecs) | ||
| index_gpu.nprobe = nprobe | ||
| t0 = time.time() | ||
| index_gpu.search(queryVecs, k) | ||
| return 1000*(time.time() - t0) | ||
|
|
||
|
|
||
| if args.bm_search: | ||
| print("=" * 40) | ||
| print("GPU Search Benchmarks") | ||
| print("=" * 40) | ||
| queryset_sizes = [5000, 10000, 100000, 500000] | ||
| n_train = 10000 | ||
| n_add = 100000 | ||
| search_bm_dims = [8, 16, 32] | ||
| for n_cols in search_bm_dims: | ||
| index = faiss.index_factory(n_cols, "IVF{},Flat".format(args.n_centroids)) | ||
| trainVecs = rs.rand(n_train, n_cols).astype('float32') | ||
| index.train(trainVecs) | ||
| addVecs = rs.rand(n_add, n_cols).astype('float32') | ||
| for n_rows in queryset_sizes: | ||
| queryVecs = rs.rand(n_rows, n_cols).astype('float32') | ||
| raft_gpu_search_time = bench_search_milliseconds( | ||
| index, addVecs, queryVecs, args.nprobe, args.k, True) | ||
| if args.raft_only: | ||
| print("Method: IVFFlat, Operation: SEARCH, dim: %d, n_centroids: %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, RAFT enabled GPU search time: %.3f milliseconds" % ( | ||
| n_cols, args.n_centroids, n_add, n_rows, args.nprobe, args.k, raft_gpu_search_time)) | ||
| else: | ||
| classical_gpu_search_time = bench_search_milliseconds( | ||
| index, addVecs, queryVecs, args.nprobe, args.k, False) | ||
| print("Method: IVFFlat, Operation: SEARCH, dim: %d, n_centroids: %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, classical GPU search time: %.3f milliseconds, RAFT enabled GPU search time: %.3f milliseconds" % ( | ||
| n_cols, args.n_centroids, n_add, n_rows, args.nprobe, args.k, classical_gpu_search_time, raft_gpu_search_time)) | ||
|
|
||
| print("=" * 40) | ||
| print("Large RAFT Enabled Benchmarks") | ||
| print("=" * 40) | ||
| # Avoid classical GPU Benchmarks for large datasets because of OOM for more than 500000 queries and/or large dims as well as for large k | ||
| queryset_sizes = [100000, 500000, 1000000] | ||
| large_search_bm_dims = [128, 256, 1024] | ||
| for n_cols in large_search_bm_dims: | ||
| trainVecs = rs.rand(n_train, n_cols).astype('float32') | ||
| index = faiss.index_factory( | ||
| n_cols, "IVF" + str(args.n_centroids) + ",Flat") | ||
| index.train(trainVecs) | ||
| addVecs = rs.rand(n_add, n_cols).astype('float32') | ||
| for n_rows in queryset_sizes: | ||
| queryVecs = rs.rand(n_rows, n_cols).astype('float32') | ||
| raft_gpu_search_time = bench_search_milliseconds( | ||
| index, addVecs, queryVecs, args.nprobe, args.k, True) | ||
| print("Method: IVFFlat, Operation: SEARCH, numTrain: %d, dim: %d, n_centroids: %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, RAFT enabled GPU search time: %.3f milliseconds" % ( | ||
| n_cols, args.n_centroids, n_add, n_rows, args.nprobe, args.k, raft_gpu_search_time)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||
| # | ||
| # This source code is licensed under the MIT license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) 2023, NVIDIA CORPORATION. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| import numpy as np | ||
| import faiss | ||
| import time | ||
| import argparse | ||
| import rmm | ||
|
|
||
| ###################################################### | ||
| # Command-line parsing | ||
| ###################################################### | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
|
|
||
| from datasets import load_sift1M, evaluate | ||
|
|
||
|
|
||
| print("load data") | ||
| xb, xq, xt, gt = load_sift1M() | ||
|
|
||
| def aa(*args, **kwargs): | ||
| group.add_argument(*args, **kwargs) | ||
|
|
||
|
|
||
| group = parser.add_argument_group('benchmarking options') | ||
| aa('--raft_only', default=False, action='store_true', | ||
| help='whether to only produce RAFT enabled benchmarks') | ||
|
|
||
| group = parser.add_argument_group('IVF options') | ||
| aa('--bits_per_code', default=8, type=int, help='bits per code. Note that < 8 is only supported when RAFT is enabled') | ||
| aa('--pq_len', default=2, type=int, help='number of vector elements represented by one PQ code') | ||
| aa('--use_precomputed', default=True, type=bool, help='use precomputed codes (not with RAFT enabled)') | ||
|
|
||
| group = parser.add_argument_group('searching') | ||
| aa('--k', default=10, type=int, help='nb of nearest neighbors') | ||
| aa('--nprobe', default=50, type=int, help='nb of IVF lists to probe') | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| print("args:", args) | ||
|
|
||
| rs = np.random.RandomState(123) | ||
|
|
||
| res = faiss.StandardGpuResources() | ||
|
|
||
| # Use an RMM pool memory resource for device allocations | ||
| mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource()) | ||
| rmm.mr.set_current_device_resource(mr) | ||
|
|
||
| # A heuristic to select a suitable number of lists | ||
| def compute_nlist(numVecs): | ||
| nlist = np.sqrt(numVecs) | ||
| if (numVecs / nlist < 1000): | ||
| nlist = numVecs / 1000 | ||
| return int(nlist) | ||
|
|
||
|
|
||
| def bench_train_milliseconds(index, trainVecs, use_raft): | ||
| co = faiss.GpuMultipleClonerOptions() | ||
| # use float 16 lookup tables to save space | ||
| co.useFloat16LookupTables = True | ||
| co.use_raft = use_raft | ||
| index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) | ||
| t0 = time.time() | ||
| index_gpu.train(trainVecs) | ||
| return 1000*(time.time() - t0) | ||
|
|
||
| n_rows, n_cols = xb.shape | ||
| n_train, _ = xt.shape | ||
| M = n_cols // args.pq_len | ||
| nlist = compute_nlist(n_rows) | ||
| index = faiss.index_factory(n_cols, "IVF{},PQ{}x{}np".format(nlist, M, args.bits_per_code)) | ||
|
|
||
| print("=" * 40) | ||
| print("GPU Train Benchmarks") | ||
| print("=" * 40) | ||
| raft_gpu_train_time = bench_train_milliseconds(index, xt, True) | ||
| if args.raft_only: | ||
| print("Method: IVFPQ, Operation: TRAIN, dim: %d, n_centroids %d, numSubQuantizers %d, bitsPerCode %d, numTrain: %d, RAFT enabled GPU train time: %.3f milliseconds" % ( | ||
| n_cols, nlist, M, args.bits_per_code, n_train, raft_gpu_train_time)) | ||
| else: | ||
| classical_gpu_train_time = bench_train_milliseconds( | ||
| index, xt, False) | ||
| print("Method: IVFPQ, Operation: TRAIN, dim: %d, n_centroids %d, numSubQuantizers %d, bitsPerCode %d, numTrain: %d, classical GPU train time: %.3f milliseconds, RAFT enabled GPU train time: %.3f milliseconds" % ( | ||
| n_cols, nlist, M, args.bits_per_code, n_train, classical_gpu_train_time, raft_gpu_train_time)) | ||
|
|
||
|
|
||
| def bench_add_milliseconds(index, addVecs, use_raft): | ||
| co = faiss.GpuMultipleClonerOptions() | ||
| # use float 16 lookup tables to save space | ||
| co.useFloat16LookupTables = True | ||
| co.use_raft = use_raft | ||
| index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) | ||
| index_gpu.copyFrom(index) | ||
| t0 = time.time() | ||
| index_gpu.add(addVecs) | ||
| return 1000*(time.time() - t0) | ||
|
|
||
| print("=" * 40) | ||
| print("GPU Add Benchmarks") | ||
| print("=" * 40) | ||
| index.train(xt) | ||
| raft_gpu_add_time = bench_add_milliseconds(index, xb, True) | ||
| if args.raft_only: | ||
| print("Method: IVFPQ, Operation: ADD, dim: %d, n_centroids %d numSubQuantizers %d, bitsPerCode %d, numAdd %d, RAFT enabled GPU add time: %.3f milliseconds" % ( | ||
| n_cols, nlist, M, args.bits_per_code, n_rows, raft_gpu_add_time)) | ||
| else: | ||
| classical_gpu_add_time = bench_add_milliseconds( | ||
| index, xb, False) | ||
| print("Method: IVFFPQ, Operation: ADD, dim: %d, n_centroids %d, numSubQuantizers %d, bitsPerCode %d, numAdd %d, classical GPU add time: %.3f milliseconds, RAFT enabled GPU add time: %.3f milliseconds" % ( | ||
| n_cols, nlist, M, args.bits_per_code, n_rows, classical_gpu_add_time, raft_gpu_add_time)) | ||
|
|
||
|
|
||
| def bench_search_milliseconds(index, addVecs, queryVecs, nprobe, k, use_raft): | ||
| co = faiss.GpuMultipleClonerOptions() | ||
| co.use_raft = use_raft | ||
| co.useFloat16LookupTables = True | ||
| index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) | ||
| index_gpu.copyFrom(index) | ||
| index_gpu.add(addVecs) | ||
| index_gpu.nprobe = nprobe | ||
| t0 = time.time() | ||
| index_gpu.search(queryVecs, k) | ||
| return 1000*(time.time() - t0) | ||
|
|
||
|
|
||
| if args.bm_search: | ||
| print("=" * 40) | ||
| print("GPU Search Benchmarks") | ||
| print("=" * 40) | ||
| queryset_sizes = [1, 10, 100, 1000, 10000] | ||
| n_train, n_cols = xt.shape | ||
| n_add, _ = xb.shape | ||
| print(xq.shape) | ||
| M = n_cols // args.pq_len | ||
| nlist = compute_nlist(n_add) | ||
| index = faiss.index_factory(n_cols, "IVF{},PQ{}x{}np".format(nlist, M, args.bits_per_code)) | ||
| index.train(xt) | ||
| for n_rows in queryset_sizes: | ||
| queryVecs = xq[np.random.choice(xq.shape[0], n_rows, replace=False)] | ||
| raft_gpu_search_time = bench_search_milliseconds( | ||
| index, xb, queryVecs, args.nprobe, args.k, True) | ||
| if args.raft_only: | ||
| print("Method: IVFPQ, Operation: SEARCH, dim: %d, n_centroids: %d, numSubQuantizers %d, bitsPerCode %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, RAFT enabled GPU search time: %.3f milliseconds" % ( | ||
| n_cols, nlist, M, args.bits_per_code, n_add, n_rows, args.nprobe, args.k, raft_gpu_search_time)) | ||
| else: | ||
| classical_gpu_search_time = bench_search_milliseconds( | ||
| index, xb, queryVecs, args.nprobe, args.k, False) | ||
| print("Method: IVFPQ, Operation: SEARCH, dim: %d, n_centroids: %d, numSubQuantizers %d, bitsPerCode %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, classical GPU search time: %.3f milliseconds, RAFT enabled GPU search time: %.3f milliseconds" % ( | ||
| n_cols, nlist, M, args.bits_per_code, n_add, n_rows, args.nprobe, args.k, classical_gpu_search_time, raft_gpu_search_time)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you drop this? See below.