Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/platform/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class CublasHandleHolder {
}
#endif

const cublasHandle_t& GetCublasHandle() const { return handle_; }

~CublasHandleHolder() PADDLE_MAY_THROW {
#ifdef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS(dynload::rocblas_destroy_handle(handle_));
Expand All @@ -117,7 +119,7 @@ class CublasHandleHolder {
}

template <typename Callback>
inline void Call(Callback &&callback) const {
inline void Call(Callback&& callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return context()->CudnnHandle();
}

cublasHandle_t CUDADeviceContext::cublas_handle() const {
return context()->CublasHandle()->GetCublasHandle();
}

CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle() const;
#endif

/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const;

/*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/platform/device_context_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ TEST(Device, CUDADeviceContext) {
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
#endif
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
delete device_context;
}
}
Expand Down