Skip to content

Commit 1b117ff

Browse files
authored
Merge branch 'main' into faiss_with_search_codes
2 parents 87fe2f4 + 467f70e commit 1b117ff

20 files changed

+378
-198
lines changed

.circleci/config.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ executors:
99
environment:
1010
CONDA_ARCH: Linux-x86_64
1111
machine:
12-
image: linux-cuda-11:2023.02.1
12+
image: linux-cuda-11:default
1313
resource_class: gpu.nvidia.medium
1414
linux-arm64-cpu:
1515
environment:
@@ -91,7 +91,7 @@ jobs:
9191
- run:
9292
name: Install conda build tools
9393
command: |
94-
conda config --set solver libmamba
94+
# conda config --set solver libmamba
9595
# conda config --set verbosity 3
9696
conda update -y -q conda
9797
conda install -y -q conda-build
@@ -171,7 +171,7 @@ jobs:
171171
sudo update-alternatives --set cuda /usr/local/cuda-<<parameters.cuda>>
172172
cd conda
173173
conda build faiss-gpu-raft --variants '{ "cudatoolkit": "<<parameters.cuda>>", "c_compiler_version": "<<parameters.compiler_version>>", "cxx_compiler_version": "<<parameters.compiler_version>>" }' \
174-
-c pytorch -c nvidia -c rapidsai -c conda-forge
174+
-c pytorch -c nvidia -c rapidsai-nightly -c conda-forge
175175
- when:
176176
condition:
177177
and:
@@ -186,7 +186,7 @@ jobs:
186186
sudo update-alternatives --set cuda /usr/local/cuda-<<parameters.cuda>>
187187
cd conda
188188
conda build faiss-gpu-raft --variants '{ "cudatoolkit": "<<parameters.cuda>>", "c_compiler_version": "<<parameters.compiler_version>>", "cxx_compiler_version": "<<parameters.compiler_version>>" }' \
189-
--user pytorch --label <<parameters.label>> -c pytorch -c nvidia -c rapidsai -c conda-forge
189+
--user pytorch --label <<parameters.label>> -c pytorch -c nvidia -c rapidsai-nightly -c conda-forge
190190
191191
build_cmake:
192192
parameters:
@@ -236,7 +236,7 @@ jobs:
236236
- run:
237237
name: Install libraft
238238
command: |
239-
conda install -y -q libraft cudatoolkit=11.4 -c rapidsai-nightly -c nvidia -c pkgs/main -c conda-forge
239+
conda install -y -q libraft cuda-version=11.4 -c rapidsai-nightly -c nvidia -c pkgs/main -c conda-forge
240240
- run:
241241
name: Build all targets
242242
no_output_timeout: 30m
@@ -283,7 +283,7 @@ jobs:
283283
- run:
284284
name: Python tests (CPU + GPU)
285285
command: |
286-
conda install -y -q pytorch pytorch-cuda -c pytorch -c nvidia
286+
conda install -y -q pytorch pytorch-cuda=11 -c pytorch -c nvidia
287287
pytest --junitxml=test-results/pytest/results.xml tests/test_*.py
288288
pytest --junitxml=test-results/pytest/results-torch.xml tests/torch_*.py
289289
cp tests/common_faiss_tests.py faiss/gpu/test

cmake/thirdparty/fetch_rapids.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# or implied. See the License for the specific language governing permissions and limitations under
1616
# the License.
1717
# =============================================================================
18-
set(RAPIDS_VERSION "23.08")
18+
set(RAPIDS_VERSION "23.12")
1919

2020
if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
2121
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake

conda/faiss-gpu-raft/meta.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ outputs:
4747
host:
4848
- mkl =2023 # [x86_64]
4949
- openblas # [not x86_64]
50-
- cudatoolkit {{ cudatoolkit }}
51-
- libraft =23.08
50+
- cuda-version {{ cudatoolkit }}
51+
- libraft =23.12
5252
run:
5353
- mkl =2023 # [x86_64]
5454
- openblas # [not x86_64]
55-
- {{ pin_compatible('cudatoolkit', max_pin='x.x') }}
56-
- libraft =23.08
55+
- {{ pin_compatible('cuda-version', max_pin='x') }}
56+
- libraft =23.12
5757
test:
5858
requires:
5959
- conda-build
@@ -90,6 +90,8 @@ outputs:
9090
- numpy
9191
- scipy
9292
- pytorch
93+
- pytorch-cuda =11.8
94+
- cuda-version =11.8
9395
commands:
9496
- python -X faulthandler -m unittest discover -v -s tests/ -p "test_*"
9597
- python -X faulthandler -m unittest discover -v -s tests/ -p "torch_*"

conda/faiss-gpu/meta.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ outputs:
5151
run:
5252
- mkl =2023 # [x86_64]
5353
- openblas # [not x86_64]
54-
- {{ pin_compatible('cudatoolkit', max_pin='x.x') }}
54+
- {{ pin_compatible('cudatoolkit', max_pin='x') }}
5555
test:
5656
requires:
5757
- conda-build
@@ -88,6 +88,8 @@ outputs:
8888
- numpy
8989
- scipy
9090
- pytorch
91+
- pytorch-cuda =11.8
92+
- cudatoolkit =11.8
9193
commands:
9294
- python -X faulthandler -m unittest discover -v -s tests/ -p "test_*"
9395
- python -X faulthandler -m unittest discover -v -s tests/ -p "torch_*"

faiss/Index.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct Index {
9999
* Vectors are implicitly assigned labels ntotal .. ntotal + n - 1
100100
* This function slices the input vectors in chunks smaller than
101101
* blocksize_add and calls add_core.
102+
* @param n number of vectors
102103
* @param x input matrix, size n * d
103104
*/
104105
virtual void add(idx_t n, const float* x) = 0;
@@ -108,7 +109,9 @@ struct Index {
108109
* The default implementation fails with an assertion, as it is
109110
* not supported by all indexes.
110111
*
111-
* @param xids if non-null, ids to store for the vectors (size n)
112+
* @param n number of vectors
113+
* @param x input vectors, size n * d
114+
* @param xids if non-null, ids to store for the vectors (size n)
112115
*/
113116
virtual void add_with_ids(idx_t n, const float* x, const idx_t* xids);
114117

@@ -117,9 +120,11 @@ struct Index {
117120
* return at most k vectors. If there are not enough results for a
118121
* query, the result array is padded with -1s.
119122
*
123+
* @param n number of vectors
120124
* @param x input vectors to search, size n * d
121-
* @param labels output labels of the NNs, size n*k
125+
* @param k number of extracted vectors
122126
* @param distances output pairwise distances, size n*k
127+
* @param labels output labels of the NNs, size n*k
123128
*/
124129
virtual void search(
125130
idx_t n,
@@ -135,6 +140,7 @@ struct Index {
135140
* indexes do not implement the range_search (only the k-NN search
136141
* is mandatory).
137142
*
143+
* @param n number of vectors
138144
* @param x input vectors to search, size n * d
139145
* @param radius search radius
140146
* @param result result table
@@ -149,8 +155,10 @@ struct Index {
149155
/** return the indexes of the k vectors closest to the query x.
150156
*
151157
* This function is identical as search but only return labels of neighbors.
158+
* @param n number of vectors
152159
* @param x input vectors to search, size n * d
153160
* @param labels output labels of the NNs, size n*k
161+
* @param k number of nearest neighbours
154162
*/
155163
virtual void assign(idx_t n, const float* x, idx_t* labels, idx_t k = 1)
156164
const;
@@ -174,7 +182,7 @@ struct Index {
174182
/** Reconstruct several stored vectors (or an approximation if lossy coding)
175183
*
176184
* this function may not be defined for some indexes
177-
* @param n number of vectors to reconstruct
185+
* @param n number of vectors to reconstruct
178186
* @param keys ids of the vectors to reconstruct (size n)
179187
* @param recons reconstucted vector (size n * d)
180188
*/
@@ -184,6 +192,8 @@ struct Index {
184192
/** Reconstruct vectors i0 to i0 + ni - 1
185193
*
186194
* this function may not be defined for some indexes
195+
* @param i0 index of the first vector in the sequence
196+
* @param ni number of vectors in the sequence
187197
* @param recons reconstucted vector (size ni * d)
188198
*/
189199
virtual void reconstruct_n(idx_t i0, idx_t ni, float* recons) const;
@@ -194,6 +204,11 @@ struct Index {
194204
* If there are not enough results for a query, the resulting arrays
195205
* is padded with -1s.
196206
*
207+
* @param n number of vectors
208+
* @param x input vectors to search, size n * d
209+
* @param k number of extracted vectors
210+
* @param distances output pairwise distances, size n*k
211+
* @param labels output labels of the NNs, size n*k
197212
* @param recons reconstructed vectors size (n, k, d)
198213
**/
199214
virtual void search_and_reconstruct(

faiss/IndexFlat.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ namespace faiss {
1818

1919
/** Index that stores the full vectors and performs exhaustive search */
2020
struct IndexFlat : IndexFlatCodes {
21-
explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2);
21+
explicit IndexFlat(
22+
idx_t d, ///< dimensionality of the input vectors
23+
MetricType metric = METRIC_L2);
2224

2325
void search(
2426
idx_t n,
@@ -82,6 +84,9 @@ struct IndexFlatL2 : IndexFlat {
8284
// and l2 norms.
8385
std::vector<float> cached_l2norms;
8486

87+
/**
88+
* @param d dimensionality of the input vectors
89+
*/
8590
explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {}
8691
IndexFlatL2() {}
8792

faiss/IndexFlatCodes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ struct IndexFlatCodes : Index {
3434

3535
void reset() override;
3636

37-
/// reconstruction using the codec interface
3837
void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
3938

4039
void reconstruct(idx_t key, float* recons) const override;

faiss/IndexPQ.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ struct IndexPQ : IndexFlatCodes {
3131
* @param M number of subquantizers
3232
* @param nbits number of bit per subvector index
3333
*/
34-
IndexPQ(int d, ///< dimensionality of the input vectors
35-
size_t M, ///< number of subquantizers
36-
size_t nbits, ///< number of bit per subvector index
37-
MetricType metric = METRIC_L2);
34+
IndexPQ(int d, size_t M, size_t nbits, MetricType metric = METRIC_L2);
3835

3936
IndexPQ();
4037

faiss/IndexRefine.cpp

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,26 @@ void IndexRefine::search(
9696
idx_t k,
9797
float* distances,
9898
idx_t* labels,
99-
const SearchParameters* params) const {
100-
FAISS_THROW_IF_NOT_MSG(
101-
!params, "search params not supported for this index");
99+
const SearchParameters* params_in) const {
100+
const IndexRefineSearchParameters* params = nullptr;
101+
if (params_in) {
102+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
103+
FAISS_THROW_IF_NOT_MSG(
104+
params, "IndexRefine params have incorrect type");
105+
}
106+
107+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
108+
: idx_t(k * k_factor);
109+
SearchParameters* base_index_params =
110+
(params != nullptr) ? params->base_index_params : nullptr;
111+
112+
FAISS_THROW_IF_NOT(k_base >= k);
113+
114+
FAISS_THROW_IF_NOT(base_index);
115+
FAISS_THROW_IF_NOT(refine_index);
116+
102117
FAISS_THROW_IF_NOT(k > 0);
103118
FAISS_THROW_IF_NOT(is_trained);
104-
idx_t k_base = idx_t(k * k_factor);
105119
idx_t* base_labels = labels;
106120
float* base_distances = distances;
107121
ScopeDeleter<idx_t> del1;
@@ -114,7 +128,8 @@ void IndexRefine::search(
114128
del2.set(base_distances);
115129
}
116130

117-
base_index->search(n, x, k_base, base_distances, base_labels);
131+
base_index->search(
132+
n, x, k_base, base_distances, base_labels, base_index_params);
118133

119134
for (int i = 0; i < n * k_base; i++)
120135
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -225,12 +240,26 @@ void IndexRefineFlat::search(
225240
idx_t k,
226241
float* distances,
227242
idx_t* labels,
228-
const SearchParameters* params) const {
229-
FAISS_THROW_IF_NOT_MSG(
230-
!params, "search params not supported for this index");
243+
const SearchParameters* params_in) const {
244+
const IndexRefineSearchParameters* params = nullptr;
245+
if (params_in) {
246+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
247+
FAISS_THROW_IF_NOT_MSG(
248+
params, "IndexRefineFlat params have incorrect type");
249+
}
250+
251+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
252+
: idx_t(k * k_factor);
253+
SearchParameters* base_index_params =
254+
(params != nullptr) ? params->base_index_params : nullptr;
255+
256+
FAISS_THROW_IF_NOT(k_base >= k);
257+
258+
FAISS_THROW_IF_NOT(base_index);
259+
FAISS_THROW_IF_NOT(refine_index);
260+
231261
FAISS_THROW_IF_NOT(k > 0);
232262
FAISS_THROW_IF_NOT(is_trained);
233-
idx_t k_base = idx_t(k * k_factor);
234263
idx_t* base_labels = labels;
235264
float* base_distances = distances;
236265
ScopeDeleter<idx_t> del1;
@@ -243,7 +272,8 @@ void IndexRefineFlat::search(
243272
del2.set(base_distances);
244273
}
245274

246-
base_index->search(n, x, k_base, base_distances, base_labels);
275+
base_index->search(
276+
n, x, k_base, base_distances, base_labels, base_index_params);
247277

248278
for (int i = 0; i < n * k_base; i++)
249279
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);

faiss/IndexRefine.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111

1212
namespace faiss {
1313

14+
struct IndexRefineSearchParameters : SearchParameters {
15+
float k_factor = 1;
16+
SearchParameters* base_index_params = nullptr; // non-owning
17+
18+
virtual ~IndexRefineSearchParameters() = default;
19+
};
20+
1421
/** Index that queries in a base_index (a fast one) and refines the
1522
* results with an exact search, hopefully improving the results.
1623
*/

0 commit comments

Comments
 (0)