-
Notifications
You must be signed in to change notification settings - Fork 184
Migrate IVF-PQ from RAFT to cuVS #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
3c40a06
c5a3d3a
3568f12
69bf330
7a5e8a0
4b86f94
8c8aaed
e444f4b
5228a68
3f924f2
e8907ee
64809e8
e88f82c
36bd63b
9c4781d
9cbfc53
1a8a789
6e7591a
5a719ba
844f209
2783c83
ec2125b
c024780
55d8851
b24a43c
091fe02
80a0c0b
40a860a
cb011f6
dff10bc
3a4c04b
8fd2389
9b9db71
e5b8500
849358c
8621e6f
197730e
7c15d0c
5b60ee4
73ae071
57d33a8
3cf245a
f7dbd55
e22aff7
de395c8
5239b93
10b9a7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,264 @@ | ||
| /* | ||
| * Copyright (c) 2024, NVIDIA CORPORATION. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <raft/core/device_container_policy.hpp> | ||
| #include <raft/core/device_mdarray.hpp> | ||
| #include <raft/core/resource/thrust_policy.hpp> | ||
| #include <raft/core/resources.hpp> | ||
| #include <thrust/functional.h> | ||
|
|
||
| namespace cuvs::core { | ||
| /** | ||
| * @defgroup bitset Bitset | ||
| * @{ | ||
| */ | ||
| /** | ||
| * @brief View of a cuVS Bitset. | ||
| * | ||
| * This lightweight structure stores a pointer to a bitset in device memory with it's length. | ||
| * It provides a test() device function to check if a given index is set in the bitset. | ||
| * | ||
| * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. | ||
| * @tparam index_t Indexing type used. Default is uint32_t. | ||
| */ | ||
| template <typename bitset_t = uint32_t, typename index_t = uint32_t> | ||
| struct bitset_view { | ||
| static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; | ||
|
|
||
| _RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len); | ||
| /** | ||
| * @brief Create a bitset view from a device vector view of the bitset. | ||
| * | ||
| * @param bitset_span Device vector view of the bitset | ||
| * @param bitset_len Number of bits in the bitset | ||
| */ | ||
| _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view<bitset_t, index_t> bitset_span, | ||
| index_t bitset_len); | ||
| /** | ||
| * @brief Device function to test if a given index is set in the bitset. | ||
| * | ||
| * @param sample_index Single index to test | ||
| * @return bool True if index has not been unset in the bitset | ||
| */ | ||
| _RAFT_DEVICE inline bool test(const index_t sample_index) const | ||
| { | ||
| const bitset_t bit_element = bitset_ptr_[sample_index / bitset_element_size]; | ||
| const index_t bit_index = sample_index % bitset_element_size; | ||
| const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0; | ||
| return is_bit_set; | ||
| } | ||
| /** | ||
| * @brief Device function to test if a given index is set in the bitset. | ||
| * | ||
| * @param sample_index Single index to test | ||
| * @return bool True if index has not been unset in the bitset | ||
| */ | ||
| _RAFT_DEVICE bool operator[](const index_t sample_index) const { return test(sample_index); } | ||
| /** | ||
| * @brief Device function to set a given index to set_value in the bitset. | ||
| * | ||
| * @param sample_index index to set | ||
| * @param set_value Value to set the bit to (true or false) | ||
| */ | ||
| _RAFT_DEVICE void set(const index_t sample_index, bool set_value) const; | ||
|
|
||
| /** | ||
| * @brief Get the device pointer to the bitset. | ||
| */ | ||
| _RAFT_HOST_DEVICE bitset_t* data(); | ||
| _RAFT_HOST_DEVICE const bitset_t* data() const; | ||
| /** | ||
| * @brief Get the number of bits of the bitset representation. | ||
| */ | ||
| _RAFT_HOST_DEVICE index_t size() const; | ||
|
|
||
| /** | ||
| * @brief Get the number of elements used by the bitset representation. | ||
| */ | ||
| _RAFT_HOST_DEVICE index_t n_elements() const; | ||
|
|
||
| raft::device_vector_view<bitset_t, index_t> to_mdspan(); | ||
| raft::device_vector_view<const bitset_t, index_t> to_mdspan() const; | ||
|
|
||
| private: | ||
| bitset_t* bitset_ptr_; | ||
| index_t bitset_len_; | ||
| }; | ||
|
|
||
| /** | ||
| * @brief cuVS Bitset. | ||
| * | ||
| * This structure encapsulates a bitset in device memory. It provides a view() method to get a | ||
| * device-usable lightweight view of the bitset. | ||
| * Each index is represented by a single bit in the bitset. The total number of bytes used is | ||
| * ceil(bitset_len / 8). | ||
| * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. | ||
| * @tparam index_t Indexing type used. Default is uint32_t. | ||
| */ | ||
| template <typename bitset_t = uint32_t, typename index_t = uint32_t> | ||
| struct bitset { | ||
| static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; | ||
|
|
||
| /** | ||
| * @brief Construct a new bitset object with a list of indices to unset. | ||
| * | ||
| * @param res RAFT resources | ||
| * @param mask_index List of indices to unset in the bitset | ||
| * @param bitset_len Length of the bitset | ||
| * @param default_value Default value to set the bits to. Default is true. | ||
| */ | ||
| bitset(const raft::resources& res, | ||
| raft::device_vector_view<const index_t, index_t> mask_index, | ||
| index_t bitset_len, | ||
| bool default_value = true); | ||
|
|
||
| /** | ||
| * @brief Construct a new bitset object | ||
| * | ||
| * @param res RAFT resources | ||
| * @param bitset_len Length of the bitset | ||
| * @param default_value Default value to set the bits to. Default is true. | ||
| */ | ||
| bitset(const raft::resources& res, index_t bitset_len, bool default_value = true); | ||
| // Disable copy constructor | ||
| bitset(const bitset&) = delete; | ||
| bitset(bitset&&) = default; | ||
| bitset& operator=(const bitset&) = delete; | ||
| bitset& operator=(bitset&&) = default; | ||
|
|
||
| /** | ||
| * @brief Create a device-usable view of the bitset. | ||
| * | ||
| * @return bitset_view<bitset_t, index_t> | ||
| */ | ||
| cuvs::core::bitset_view<bitset_t, index_t> view(); | ||
| cuvs::core::bitset_view<const bitset_t, index_t> view() const; | ||
|
|
||
| /** | ||
| * @brief Get the device pointer to the bitset. | ||
| */ | ||
| bitset_t* data(); | ||
| const bitset_t* data() const; | ||
| /** | ||
| * @brief Get the number of bits of the bitset representation. | ||
| */ | ||
| index_t size() const; | ||
|
|
||
| /** | ||
| * @brief Get the number of elements used by the bitset representation. | ||
| */ | ||
| index_t n_elements() const; | ||
|
|
||
| /** @brief Get an mdspan view of the current bitset */ | ||
| raft::device_vector_view<bitset_t, index_t> to_mdspan(); | ||
| raft::device_vector_view<const bitset_t, index_t> to_mdspan() const; | ||
|
|
||
| /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to | ||
| * the default value. | ||
| * @param res RAFT resources | ||
| * @param new_bitset_len new size of the bitset | ||
| * @param default_value default value to initialize the new bits to | ||
| */ | ||
| void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true); | ||
|
|
||
| /** | ||
| * @brief Test a list of indices in a bitset. | ||
| * | ||
| * @tparam output_t Output type of the test. Default is bool. | ||
| * @param res RAFT resources | ||
| * @param queries List of indices to test | ||
| * @param output List of outputs | ||
| */ | ||
| /* | ||
| TODO: Disabled test() for cuVS migration | ||
| template <typename output_t = bool> | ||
| void test(const raft::resources& res, | ||
| raft::device_vector_view<const index_t, index_t> queries, | ||
| raft::device_vector_view<output_t, index_t> output) const | ||
| { | ||
| RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); | ||
| auto bitset_view = view(); | ||
| thrust::transform( | ||
| raft::resource::get_thrust_policy(res), | ||
| queries.data_handle(), | ||
| queries.data_handle() + queries.size(), | ||
| output.data_handle(), | ||
| [bitset_view] __device__(index_t query) { return output_t{bitset_view.test(query)}; }); | ||
| } | ||
| */ | ||
| /** | ||
| * @brief Set a list of indices in a bitset to set_value. | ||
| * | ||
| * @param res RAFT resources | ||
| * @param mask_index indices to remove from the bitset | ||
| * @param set_value Value to set the bits to (true or false) | ||
| */ | ||
| void set(const raft::resources& res, | ||
| raft::device_vector_view<const index_t, index_t> mask_index, | ||
| bool set_value = false); | ||
| /** | ||
| * @brief Flip all the bits in a bitset. | ||
| * @param res RAFT resources | ||
| */ | ||
| void flip(const raft::resources& res); | ||
| /** | ||
| * @brief Reset the bits in a bitset. | ||
| * | ||
| * @param res RAFT resources | ||
| * @param default_value Value to set the bits to (true or false) | ||
| */ | ||
| void reset(const raft::resources& res, bool default_value = true); | ||
| /** | ||
| * @brief Returns the number of bits set to true in count_gpu_scalar. | ||
| * | ||
| * @param[in] res RAFT resources | ||
| * @param[out] count_gpu_scalar Device scalar to store the count | ||
| */ | ||
| void count(const raft::resources& res, raft::device_scalar_view<index_t> count_gpu_scalar); | ||
|
|
||
| /** | ||
| * @brief Returns the number of bits set to true. | ||
| * | ||
| * @param res RAFT resources | ||
| * @return index_t Number of bits set to true | ||
| */ | ||
| index_t count(const raft::resources& res); | ||
|
|
||
| /** | ||
| * @brief Checks if any of the bits are set to true in the bitset. | ||
| * @param res RAFT resources | ||
| */ | ||
| bool any(const raft::resources& res) { return count(res) > 0; } | ||
| /** | ||
| * @brief Checks if all of the bits are set to true in the bitset. | ||
| * @param res RAFT resources | ||
| */ | ||
| bool all(const raft::resources& res) { return count(res) == bitset_len_; } | ||
| /** | ||
| * @brief Checks if none of the bits are set to true in the bitset. | ||
| * @param res RAFT resources | ||
| */ | ||
| bool none(const raft::resources& res) { return count(res) == 0; } | ||
|
|
||
| private: | ||
| raft::device_uvector<bitset_t> bitset_; | ||
| index_t bitset_len_; | ||
| }; | ||
|
|
||
| /** @} */ | ||
| } // end namespace cuvs::core | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| /* | ||
| * Copyright (c) 2024, NVIDIA CORPORATION. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't be a user-facing API and thus shouldn't be defined in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having all of the declaration in
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've done relative paths for cagra so far and haven't seen them to be too terrible. Is ivf-pq somehow making it more challenging to work with? From a development perspective, I generally tend to prefer the use of the relative quotationed paths for things that are local to src/, (and thus internal) rather than muddying the line between the two and making it harder to discern which things are public APIs and which aren't. |
||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once | ||
| #include <raft/core/nvtx.hpp> | ||
|
|
||
| namespace cuvs::common::nvtx::domain { | ||
| /** @brief This NVTX domain is supposed to be used within cuvs. */ | ||
| struct cuvs { | ||
| static constexpr const char* name = "cuvs"; | ||
| }; | ||
| }; // namespace cuvs::common::nvtx::domain | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should consider keeping this in raft with the other vocabulary types so it can continue to be used across different libraries that use raft. Also- we can't be defining device functions in an hpp file in cuVS, since it's not header-only. The only APIs users should be interacting with in cuVS should be pre-compiled runtime APIs.