Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion c/src/core/detail/interop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ inline MdspanType from_dlpack(DLManagedTensor* managed_tensor)
"ndim mismatch between return mdspan and DLTensor");

// auto exts = typename MdspanType::extents_type{tensor.shape};
std::array<int64_t, MdspanType::extents_type::rank()> shape{};
cuda::std::array<int64_t, MdspanType::extents_type::rank()> shape{};
for (int64_t i = 0; i < tensor.ndim; ++i) {
shape[i] = tensor.shape[i];
}
Expand Down
1 change: 1 addition & 0 deletions ci/build_cpp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export RAPIDS_PACKAGE_VERSION
RAPIDS_ARTIFACTS_DIR=${RAPIDS_ARTIFACTS_DIR:-"${PWD}/artifacts"}
mkdir -p "${RAPIDS_ARTIFACTS_DIR}"
export RAPIDS_ARTIFACTS_DIR
source ./ci/use_conda_packages_from_prs.sh
Comment thread
divyegala marked this conversation as resolved.
Outdated

# populates `RATTLER_CHANNELS` array and `RATTLER_ARGS` array
source rapids-rattler-channel-string
Expand Down
2 changes: 2 additions & 0 deletions ci/build_docs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ rapids-logger "Downloading artifacts from previous jobs"
CPP_CHANNEL=$(rapids-download-conda-from-github cpp)
PYTHON_CHANNEL=$(rapids-download-conda-from-github python)

source ./ci/use_conda_packages_from_prs.sh

rapids-logger "Create test conda environment"
. /opt/conda/etc/profile.d/conda.sh

Expand Down
1 change: 1 addition & 0 deletions ci/build_go.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set -euo pipefail

rapids-logger "Downloading artifacts from previous jobs"
CPP_CHANNEL=$(rapids-download-conda-from-github cpp)
source ./ci/use_conda_packages_from_prs.sh

rapids-logger "Create test conda environment"
. /opt/conda/etc/profile.d/conda.sh
Expand Down
1 change: 1 addition & 0 deletions ci/build_java.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ conda config --set channel_priority strict

rapids-logger "Downloading artifacts from previous jobs"
CPP_CHANNEL=$(rapids-download-conda-from-github cpp)
source ./ci/use_conda_packages_from_prs.sh

rapids-logger "Generate Java testing dependencies"

Expand Down
1 change: 1 addition & 0 deletions ci/build_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ rapids-print-env
rapids-logger "Begin py build"

CPP_CHANNEL=$(rapids-download-conda-from-github cpp)
source ./ci/use_conda_packages_from_prs.sh

version=$(rapids-generate-version)
export RAPIDS_PACKAGE_VERSION=${version}
Expand Down
1 change: 1 addition & 0 deletions ci/build_rust.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set -euo pipefail

rapids-logger "Downloading artifacts from previous jobs"
CPP_CHANNEL=$(rapids-download-conda-from-github cpp)
source ./ci/use_conda_packages_from_prs.sh

rapids-logger "Create test conda environment"
. /opt/conda/etc/profile.d/conda.sh
Expand Down
1 change: 1 addition & 0 deletions ci/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package_dir=$2
source rapids-configure-sccache
source rapids-date-string
source rapids-init-pip
source ./ci/use_wheels_from_prs.sh

rapids-generate-version > ./VERSION

Expand Down
1 change: 1 addition & 0 deletions ci/test_cpp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ rapids-logger "Configuring conda strict channel priority"
conda config --set channel_priority strict

CPP_CHANNEL=$(rapids-download-conda-from-github cpp)
source ./ci/use_conda_packages_from_prs.sh

rapids-logger "Generate C++ testing dependencies"
rapids-dependency-file-generator \
Expand Down
1 change: 1 addition & 0 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ conda config --set channel_priority strict
rapids-logger "Downloading artifacts from previous jobs"
CPP_CHANNEL=$(rapids-download-conda-from-github cpp)
PYTHON_CHANNEL=$(rapids-download-conda-from-github python)
source ./ci/use_conda_packages_from_prs.sh

rapids-logger "Generate Python testing dependencies"
rapids-dependency-file-generator \
Expand Down
2 changes: 2 additions & 0 deletions ci/test_wheel_cuvs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ set -euo pipefail

source rapids-init-pip

source ./ci/use_wheels_from_prs.sh

# Delete system libnccl.so to ensure the wheel is used
rm -rf /usr/lib64/libnccl*

Expand Down
30 changes: 30 additions & 0 deletions ci/use_conda_packages_from_prs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# download CI artifacts
LIBRAFT_CHANNEL=$(rapids-get-pr-artifact raft 2836 cpp conda)
PYLIBRAFT_CHANNEL=$(rapids-get-pr-artifact raft 2836 python conda)

# For `rattler` builds:
#
# Add these channels to the array checked by 'rapids-rattler-channel-string'.
# This ensures that when conda packages are built with strict channel priority enabled,
# the locally-downloaded packages will be preferred to remote packages (e.g. nightlies).
#
RAPIDS_PREPENDED_CONDA_CHANNELS=(
"${LIBRAFT_CHANNEL}"
"${PYLIBRAFT_CHANNEL}"
)
export RAPIDS_PREPENDED_CONDA_CHANNELS

# For tests and `conda-build` builds:
#
# Add these channels to the system-wide conda configuration.
# This results in PREPENDING them to conda's channel list, so
# these packages should be found first if strict channel priority is enabled.
#
for _channel in "${RAPIDS_PREPENDED_CONDA_CHANNELS[@]}"
do
conda config --system --add channels "${_channel}"
done
21 changes: 21 additions & 0 deletions ci/use_wheels_from_prs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# initialize PIP_CONSTRAINT
source rapids-init-pip

RAPIDS_PY_CUDA_SUFFIX=$(rapids-wheel-ctk-name-gen "${RAPIDS_CUDA_VERSION}")

# download wheels, store the directories holding them in variables
LIBRAFT_WHEELHOUSE=$(
RAPIDS_PY_WHEEL_NAME="libraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-get-pr-artifact raft 2836 cpp wheel
)
PYLIBRAFT_WHEELHOUSE=$(
RAPIDS_PY_WHEEL_NAME="pylibraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-get-pr-artifact raft 2836 python wheel
)
# write a pip constraints file saying e.g. "whenever you encounter a requirement for 'librmm-cu12', use this wheel"
cat > "${PIP_CONSTRAINT}" <<EOF
libraft-${RAPIDS_PY_CUDA_SUFFIX} @ file://$(echo "${LIBRAFT_WHEELHOUSE}"/libraft_*.whl)
pylibraft-${RAPIDS_PY_CUDA_SUFFIX} @ file://$(echo "${PYLIBRAFT_WHEELHOUSE}"/pylibraft_*.whl)
EOF
11 changes: 8 additions & 3 deletions cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@ function(find_and_configure_raft)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )

# Set BUILD_SHARED_LIBS whenever building static dependencies
if(PKG_BUILD_STATIC_DEPS)
set(BUILD_SHARED_LIBS OFF)
endif()

# Determine whether to clone raft locally
if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "${rapids-cmake-checkout-tag}")
message(STATUS "cuVS: RAFT pinned tag found: ${PKG_PINNED_TAG}. Cloning raft locally.")
set(CPM_DOWNLOAD_raft ON)
elseif(PKG_BUILD_STATIC_DEPS AND (NOT CPM_raft_SOURCE))
message(STATUS "cuVS: Cloning raft locally to build static libraries.")
set(CPM_DOWNLOAD_raft ON)
set(BUILD_SHARED_LIBS OFF)
endif()

set(RAFT_COMPONENTS "")
Expand Down Expand Up @@ -55,8 +60,8 @@ endfunction()
# To use a different RAFT locally, set the CMake variable
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${RAFT_VERSION}.00
FORK ${RAFT_FORK}
PINNED_TAG ${RAFT_PINNED_TAG}
FORK bdice
PINNED_TAG cccl-mdspan
ENABLE_MNMG_DEPENDENCIES OFF
ENABLE_NVTX OFF
BUILD_STATIC_DEPS ${CUVS_STATIC_RAPIDS_LIBRARIES}
Expand Down
45 changes: 27 additions & 18 deletions cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ void fit(raft::resources const& handle,
raft::device_matrix_view<const float, int> X,
std::optional<raft::device_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<float, int> centroids,
raft::host_scalar_view<float, int> inertia,
raft::host_scalar_view<int, int> n_iter);
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Find clusters with k-means algorithm.
Expand Down Expand Up @@ -232,8 +232,8 @@ void fit(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t> X,
std::optional<raft::device_vector_view<const float, int64_t>> sample_weight,
raft::device_matrix_view<float, int64_t> centroids,
raft::host_scalar_view<float, int64_t> inertia,
raft::host_scalar_view<int64_t, int64_t> n_iter);
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Find clusters with k-means algorithm.
Expand Down Expand Up @@ -282,8 +282,8 @@ void fit(raft::resources const& handle,
raft::device_matrix_view<const double, int> X,
std::optional<raft::device_vector_view<const double, int>> sample_weight,
raft::device_matrix_view<double, int> centroids,
raft::host_scalar_view<double, int> inertia,
raft::host_scalar_view<int, int> n_iter);
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Find clusters with k-means algorithm.
Expand Down Expand Up @@ -333,8 +333,8 @@ void fit(raft::resources const& handle,
raft::device_matrix_view<const double, int64_t> X,
std::optional<raft::device_vector_view<const double, int64_t>> sample_weight,
raft::device_matrix_view<double, int64_t> centroids,
raft::host_scalar_view<double, int64_t> inertia,
raft::host_scalar_view<int64_t, int64_t> n_iter);
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Find clusters with k-means algorithm.
Expand Down Expand Up @@ -383,8 +383,8 @@ void fit(raft::resources const& handle,
raft::device_matrix_view<const int8_t, int> X,
std::optional<raft::device_vector_view<const int8_t, int>> sample_weight,
raft::device_matrix_view<int8_t, int> centroids,
raft::host_scalar_view<int8_t, int> inertia,
raft::host_scalar_view<int, int> n_iter);
raft::host_scalar_view<int8_t> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Find balanced clusters with k-means algorithm.
Expand Down Expand Up @@ -581,6 +581,15 @@ void predict(raft::resources const& handle,
bool normalize_weight,
raft::host_scalar_view<float> inertia);

void predict(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const float, int64_t> X,
std::optional<raft::device_vector_view<const float, int64_t>> sample_weight,
raft::device_matrix_view<const float, int64_t> centroids,
raft::device_vector_view<int64_t, int64_t> labels,
bool normalize_weight,
raft::host_scalar_view<float> inertia);

/**
* @brief Predict the closest cluster each sample in X belongs to.
*
Expand Down Expand Up @@ -632,10 +641,10 @@ void predict(raft::resources const& handle,
*/
void predict(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const float, int> X,
std::optional<raft::device_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<const float, int> centroids,
raft::device_vector_view<int64_t, int> labels,
raft::device_matrix_view<const float, int64_t> X,
std::optional<raft::device_vector_view<const float, int64_t>> sample_weight,
raft::device_matrix_view<const float, int64_t> centroids,
raft::device_vector_view<int64_t, int64_t> labels,
bool normalize_weight,
raft::host_scalar_view<float> inertia);

Expand Down Expand Up @@ -748,10 +757,10 @@ void predict(raft::resources const& handle,
*/
void predict(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const double, int> X,
std::optional<raft::device_vector_view<const double, int>> sample_weight,
raft::device_matrix_view<const double, int> centroids,
raft::device_vector_view<int64_t, int> labels,
raft::device_matrix_view<const double, int64_t> X,
std::optional<raft::device_vector_view<const double, int64_t>> sample_weight,
raft::device_matrix_view<const double, int64_t> centroids,
raft::device_vector_view<int64_t, int64_t> labels,
bool normalize_weight,
raft::host_scalar_view<double> inertia);

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t
}
// Something is wrong: have to make a copy and produce an owning dataset
auto out_layout =
raft::make_strided_layout(src.extents(), std::array<index_type, 2>{required_stride, 1});
raft::make_strided_layout(src.extents(), cuda::std::array<index_type, 2>{required_stride, 1});
auto out_array =
raft::make_device_matrix<value_type, index_type>(res, src.extent(0), required_stride);

Expand Down Expand Up @@ -310,7 +310,7 @@ auto make_strided_dataset(
const bool stride_matches = required_stride == src_stride;

auto out_layout =
raft::make_strided_layout(src.extents(), std::array<index_type, 2>{required_stride, 1});
raft::make_strided_layout(src.extents(), cuda::std::array<index_type, 2>{required_stride, 1});

using out_mdarray_type = raft::device_matrix<value_type, index_type>;
using out_layout_type = typename out_mdarray_type::layout_type;
Expand Down
14 changes: 6 additions & 8 deletions cpp/include/cuvs/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,8 @@ constexpr typename list_spec<SizeT, IdxT>::list_extents list_spec<SizeT, IdxT>::
{
// how many elems of pq_dim fit into one kIndexGroupVecLen-byte chunk
auto pq_chunk = (kIndexGroupVecLen * 8u) / pq_bits;
return raft::make_extents<SizeT>(raft::div_rounding_up_safe<SizeT>(n_rows, kIndexGroupSize),
raft::div_rounding_up_safe<SizeT>(pq_dim, pq_chunk),
kIndexGroupSize,
kIndexGroupVecLen);
return list_extents{raft::div_rounding_up_safe<SizeT>(n_rows, kIndexGroupSize),
raft::div_rounding_up_safe<SizeT>(pq_dim, pq_chunk)};
}

template <typename IdxT, typename SizeT = uint32_t>
Expand Down Expand Up @@ -335,8 +333,8 @@ struct index : cuvs::neighbors::index {
static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
"IdxT must be able to represent all values of uint32_t");

using pq_centers_extents = std::experimental::
extents<uint32_t, raft::dynamic_extent, raft::dynamic_extent, raft::dynamic_extent>;
using pq_centers_extents =
raft::extents<uint32_t, raft::dynamic_extent, raft::dynamic_extent, raft::dynamic_extent>;

public:
index(const index&) = delete;
Expand Down Expand Up @@ -2875,7 +2873,7 @@ void make_rotation_matrix(raft::resources const& res,
*/
void set_centers(raft::resources const& res,
index<int64_t>* index,
raft::device_matrix_view<const float, uint32_t> cluster_centers);
raft::device_matrix_view<const float, int64_t> cluster_centers);

/**
* @brief Public helper API for fetching a trained index's IVF centroids
Expand All @@ -2896,7 +2894,7 @@ void set_centers(raft::resources const& res,
*/
void extract_centers(raft::resources const& res,
const index<int64_t>& index,
raft::device_matrix_view<float, uint32_t, raft::row_major> cluster_centers);
raft::device_matrix_view<float, int64_t, raft::row_major> cluster_centers);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to change the signatures of ivf-pq getters to int64_t instead of uint32_t, right? For example things like list_sizes.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share a link reference?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes so we have n_lists() for example

uint32_t n_lists() const noexcept;

Do we want to change this signature to int64_t?

Copy link
Copy Markdown
Contributor

@tarang-jain tarang-jain Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words, what I am trying to say is that we might be considering an index-wide migration to int64_t from uint32_t for all extent types in IVF-PQ.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tfeher had made of aware of plans for this, but maybe that is outside the scope of this PR.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_lists() returns lists_.size() and lists_ is an std::vector

std::vector<std::shared_ptr<list_data<IdxT>>> lists_;

It should be returning size_t. Anyway, the extent type updates in this PR are not related to index<IdxT> where IdxT is the index type, not the extent type.


/** @copydoc extract_centers */
void extract_centers(raft::resources const& res,
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/cuvs/neighbors/knn_merge_parts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ void knn_merge_parts(raft::resources const& res,
raft::device_matrix_view<const int64_t, int64_t> inV,
raft::device_matrix_view<float, int64_t> outK,
raft::device_matrix_view<int64_t, int64_t> outV,
raft::device_vector_view<int64_t> translations);
raft::device_vector_view<int64_t, int64_t> translations);
void knn_merge_parts(raft::resources const& res,
raft::device_matrix_view<const float, int64_t> inK,
raft::device_matrix_view<const uint32_t, int64_t> inV,
raft::device_matrix_view<float, int64_t> outK,
raft::device_matrix_view<uint32_t, int64_t> outV,
raft::device_vector_view<uint32_t> translations);
raft::device_vector_view<uint32_t, int64_t> translations);
void knn_merge_parts(raft::resources const& res,
raft::device_matrix_view<const float, int64_t> inK,
raft::device_matrix_view<const int32_t, int64_t> inV,
raft::device_matrix_view<float, int64_t> outK,
raft::device_matrix_view<int32_t, int64_t> outV,
raft::device_vector_view<int32_t> translations);
raft::device_vector_view<int32_t, int64_t> translations);
} // namespace cuvs::neighbors
6 changes: 3 additions & 3 deletions cpp/src/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1117,9 +1117,9 @@ void kmeans_predict(raft::resources const& handle,
template <typename DataT, typename IndexT = int>
void kmeans_transform(raft::resources const& handle,
const cuvs::cluster::kmeans::params& pams,
raft::device_matrix_view<const DataT> X,
raft::device_matrix_view<const DataT> centroids,
raft::device_matrix_view<DataT> X_new)
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_matrix_view<DataT, IndexT> X_new)
{
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope("kmeans_transform");
raft::default_logger().set_level(pams.verbosity);
Expand Down
Loading