Skip to content

Commit 48e8d4d

Browse files
cuVS Cagra FP16 support (#4384)
Summary: Supporting fp16 for cuVS cagra, and introducing new extended APIs for this. Discussions related to this issue: facebookresearch/faiss#4324 Added tests in `faiss/gpu/test/TestGpuIndexCagra.cu` and `faiss/gpu/test/test_cagra.py` for example usage. Pull Request resolved: facebookresearch/faiss#4384 Reviewed By: junjieqi Differential Revision: D76480612 Pulled By: mnorris11 fbshipit-source-id: 863d8671eab461733110f74550ffc56650f77407
1 parent b28b4c5 commit 48e8d4d

15 files changed

Lines changed: 1439 additions & 312 deletions

faiss/Index.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#define FAISS_INDEX_H
1212

1313
#include <faiss/MetricType.h>
14+
#include <faiss/impl/FaissAssert.h>
15+
1416
#include <cstdio>
1517
#include <sstream>
1618
#include <string>
@@ -56,6 +58,23 @@ struct IDSelector;
5658
struct RangeSearchResult;
5759
struct DistanceComputer;
5860

61+
enum NumericType {
62+
Float32,
63+
Float16,
64+
};
65+
66+
inline size_t get_numeric_type_size(NumericType numeric_type) {
67+
switch (numeric_type) {
68+
case NumericType::Float32:
69+
return 4;
70+
case NumericType::Float16:
71+
return 2;
72+
default:
73+
FAISS_THROW_MSG(
74+
"Unknown Numeric Type. Only supports Float32, Float16");
75+
}
76+
}
77+
5978
/** Parent class for the optional search paramenters.
6079
*
6180
* Sub-classes with additional search parameters should inherit this class.
@@ -107,6 +126,14 @@ struct Index {
107126
*/
108127
virtual void train(idx_t n, const float* x);
109128

129+
virtual void train(idx_t n, const void* x, NumericType numeric_type) {
130+
if (numeric_type == NumericType::Float32) {
131+
train(n, static_cast<const float*>(x));
132+
} else {
133+
FAISS_THROW_MSG("Index::train: unsupported numeric type");
134+
}
135+
}
136+
110137
/** Add n vectors of dimension d to the index.
111138
*
112139
* Vectors are implicitly assigned labels ntotal .. ntotal + n - 1
@@ -117,6 +144,14 @@ struct Index {
117144
*/
118145
virtual void add(idx_t n, const float* x) = 0;
119146

147+
virtual void add(idx_t n, const void* x, NumericType numeric_type) {
148+
if (numeric_type == NumericType::Float32) {
149+
add(n, static_cast<const float*>(x));
150+
} else {
151+
FAISS_THROW_MSG("Index::add: unsupported numeric type");
152+
}
153+
}
154+
120155
/** Same as add, but stores xids instead of sequential ids.
121156
*
122157
* The default implementation fails with an assertion, as it is
@@ -127,6 +162,17 @@ struct Index {
127162
* @param xids if non-null, ids to store for the vectors (size n)
128163
*/
129164
virtual void add_with_ids(idx_t n, const float* x, const idx_t* xids);
165+
virtual void add_with_ids(
166+
idx_t n,
167+
const void* x,
168+
NumericType numeric_type,
169+
const idx_t* xids) {
170+
if (numeric_type == NumericType::Float32) {
171+
add_with_ids(n, static_cast<const float*>(x), xids);
172+
} else {
173+
FAISS_THROW_MSG("Index::add_with_ids: unsupported numeric type");
174+
}
175+
}
130176

131177
/** query n vectors of dimension d to the index.
132178
*
@@ -147,6 +193,26 @@ struct Index {
147193
idx_t* labels,
148194
const SearchParameters* params = nullptr) const = 0;
149195

196+
virtual void search(
197+
idx_t n,
198+
const void* x,
199+
NumericType numeric_type,
200+
idx_t k,
201+
float* distances,
202+
idx_t* labels,
203+
const SearchParameters* params = nullptr) const {
204+
if (numeric_type == NumericType::Float32) {
205+
search(n,
206+
static_cast<const float*>(x),
207+
k,
208+
distances,
209+
labels,
210+
params);
211+
} else {
212+
FAISS_THROW_MSG("Index::search: unsupported numeric type");
213+
}
214+
}
215+
150216
/** query n vectors of dimension d to the index.
151217
*
152218
* return all vectors with distance < radius. Note that many

faiss/IndexHNSW.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <random>
2020

2121
#include <cstdint>
22+
#include "faiss/Index.h"
2223

2324
#include <faiss/Index2Layer.h>
2425
#include <faiss/IndexFlat.h>
@@ -893,15 +894,31 @@ IndexHNSWCagra::IndexHNSWCagra() {
893894
is_trained = true;
894895
}
895896

896-
IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric)
897-
: IndexHNSW(
898-
(metric == METRIC_L2)
899-
? static_cast<IndexFlat*>(new IndexFlatL2(d))
900-
: static_cast<IndexFlat*>(new IndexFlatIP(d)),
901-
M) {
897+
IndexHNSWCagra::IndexHNSWCagra(
898+
int d,
899+
int M,
900+
MetricType metric,
901+
NumericType numeric_type)
902+
: IndexHNSW(d, M, metric) {
902903
FAISS_THROW_IF_NOT_MSG(
903904
((metric == METRIC_L2) || (metric == METRIC_INNER_PRODUCT)),
904905
"unsupported metric type for IndexHNSWCagra");
906+
numeric_type_ = numeric_type;
907+
if (numeric_type == NumericType::Float32) {
908+
// Use flat storage with full precision for fp32
909+
storage = (metric == METRIC_L2)
910+
? static_cast<Index*>(new IndexFlatL2(d))
911+
: static_cast<Index*>(new IndexFlatIP(d));
912+
} else if (numeric_type == NumericType::Float16) {
913+
auto qtype = ScalarQuantizer::QT_fp16;
914+
storage = new IndexScalarQuantizer(d, qtype, metric);
915+
} else {
916+
FAISS_THROW_MSG(
917+
"Unsupported numeric_type: only F16 and F32 are supported for IndexHNSWCagra");
918+
}
919+
920+
metric_arg = storage->metric_arg;
921+
905922
own_fields = true;
906923
is_trained = true;
907924
init_level0 = true;
@@ -967,4 +984,12 @@ void IndexHNSWCagra::search(
967984
}
968985
}
969986

987+
faiss::NumericType IndexHNSWCagra::get_numeric_type() const {
988+
return numeric_type_;
989+
}
990+
991+
void IndexHNSWCagra::set_numeric_type(faiss::NumericType numeric_type) {
992+
numeric_type_ = numeric_type;
993+
}
994+
970995
} // namespace faiss

faiss/IndexHNSW.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include <vector>
13+
#include "faiss/Index.h"
1314

1415
#include <faiss/IndexFlat.h>
1516
#include <faiss/IndexPQ.h>
@@ -170,7 +171,11 @@ struct IndexHNSW2Level : IndexHNSW {
170171

171172
struct IndexHNSWCagra : IndexHNSW {
172173
IndexHNSWCagra();
173-
IndexHNSWCagra(int d, int M, MetricType metric = METRIC_L2);
174+
IndexHNSWCagra(
175+
int d,
176+
int M,
177+
MetricType metric = METRIC_L2,
178+
NumericType numeric_type = NumericType::Float32);
174179

175180
/// When set to true, the index is immutable.
176181
/// This option is used to copy the knn graph from GpuIndexCagra
@@ -195,6 +200,10 @@ struct IndexHNSWCagra : IndexHNSW {
195200
float* distances,
196201
idx_t* labels,
197202
const SearchParameters* params = nullptr) const override;
203+
204+
faiss::NumericType get_numeric_type() const;
205+
void set_numeric_type(faiss::NumericType numeric_type);
206+
NumericType numeric_type_;
198207
};
199208

200209
} // namespace faiss

faiss/gpu/GpuCloner.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ Index* ToCPUCloner::clone_Index(const Index* index) {
9494
#if defined USE_NVIDIA_CUVS
9595
else if (auto icg = dynamic_cast<const GpuIndexCagra*>(index)) {
9696
IndexHNSWCagra* res = new IndexHNSWCagra();
97+
if (icg->get_numeric_type() == faiss::NumericType::Float16) {
98+
res->base_level_only = true;
99+
}
97100
icg->copyTo(res);
98101
return res;
99102
}
@@ -235,7 +238,7 @@ Index* ToGpuCloner::clone_Index(const Index* index) {
235238
config.device = device;
236239
GpuIndexCagra* res =
237240
new GpuIndexCagra(provider, icg->d, icg->metric_type, config);
238-
res->copyFrom(icg);
241+
res->copyFrom(icg, icg->get_numeric_type());
239242
return res;
240243
}
241244
#endif

0 commit comments

Comments
 (0)