-
Notifications
You must be signed in to change notification settings - Fork 247
Use cuda::stream_ref for stream usage #2372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,6 +11,8 @@ | |||||||||||||||
| #include <rmm/mr/per_device_resource.hpp> | ||||||||||||||||
| #include <rmm/resource_ref.hpp> | ||||||||||||||||
|
|
||||||||||||||||
| #include <cuda/stream_ref> | ||||||||||||||||
|
|
||||||||||||||||
| #include <type_traits> | ||||||||||||||||
|
|
||||||||||||||||
| namespace RMM_NAMESPACE { | ||||||||||||||||
|
|
@@ -84,7 +86,7 @@ class device_scalar { | |||||||||||||||
| * @param mr Optional, resource with which to allocate. | ||||||||||||||||
| */ | ||||||||||||||||
| explicit device_scalar( | ||||||||||||||||
| cuda_stream_view stream, | ||||||||||||||||
| cuda::stream_ref stream, | ||||||||||||||||
| cuda::mr::any_resource<cuda::mr::device_accessible> mr = mr::get_current_device_resource_ref()) | ||||||||||||||||
| : _storage{1, stream, std::move(mr)} | ||||||||||||||||
| { | ||||||||||||||||
|
|
@@ -110,7 +112,7 @@ class device_scalar { | |||||||||||||||
| */ | ||||||||||||||||
| explicit device_scalar( | ||||||||||||||||
| value_type const& initial_value, | ||||||||||||||||
| cuda_stream_view stream, | ||||||||||||||||
| cuda::stream_ref stream, | ||||||||||||||||
| cuda::mr::any_resource<cuda::mr::device_accessible> mr = mr::get_current_device_resource_ref()) | ||||||||||||||||
| : _storage{1, stream, std::move(mr)} | ||||||||||||||||
| { | ||||||||||||||||
|
|
@@ -131,7 +133,7 @@ class device_scalar { | |||||||||||||||
| */ | ||||||||||||||||
| device_scalar( | ||||||||||||||||
| device_scalar const& other, | ||||||||||||||||
| cuda_stream_view stream, | ||||||||||||||||
| cuda::stream_ref stream, | ||||||||||||||||
| cuda::mr::any_resource<cuda::mr::device_accessible> mr = mr::get_current_device_resource_ref()) | ||||||||||||||||
| : _storage{other._storage, stream, std::move(mr)} | ||||||||||||||||
| { | ||||||||||||||||
|
|
@@ -153,7 +155,7 @@ class device_scalar { | |||||||||||||||
| * @return T The value of the scalar. | ||||||||||||||||
| * @param stream CUDA stream on which to perform the copy and synchronize. | ||||||||||||||||
| */ | ||||||||||||||||
| [[nodiscard]] value_type value(cuda_stream_view stream) const | ||||||||||||||||
| [[nodiscard]] value_type value(cuda::stream_ref stream) const | ||||||||||||||||
| { | ||||||||||||||||
| return _storage.front_element(stream); | ||||||||||||||||
| } | ||||||||||||||||
|
|
@@ -191,14 +193,14 @@ class device_scalar { | |||||||||||||||
| * @param value The host value which will be copied to device | ||||||||||||||||
| * @param stream CUDA stream on which to perform the copy | ||||||||||||||||
| */ | ||||||||||||||||
| void set_value_async(value_type const& value, cuda_stream_view stream) | ||||||||||||||||
| void set_value_async(value_type const& value, cuda::stream_ref stream) | ||||||||||||||||
| { | ||||||||||||||||
| _storage.set_element_async(0, value, stream); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // Disallow passing literals to set_value to avoid race conditions where the memory holding the | ||||||||||||||||
| // literal can be freed before the async memcpy / memset executes. | ||||||||||||||||
| void set_value_async(value_type&&, cuda_stream_view) = delete; | ||||||||||||||||
| void set_value_async(value_type&&, cuda::stream_ref) = delete; | ||||||||||||||||
|
|
||||||||||||||||
| /** | ||||||||||||||||
| * @brief Sets the value of the `device_scalar` to zero on the specified stream. | ||||||||||||||||
|
|
@@ -214,7 +216,7 @@ class device_scalar { | |||||||||||||||
| * | ||||||||||||||||
| * @param stream CUDA stream on which to perform the copy | ||||||||||||||||
| */ | ||||||||||||||||
| void set_value_to_zero_async(cuda_stream_view stream) | ||||||||||||||||
| void set_value_to_zero_async(cuda::stream_ref stream) | ||||||||||||||||
| { | ||||||||||||||||
| _storage.set_element_to_zero_async(value_type{0}, stream); | ||||||||||||||||
|
Comment on lines
+219
to
221
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Verify the callee signature and this call site side-by-side.
rg -n -C2 'set_element_to_zero_async\s*\(' cpp/include/rmm/device_uvector.hpp cpp/include/rmm/device_scalar.hppRepository: rapidsai/rmm Length of output: 1085 Use element index
Proposed fix- _storage.set_element_to_zero_async(value_type{0}, stream);
+ _storage.set_element_to_zero_async(size_type{0}, stream);📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||
| } | ||||||||||||||||
|
|
@@ -261,7 +263,7 @@ class device_scalar { | |||||||||||||||
| * | ||||||||||||||||
| * @param stream Stream to be used for deallocation | ||||||||||||||||
| */ | ||||||||||||||||
| void set_stream(cuda_stream_view stream) noexcept { _storage.set_stream(stream); } | ||||||||||||||||
| void set_stream(cuda::stream_ref stream) noexcept { _storage.set_stream(stream); } | ||||||||||||||||
|
|
||||||||||||||||
| private: | ||||||||||||||||
| rmm::device_uvector<T> _storage; | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: rapidsai/rmm
Length of output: 248
Wrap
cudaMemsetAsyncwithRMM_CUDA_TRYto detect CUDA errors.Line 94 has an unchecked
cudaMemsetAsynccall that could silently fail. Per coding guidelines, all CUDA API calls must be wrapped withRMM_CUDA_TRYto detect errors early.Suggested patch
📝 Committable suggestion
🤖 Prompt for AI Agents