-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add paddle.device.cuda.stream_guard API #35623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
c00a85d to
4047772
Compare
4047772 to
19a17c9
Compare
| self.assertTrue(event_query_2) | ||
|
|
||
|
|
||
| class TestStreamGuard(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在PR上贴上验证的代码以及验证之后的效果
|
|
||
| cur_stream = current_stream() | ||
| if stream is None or id(stream) == id(cur_stream): | ||
| yield |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里单测是不是要加上同样的stream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经讨论后不需要修改。
| ''' | ||
| Set the current stream. | ||
| Parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameters->Args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
询问了陈龙,Args 或者 Parameters 都可以,为了与本页面其他API 保持统一,不进行修改。
| #else | ||
| PADDLE_THROW(platform::errors::Unavailable( | ||
| "Class CUDAStream can only be initialized on the GPU platform.")); | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的Stream方法,是不是可以默认non_blocking方式
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| enum class StreamFlag : uint8_t { | ||
| kDefaultFlag = 0x0, | ||
| kStreamNonBlocking = 0x1, | ||
| kStreamPerThread = 0x2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个kStreamPerThread 可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| A context manager that specifies the current stream context by the given stream. | ||
| Parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paramters->Args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
| return core._set_current_stream(stream) | ||
|
|
||
|
|
||
| @signature_safe_contextmanager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dygraph_only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
待定
| if stream is None or id(stream) == id(cur_stream): | ||
| yield | ||
| else: | ||
| pre_stream = _set_current_stream(stream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stream 是否影响分布式环境?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
会进行线下测试,相关结果后续会贴在开头的 comment 中。
| } | ||
| auto prio = paddle::platform::stream::Priority(priority); | ||
| auto stream_flag = paddle::platform::stream::StreamFlag(1); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the hard code 1 means?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 means non-blocking stream. We init CUDA Stream with default non-blocking property following pytorch implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using paddle::platform::stream::StreamFlag::kStreamNonBlocking instead of 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
|
||
| if stream is None: | ||
| raise ValueError("input stream should not be None.") | ||
| if not isinstance(stream, paddle.device.cuda.Stream): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下面的判断是否可以包含上面 None 的判断?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我想问,可不可以统一成 TypeError?(其实我不应该写成 ValueError,想统一改成 TypeError)
wawltor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
XiaoguangHu01
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG API
MingMingShangTian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Add paddle.cuda.device.stream_guard API
PR types
New features
PR changes
APIs
Describe
This API provide a way to switch Cuda Stream flexibly.

Offline Test
Async property test
From the picture above, we can see that CUDA Kernel and CUDA Memcpy can run asynchronously.