Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ add_library(
src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu
src/distance/distance.cu
src/distance/pairwise_distance.cu
src/neighbors/brute_force_index.cu
src/neighbors/brute_force.cu
src/neighbors/cagra_build_float.cu
src/neighbors/cagra_build_int8.cu
Expand Down
82 changes: 70 additions & 12 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

#include "ann_types.hpp"
#include <cuvs/neighbors/ann_types.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdspan.hpp>

namespace cuvs::neighbors::brute_force {

Expand All @@ -42,34 +44,90 @@ struct index : cuvs::neighbors::ann::index {
index& operator=(const index&) = delete;
index& operator=(index&&) = default;
~index() = default;
index(void* raft_index);

/** Construct a brute force index from dataset
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
* the dataset, providing a speed benefit over doing this at query time.
* This index will store a non-owning reference to the dataset.
*/
index(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset_view,
std::optional<raft::device_vector<T, int64_t>>&& norms,
cuvs::distance::DistanceType metric,
T metric_arg = 0.0);

/** Construct a brute force index from dataset
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
* the dataset, providing a speed benefit over doing this at query time.
* The dataset will be copied to the device and the index will own the device memory.
*/
index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset_view,
std::optional<raft::device_vector<T, int64_t>>&& norms,
cuvs::distance::DistanceType metric,
T metric_arg = 0.0);

/** Construct a brute force index from dataset
*
* This class stores a non-owning reference to the dataset and norms here.
* Having precomputed norms gives us a performance advantage at query time.
*/
index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset_view,
std::optional<raft::device_vector_view<const T, int64_t>> norms_view,
cuvs::distance::DistanceType metric,
T metric_arg = 0.0);

/**
* Replace the dataset with a new dataset.
*/
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset);

/**
* Replace the dataset with a new dataset.
*
* We create a copy of the dataset on the device. The index manages the lifetime of this copy.
*/
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset);

/** Distance metric used for retrieval */
cuvs::distance::DistanceType metric() const noexcept;
cuvs::distance::DistanceType metric() const noexcept { return metric_; }

/** Metric argument */
T metric_arg() const noexcept;
T metric_arg() const noexcept { return metric_arg_; }

/** Total length of the index (number of vectors). */
size_t size() const noexcept;
size_t size() const noexcept { return dataset_view_.extent(0); }

/** Dimensionality of the data. */
size_t dim() const noexcept;
size_t dim() const noexcept { return dataset_view_.extent(1); }

/** Dataset [size, dim] */
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset() const noexcept;
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset() const noexcept
{
return dataset_view_;
}

/** Dataset norms */
raft::device_vector_view<const T, int64_t, raft::row_major> norms() const;
raft::device_vector_view<const T, int64_t, raft::row_major> norms() const
{
return norms_view_.value();
}

/** Whether ot not this index has dataset norms */
bool has_norms() const noexcept;

// Get pointer to underlying RAFT index, not meant to be used outside of cuVS
inline const void* get_raft_index() const noexcept { return raft_index_.get(); }
inline bool has_norms() const noexcept { return norms_view_.has_value(); }

private:
std::unique_ptr<void*> raft_index_;
cuvs::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, raft::row_major> dataset_;
std::optional<raft::device_vector<T, int64_t>> norms_;
std::optional<raft::device_vector_view<const T, int64_t>> norms_view_;
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset_view_;
T metric_arg_;
};
/**
* @}
Expand Down
115 changes: 86 additions & 29 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,98 @@
* limitations under the License.
*/

#include "./detail/knn_brute_force.cuh"
#include <cuvs/neighbors/brute_force.hpp>
#include <raft/neighbors/brute_force-inl.cuh>

#include <raft/core/copy.hpp>

namespace cuvs::neighbors::brute_force {
template <typename T>
index<T>::index(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset,
std::optional<raft::device_vector<T, int64_t>>&& norms,
cuvs::distance::DistanceType metric,
T metric_arg)
: ann::index(),
metric_(metric),
dataset_(raft::make_device_matrix<T, int64_t>(res, 0, 0)),
norms_(std::move(norms)),
metric_arg_(metric_arg)
{
if (norms_) { norms_view_ = raft::make_const_mdspan(norms_.value().view()); }
update_dataset(res, dataset);
raft::resource::sync_stream(res);
}

template <typename T>
index<T>::index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset,
std::optional<raft::device_vector<T, int64_t>>&& norms,
cuvs::distance::DistanceType metric,
T metric_arg)
: ann::index(),
metric_(metric),
dataset_(raft::make_device_matrix<T, int64_t>(res, 0, 0)),
norms_(std::move(norms)),
metric_arg_(metric_arg)
{
if (norms_) { norms_view_ = raft::make_const_mdspan(norms_.value().view()); }
update_dataset(res, dataset);
}

template <typename T>
index<T>::index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset_view,
std::optional<raft::device_vector_view<const T, int64_t>> norms_view,
cuvs::distance::DistanceType metric,
T metric_arg)
: ann::index(),
metric_(metric),
dataset_(raft::make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_view_(dataset_view),
norms_view_(norms_view),
metric_arg_(metric_arg)
{
}

template <typename T>
void index<T>::update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset)
{
dataset_view_ = dataset;
}

template <typename T>
void index<T>::update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset)
{
dataset_ = raft::make_device_matrix<T, int64_t>(res, dataset.extent(0), dataset.extent(1));
raft::copy(res, dataset_.view(), dataset);
dataset_view_ = raft::make_const_mdspan(dataset_.view());
}

#define CUVS_INST_BFKNN(T, IdxT) \
auto build(raft::resources const& res, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \
cuvs::distance::DistanceType metric, \
T metric_arg) \
->cuvs::neighbors::brute_force::index<T> \
{ \
auto index_on_stack = raft::neighbors::brute_force::build( \
res, dataset, static_cast<raft::distance::DistanceType>(metric), metric_arg); \
auto index_on_heap = \
new raft::neighbors::brute_force::index<float>(std::move(index_on_stack)); \
return cuvs::neighbors::brute_force::index<float>(index_on_heap); \
} \
\
void search(raft::resources const& res, \
const cuvs::neighbors::brute_force::index<T>& idx, \
raft::device_matrix_view<const T, IdxT, raft::row_major> queries, \
raft::device_matrix_view<IdxT, IdxT, raft::row_major> neighbors, \
raft::device_matrix_view<T, IdxT, raft::row_major> distances) \
{ \
auto raft_idx = \
reinterpret_cast<const raft::neighbors::brute_force::index<T>*>(idx.get_raft_index()); \
raft::neighbors::brute_force::search(res, *raft_idx, queries, neighbors, distances); \
} \
\
#define CUVS_INST_BFKNN(T) \
auto build(raft::resources const& res, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset, \
cuvs::distance::DistanceType metric, \
T metric_arg) \
->cuvs::neighbors::brute_force::index<T> \
{ \
return detail::build<T>(res, dataset, metric, metric_arg); \
} \
\
void search(raft::resources const& res, \
const cuvs::neighbors::brute_force::index<T>& idx, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<T, int64_t, raft::row_major> distances) \
{ \
detail::brute_force_search<T, int64_t>(res, idx, queries, neighbors, distances); \
} \
\
template struct cuvs::neighbors::brute_force::index<T>;

CUVS_INST_BFKNN(float, int64_t);
// CUVS_INST_BFKNN(int8_t, int64_t);
// CUVS_INST_BFKNN(uint8_t, int64_t);
CUVS_INST_BFKNN(float);

#undef CUVS_INST_BFKNN

Expand Down
86 changes: 0 additions & 86 deletions cpp/src/neighbors/brute_force_index.cu

This file was deleted.

52 changes: 52 additions & 0 deletions cpp/src/neighbors/detail/faiss_distance_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file thirdparty/LICENSES/LICENSE.faiss
*/

#pragma once

namespace cuvs::neighbors::detail::faiss_select {
// If the inner size (dim) of the vectors is small, we want a larger query tile
// size, like 1024
inline void chooseTileSize(size_t numQueries,
size_t numCentroids,
size_t dim,
size_t elementSize,
size_t totalMem,
size_t& tileRows,
size_t& tileCols)
{
// The matrix multiplication should be large enough to be efficient, but if
// it is too large, we seem to lose efficiency as opposed to
// double-streaming. Each tile size here defines 1/2 of the memory use due
// to double streaming. We ignore available temporary memory, as that is
// adjusted independently by the user and can thus meet these requirements
// (or not). For <= 4 GB GPUs, prefer 512 MB of usage. For <= 8 GB GPUs,
// prefer 768 MB of usage. Otherwise, prefer 1 GB of usage.
size_t targetUsage = 0;

if (totalMem <= ((size_t)4) * 1024 * 1024 * 1024) {
targetUsage = 512 * 1024 * 1024;
} else if (totalMem <= ((size_t)8) * 1024 * 1024 * 1024) {
targetUsage = 768 * 1024 * 1024;
} else {
targetUsage = 1024 * 1024 * 1024;
}

targetUsage /= 2 * elementSize;

// 512 seems to be a batch size sweetspot for float32.
// If we are on float16, increase to 512.
// If the k size (vec dim) of the matrix multiplication is small (<= 32),
// increase to 1024.
size_t preferredTileRows = 512;
if (dim <= 32) { preferredTileRows = 1024; }

tileRows = std::min(preferredTileRows, numQueries);

// tileCols is the remainder size
tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
}
} // namespace cuvs::neighbors::detail::faiss_select
Loading