Skip to content

Commit 2644992

Browse files
mdouzefacebook-github-bot
authored andcommitted
mem mapping and zero-copy python fixes (facebookresearch#4212)
Summary: Pull Request resolved: facebookresearch#4212 Add files to TARGETS fix python Reviewed By: mengdilin Differential Revision: D69984379 fbshipit-source-id: 9437b4ad92ef49333a44ea37ec194364123fe825
1 parent 5ebec1d commit 2644992

8 files changed

Lines changed: 92 additions & 8 deletions

File tree

faiss/impl/mapped_io.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
18
#include <stdio.h>
29
#include <string.h>
310

@@ -11,8 +18,8 @@
1118

1219
#elif defined(_WIN32)
1320

14-
#include <Windows.h>
15-
#include <io.h>
21+
#include <Windows.h> // @manual
22+
#include <io.h> // @manual
1623

1724
#endif
1825

@@ -278,4 +285,4 @@ int MappedFileIOReader::filedescriptor() {
278285
return -1;
279286
}
280287

281-
} // namespace faiss
288+
} // namespace faiss

faiss/impl/mapped_io.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
18
#pragma once
29

310
#include <cstddef>
@@ -41,4 +48,4 @@ struct MappedFileIOReader : IOReader {
4148
int filedescriptor() override;
4249
};
4350

44-
} // namespace faiss
51+
} // namespace faiss

faiss/impl/maybe_owned_vector.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
18
#pragma once
29

310
#include <cstddef>
@@ -50,6 +57,13 @@ struct MaybeOwnedVector {
5057
c_size = owned_data.size();
5158
}
5259

60+
explicit MaybeOwnedVector(const std::vector<T>& vec)
61+
: faiss::MaybeOwnedVector<T>(vec.size()) {
62+
if (vec.size() > 0) {
63+
memcpy(owned_data.data(), vec.data(), sizeof(T) * vec.size());
64+
}
65+
}
66+
5367
MaybeOwnedVector(const MaybeOwnedVector& other) {
5468
is_owned = other.is_owned;
5569
owned_data = other.owned_data;
@@ -229,4 +243,4 @@ struct is_maybe_owned_vector<MaybeOwnedVector<T>> : std::true_type {};
229243
template <typename T>
230244
inline constexpr bool is_maybe_owned_vector_v = is_maybe_owned_vector<T>::value;
231245

232-
} // namespace faiss
246+
} // namespace faiss

faiss/impl/zerocopy_io.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
18
#include <faiss/impl/zerocopy_io.h>
29
#include <cstring>
310

@@ -37,6 +44,9 @@ void ZeroCopyIOReader::reset() {
3744
}
3845

3946
size_t ZeroCopyIOReader::operator()(void* ptr, size_t size, size_t nitems) {
47+
if (size * nitems == 0) {
48+
return 0;
49+
}
4050
if (rp_ >= total_) {
4151
return 0;
4252
}
@@ -53,4 +63,4 @@ int ZeroCopyIOReader::filedescriptor() {
5363
return -1; // Indicating no file descriptor available for memory buffer
5464
}
5565

56-
} // namespace faiss
66+
} // namespace faiss

faiss/impl/zerocopy_io.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
18
#pragma once
29

310
#include <cstdint>
@@ -22,4 +29,4 @@ struct ZeroCopyIOReader : public faiss::IOReader {
2229
int filedescriptor() override;
2330
};
2431

25-
} // namespace faiss
32+
} // namespace faiss

faiss/python/swigfaiss.swig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#pragma SWIG nowarn=341
3333
#pragma SWIG nowarn=512
3434
#pragma SWIG nowarn=362
35+
#pragma SWIG nowarn=509
3536

3637
// we need explict control of these typedefs...
3738
// %include <stdint.i>

tests/test_fast_scan_ivf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,9 @@ def test_equiv_pq(self):
270270
index_pq = faiss.index_factory(32, "PQ16x4np")
271271
index_pq.pq = index.pq
272272
index_pq.is_trained = True
273-
index_pq.codes = faiss. downcast_InvertedLists(
273+
codevec = faiss.downcast_InvertedLists(
274274
index.invlists).codes.at(0)
275+
index_pq.codes = faiss.MaybeOwnedVectorUInt8(codevec)
275276
index_pq.ntotal = index.ntotal
276277
Dnew, Inew = index_pq.search(xq, 4)
277278

tests/test_io.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,40 @@ def test_reader(self):
481481
finally:
482482
if os.path.exists(fname):
483483
os.unlink(fname)
484+
485+
486+
class TestIOFlatMMap(unittest.TestCase):
487+
488+
def test_mmap(self):
489+
xt, xb, xq = get_dataset_2(32, 0, 100, 50)
490+
index = faiss.index_factory(32, "SQfp16", faiss.METRIC_L2)
491+
# does not need training
492+
index.add(xb)
493+
Dref, Iref = index.search(xq, 10)
494+
495+
fd, fname = tempfile.mkstemp()
496+
os.close(fd)
497+
try:
498+
faiss.write_index(index, fname)
499+
index2 = faiss.read_index(fname, faiss.IO_FLAG_MMAP_IFC)
500+
Dnew, Inew = index2.search(xq, 10)
501+
np.testing.assert_array_equal(Iref, Inew)
502+
np.testing.assert_array_equal(Dref, Dnew)
503+
finally:
504+
if os.path.exists(fname):
505+
os.unlink(fname)
506+
507+
def test_zerocopy(self):
508+
xt, xb, xq = get_dataset_2(32, 0, 100, 50)
509+
index = faiss.index_factory(32, "SQfp16", faiss.METRIC_L2)
510+
# does not need training
511+
index.add(xb)
512+
Dref, Iref = index.search(xq, 10)
513+
514+
serialized_index = faiss.serialize_index(index)
515+
reader = faiss.ZeroCopyIOReader(
516+
faiss.swig_ptr(serialized_index), serialized_index.size)
517+
index2 = faiss.read_index(reader)
518+
Dnew, Inew = index2.search(xq, 10)
519+
np.testing.assert_array_equal(Iref, Inew)
520+
np.testing.assert_array_equal(Dref, Dnew)

0 commit comments

Comments
 (0)