Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/faiss_distance_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ inline void chooseTileSize(size_t numQueries,
tileRows = std::min(preferredTileRows, numQueries);

// tileCols is the remainder size
tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
tileCols = std::min(targetUsage / tileRows, numCentroids);
}
} // namespace cuvs::neighbors::detail::faiss_select
36 changes: 17 additions & 19 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ void tiled_brute_force_knn(const raft::resources& handle,
const uint32_t* filter_bitmap = nullptr)
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
size_t tile_cols = 0;
auto stream = raft::resource::get_cuda_stream(handle);
auto device_memory = raft::resource::get_workspace_resource(handle);
auto total_mem = rmm::available_device_memory().second;
size_t tile_rows = 0;
size_t tile_cols = 0;
auto stream = raft::resource::get_cuda_stream(handle);

// total memory is not relevant in the heuristic for data below 512 MB
auto total_mem =
(sizeof(DistanceT) * m * n < 1 << 29) ? (1ul << 36) : rmm::available_device_memory().second;
Comment thread
mfoerste4 marked this conversation as resolved.
Outdated
cuvs::neighbors::detail::faiss_select::chooseTileSize(
m, n, d, sizeof(DistanceT), total_mem, tile_rows, tile_cols);

Expand Down Expand Up @@ -356,27 +357,26 @@ void brute_force_knn_impl(

ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size");

std::vector<IdxType>* id_ranges;
if (translations == nullptr) {
std::vector<IdxType> id_ranges;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

longer term - we can probably remove the code that handles translations entirely. Its not being used in the public api anymore, and is just left over from the RAFT version. (doesn't need to change in this PR though)

if (translations != nullptr) {
// use the given translations
id_ranges.insert(id_ranges.end(), translations->begin(), translations->end());
} else if (input.size() > 1) {
// If we don't have explicit translations
// for offsets of the indices, build them
// from the local partitions
id_ranges = new std::vector<IdxType>();
IdxType total_n = 0;
for (size_t i = 0; i < input.size(); i++) {
id_ranges->push_back(total_n);
id_ranges.push_back(total_n);
total_n += sizes[i];
}
} else {
// otherwise, use the given translations
id_ranges = translations;
}

int device;
RAFT_CUDA_TRY(cudaGetDevice(&device));

rmm::device_uvector<IdxType> trans(id_ranges->size(), userStream);
raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream);
rmm::device_uvector<IdxType> trans(0, userStream);
if (id_ranges.size() > 0) {
trans.resize(id_ranges.size(), userStream);
raft::update_device(trans.data(), id_ranges.data(), id_ranges.size(), userStream);
}

rmm::device_uvector<DistType> all_D(0, userStream);
rmm::device_uvector<IdxType> all_I(0, userStream);
Expand Down Expand Up @@ -513,8 +513,6 @@ void brute_force_knn_impl(
// no translations or partitions to combine, it can be skipped.
knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data());
}

if (translations == nullptr) delete id_ranges;
};

template <typename T,
Expand Down