Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions cpp/src/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1076,11 +1076,11 @@ void optimize(
"Each input array is expected to have the same number of rows");
RAFT_EXPECTS(new_graph.extent(1) <= knn_graph.extent(1),
"output graph cannot have more columns than input graph");
const uint32_t input_graph_degree = knn_graph.extent(1);
const uint32_t output_graph_degree = new_graph.extent(1);
const uint64_t input_graph_degree = knn_graph.extent(1);
const uint64_t output_graph_degree = new_graph.extent(1);
const uint64_t graph_size = new_graph.extent(0);
auto input_graph_ptr = knn_graph.data_handle();
auto output_graph_ptr = new_graph.data_handle();
const IdxT graph_size = new_graph.extent(0);

// MST optimization
auto mst_graph_num_edges = raft::make_host_vector<uint32_t, int64_t>(graph_size);
Expand Down Expand Up @@ -1148,7 +1148,7 @@ void optimize(
constexpr int MAX_DEGREE = 1024;
if (input_graph_degree > MAX_DEGREE) {
RAFT_FAIL(
"The degree of input knn graph is too large (%u). "
"The degree of input knn graph is too large (%zu). "
"It must be equal to or smaller than %d.",
input_graph_degree,
1024);
Expand Down Expand Up @@ -1217,11 +1217,12 @@ void optimize(
assert(next_num_detour != std::numeric_limits<uint32_t>::max());
num_detour = next_num_detour;
}
RAFT_EXPECTS(pk == output_graph_degree,
"Couldn't find the output_graph_degree (%u) smallest detourable count nodes for "
"node %lu in the rank-based node reranking process",
output_graph_degree,
static_cast<uint64_t>(i));
RAFT_EXPECTS(
pk == output_graph_degree,
"Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for "
"node %lu in the rank-based node reranking process",
output_graph_degree,
i);
}

const double time_prune_end = cur_time();
Expand Down Expand Up @@ -1317,7 +1318,7 @@ void optimize(
uint32_t kf = 0;
uint32_t k = mst_graph_num_edges_ptr[i];

const uint64_t num_protected_edges = max(k, output_graph_degree / 2);
const auto num_protected_edges = std::max<uint64_t>(k, output_graph_degree / 2);
assert(num_protected_edges <= output_graph_degree);
if (num_protected_edges == output_graph_degree) continue;

Expand All @@ -1342,7 +1343,7 @@ void optimize(
assert(kf <= output_graph_degree);

// Replace some edges of the output graph with edges of the reverse graph.
uint32_t kr = std::min(rev_graph_count.data_handle()[i], output_graph_degree);
auto kr = std::min<uint32_t>(rev_graph_count.data_handle()[i], output_graph_degree);
while (kr) {
kr -= 1;
if (my_rev_graph[kr] < graph_size) {
Expand Down