Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,11 @@ void CUDAContext::InitEigenContext() {
}

CUDAContext::CUDAContext(const CUDAPlace& place,
const stream::Priority& priority) {
const stream::Priority& priority,
const stream::StreamFlag& flag) {
place_ = place;
CUDADeviceGuard guard(place_.device);
stream_.reset(new stream::CUDAStream(place, priority));
stream_.reset(new stream::CUDAStream(place, priority, flag));
InitEigenContext();
InitCuBlasContext();
InitCuDNNContext();
Expand Down
9 changes: 8 additions & 1 deletion paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ class CUDAContext {
CUDAContext() = default;
explicit CUDAContext(
const CUDAPlace& place,
const stream::Priority& priority = stream::Priority::kNormal);
const stream::Priority& priority = stream::Priority::kNormal,
const stream::StreamFlag& flag = stream::StreamFlag::kDefaultFlag);

~CUDAContext();

Expand All @@ -288,6 +289,12 @@ class CUDAContext {

const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }

stream::CUDAStream* SetStream(stream::CUDAStream* new_stream_ptr) {
auto* old_stream_ptr = stream_.release();
stream_.reset(new_stream_ptr);
return old_stream_ptr;
}

const gpuStream_t& RawStream() { return stream_->raw_stream(); }

#ifdef PADDLE_WITH_HIP
Expand Down
41 changes: 25 additions & 16 deletions paddle/fluid/platform/stream/cuda_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,34 @@ namespace paddle {
namespace platform {
namespace stream {

#ifdef PADDLE_WITH_HIP
constexpr unsigned int kDefaultFlag = hipStreamDefault;
#else
constexpr unsigned int kDefaultFlag = cudaStreamDefault;
#endif

bool CUDAStream::Init(const Place& place, const Priority& priority) {
bool CUDAStream::Init(const Place& place, const Priority& priority,
const StreamFlag& flag) {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::InvalidArgument(
"Cuda stream must be created using cuda place."));
place_ = place;
CUDADeviceGuard guard(BOOST_GET_CONST(CUDAPlace, place_).device);
if (priority == Priority::kHigh) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
hipStreamCreateWithPriority(&stream_, kDefaultFlag, -1));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreateWithPriority(
&stream_, static_cast<unsigned int>(flag), -1));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithPriority(&stream_, kDefaultFlag, -1));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreateWithPriority(
&stream_, static_cast<unsigned int>(flag), -1));
#endif
} else if (priority == Priority::kNormal) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
hipStreamCreateWithPriority(&stream_, kDefaultFlag, 0));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreateWithPriority(
&stream_, static_cast<unsigned int>(flag), 0));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithPriority(&stream_, kDefaultFlag, 0));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreateWithPriority(
&stream_, static_cast<unsigned int>(flag), 0));
#endif
}
callback_manager_.reset(new StreamCallbackManager<gpuStream_t>(stream_));
VLOG(3) << "GPUStream Init stream: " << stream_
<< ", priority: " << static_cast<int>(priority);
<< ", priority: " << static_cast<int>(priority)
<< ", flag:" << static_cast<int>(flag);
return true;
}

Expand Down Expand Up @@ -118,6 +114,19 @@ CUDAStream* get_current_stream(int deviceId) {
#endif
}

CUDAStream* set_current_stream(CUDAStream* stream) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto& device = stream->GetPlace();
auto& pool = platform::DeviceContextPool::Instance();
return static_cast<platform::CUDADeviceContext*>(pool.Get(device))
->context()
->SetStream(stream);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit cuda current stream."));
return nullptr;
#endif
}
} // namespace stream
} // namespace platform
} // namespace paddle
18 changes: 15 additions & 3 deletions paddle/fluid/platform/stream/cuda_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,27 @@ enum class Priority : uint8_t {
kHigh = 0x1,
kNormal = 0x2,
};

enum class StreamFlag : uint8_t {
kDefaultFlag = 0x0,
kStreamNonBlocking = 0x1,
};

#endif
class CUDAStream final {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)

public:
CUDAStream() = default;
explicit CUDAStream(const Place& place,
const Priority& priority = Priority::kNormal) {
Init(place, priority);
const Priority& priority = Priority::kNormal,
const StreamFlag& flag = StreamFlag::kDefaultFlag) {
Init(place, priority, flag);
}
virtual ~CUDAStream() { Destroy(); }

bool Init(const Place& place, const Priority& priority = Priority::kNormal);
bool Init(const Place& place, const Priority& priority = Priority::kNormal,
const StreamFlag& flag = StreamFlag::kDefaultFlag);

template <typename Callback>
void AddCallback(Callback&& callback) const {
Expand Down Expand Up @@ -125,6 +134,8 @@ class CUDAStream final {
#endif
}

const Place& GetPlace() const { return place_; }

private:
Place place_;
#ifdef PADDLE_WITH_HIP
Expand All @@ -139,6 +150,7 @@ class CUDAStream final {
};

CUDAStream* get_current_stream(int deviceId);
CUDAStream* set_current_stream(CUDAStream* stream);

} // namespace stream
} // namespace platform
Expand Down
27 changes: 23 additions & 4 deletions paddle/fluid/pybind/cuda_streams_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ void BindCudaStream(py::module *m_ptr) {
},
py::return_value_policy::reference);

m.def("_set_current_stream",
[](paddle::platform::stream::CUDAStream &stream) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return paddle::platform::stream::set_current_stream(&stream);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot set cuda current "
"stream."));
#endif
},
py::return_value_policy::reference);

m.def("_device_synchronize", [](int device_id) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (device_id == -1) {
Expand Down Expand Up @@ -69,7 +81,7 @@ void BindCudaStream(py::module *m_ptr) {
If device is positive integer, it must less than the device count. Default: None.

priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
If prioriyt is None, the priority is 2(normal). Default: None.
If priority is None, the priority is 2(normal). Default: None.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -200,14 +212,17 @@ void BindCudaStream(py::module *m_ptr) {
"Priority should be 1(high) or 2(normal) "));
}
auto prio = paddle::platform::stream::Priority(priority);
auto stream_flag =
paddle::platform::stream::StreamFlag::kStreamNonBlocking;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the hard code 1 means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 means non-blocking stream. We init CUDA Stream with default non-blocking property following pytorch implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using paddle::platform::stream::StreamFlag::kStreamNonBlocking instead of 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (device == nullptr) {
int curr_device_id = platform::GetCurrentDeviceId();
auto device_tmp = platform::CUDAPlace(curr_device_id);
device = &device_tmp;
}

new (&self) paddle::platform::stream::CUDAStream(*device, prio);
new (&self) paddle::platform::stream::CUDAStream(*device, prio,
stream_flag);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform."));
Expand All @@ -224,6 +239,8 @@ void BindCudaStream(py::module *m_ptr) {
"Priority should be 1(high) or 2(normal) "));
}
auto prio = paddle::platform::stream::Priority(priority);
auto stream_flag =
paddle::platform::stream::StreamFlag::kStreamNonBlocking;

int device_count = platform::GetCUDADeviceCount();
if (device < 0) {
Expand All @@ -236,7 +253,7 @@ void BindCudaStream(py::module *m_ptr) {
}

new (&self) paddle::platform::stream::CUDAStream(
platform::CUDAPlace(device), prio);
platform::CUDAPlace(device), prio, stream_flag);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform."));
Expand All @@ -246,11 +263,13 @@ void BindCudaStream(py::module *m_ptr) {
.def("__init__", [](paddle::platform::stream::CUDAStream &self) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto prio = paddle::platform::stream::Priority::kNormal;
auto stream_flag =
paddle::platform::stream::StreamFlag::kStreamNonBlocking;

int device_id = platform::GetCurrentDeviceId();

new (&self) paddle::platform::stream::CUDAStream(
platform::CUDAPlace(device_id), prio);
platform::CUDAPlace(device_id), prio, stream_flag);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform."));
Expand Down
67 changes: 65 additions & 2 deletions python/paddle/device/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid import core
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager

from .streams import Stream # noqa: F401
from .streams import Event # noqa: F401
Expand All @@ -24,6 +26,7 @@
'synchronize',
'device_count',
'empty_cache',
'stream_guard',
]


Expand Down Expand Up @@ -121,7 +124,7 @@ def device_count():


def empty_cache():
"""
'''
Releases idle cached memory held by the allocator so that those can be used in other GPU
application and visible in `nvidia-smi`. In most cases you don't need to use this function,
Paddle does not release the memory back to the OS when you remove Tensors on the GPU,
Expand All @@ -137,7 +140,67 @@ def empty_cache():
tensor = paddle.randn([512, 512, 512], "float")
del tensor
paddle.device.cuda.empty_cache()
"""
'''

if core.is_compiled_with_cuda():
core.cuda_empty_cache()


def _set_current_stream(stream):
'''
Set the current stream.

Parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameters->Args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

询问了陈龙,Args 或者 Parameters 都可以,为了与本页面其他API 保持统一,不进行修改。

stream(paddle.device.cuda.Stream): The selected stream.

Returns:
CUDAStream: The previous stream.

'''

if not isinstance(stream, paddle.device.cuda.Stream):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面的判断是否可以包含上面 None 的判断?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我想问,可不可以统一成 TypeError?(其实我不应该写成 ValueError,想统一改成 TypeError)

raise TypeError("stream type should be paddle.device.cuda.Stream")

cur_stream = current_stream()
if id(stream) == id(cur_stream):
return stream
return core._set_current_stream(stream)


@signature_safe_contextmanager
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dygraph_only

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

待定

def stream_guard(stream):
'''
**Notes**:
**This API only supports dygraph mode currently.**

A context manager that specifies the current stream context by the given stream.

Parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paramters->Args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

stream(paddle.device.cuda.Stream): the selected stream. If stream is None, just yield. The default value is None.

Examples:
.. code-block:: python

# required: gpu
import paddle

s = paddle.device.cuda.Stream()
data1 = paddle.ones(shape=[20])
data2 = paddle.ones(shape=[20])
with paddle.device.cuda.stream_guard(s):
data3 = data1 + data2

'''

if stream is not None and not isinstance(stream, paddle.device.cuda.Stream):
raise TypeError("stream type should be paddle.device.cuda.Stream")

cur_stream = current_stream()
if stream is None or id(stream) == id(cur_stream):
yield
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里单测是不是要加上同样的stream

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经讨论后不需要修改。

else:
pre_stream = _set_current_stream(stream)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stream 是否影响分布式环境?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会进行线下测试,相关结果后续会贴在开头的 comment 中。

try:
yield
finally:
stream = _set_current_stream(pre_stream)
2 changes: 2 additions & 0 deletions python/paddle/fluid/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def to_list(s):
from .core_avx import _set_cached_executor_build_strategy
from .core_avx import _device_synchronize
from .core_avx import _get_current_stream
from .core_avx import _set_current_stream
if sys.platform != 'win32':
from .core_avx import _set_process_pids
from .core_avx import _erase_process_pids
Expand Down Expand Up @@ -328,6 +329,7 @@ def to_list(s):
from .core_noavx import _set_cached_executor_build_strategy
from .core_noavx import _device_synchronize
from .core_noavx import _get_current_stream
from .core_noavx import _set_current_stream
if sys.platform != 'win32':
from .core_noavx import _set_process_pids
from .core_noavx import _erase_process_pids
Expand Down
Loading