-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Restructure device context #4593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
wangkuiyi
wants to merge
6
commits into
PaddlePaddle:develop
from
wangkuiyi:restructure-device-context
Closed
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6a498e4
incomplete work to push
5f51d0a
Add -D PADDLE_WITH_CUDA in cmake/configure.cmake
47b5de1
Merge branch 'paddle_only_cpu' into restructure-device-context
8517345
Merge branch 'develop' of https://github.com/paddlepaddle/paddle into…
b83ead5
Staging my work
d66595e
Resolve conflicts
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,142 +1,108 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/platform/device_context.h" | ||
| #include "paddle/memory/memory.h" | ||
|
|
||
| namespace paddle { | ||
| namespace platform { | ||
|
|
||
| template <> | ||
| Eigen::DefaultDevice* DeviceContext::GetEigenDevice< | ||
| platform::CPUPlace, Eigen::DefaultDevice>() const { | ||
| return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device(); | ||
| } | ||
| #ifdef PADDLE_WITH_CUDA | ||
|
|
||
| CPUDeviceContext::CPUDeviceContext() { | ||
| eigen_device_.reset(new Eigen::DefaultDevice()); | ||
| CUDADeviceContext::EigenCudaStreamDevice::EigenCudaStreamDevice() | ||
| : scratch_(nullptr), semaphore_(nullptr) { | ||
| Eigen::initializeDeviceProp(); | ||
| } | ||
| CUDADeviceContext::EigenCudaStreamDevice::~EigenCudaStreamDevice() override {} | ||
|
|
||
| CPUDeviceContext::CPUDeviceContext(CPUPlace place) { | ||
| eigen_device_.reset(new Eigen::DefaultDevice()); | ||
| void CUDADeviceContext::EigenCudaStreamDevice::SetValues( | ||
| const cudaStream_t* cuda_stream, GPUPlace place) { | ||
| stream_ = cuda_stream; | ||
| place_ = place; | ||
| device_prop_ = &Eigen::m_deviceProperties[place.device]; | ||
| } | ||
|
|
||
| Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { | ||
| return eigen_device_.get(); | ||
| const cudaStream_t& void CUDADeviceContext::EigenCudaStreamDevice::stream() | ||
| const override { | ||
| return *stream_; | ||
| } | ||
|
|
||
| Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } | ||
|
|
||
| #ifndef PADDLE_ONLY_CPU | ||
|
|
||
| template <> | ||
| Eigen::GpuDevice* | ||
| DeviceContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { | ||
| return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device(); | ||
| const cudaDeviceProp& void | ||
| CUDADeviceContext::EigenCudaStreamDevice::deviceProperties() const override { | ||
| return *device_prop_; | ||
| } | ||
|
|
||
| class EigenCudaStreamDevice : public Eigen::StreamInterface { | ||
| public: | ||
| EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { | ||
| Eigen::initializeDeviceProp(); | ||
| } | ||
| ~EigenCudaStreamDevice() override {} | ||
|
|
||
| void Reinitialize(const cudaStream_t* cuda_stream, GPUPlace place) { | ||
| stream_ = cuda_stream; | ||
| place_ = place; | ||
| device_prop_ = &Eigen::m_deviceProperties[place.device]; | ||
| } | ||
|
|
||
| const cudaStream_t& stream() const override { return *stream_; } | ||
|
|
||
| const cudaDeviceProp& deviceProperties() const override { | ||
| return *device_prop_; | ||
| } | ||
|
|
||
| void* allocate(size_t num_bytes) const override { | ||
| return paddle::memory::Alloc(place_, num_bytes); | ||
| } | ||
| void* void CUDADeviceContext::EigenCudaStreamDevice::allocate( | ||
| size_t num_bytes) const override { | ||
| return paddle::memory::Alloc(place_, num_bytes); | ||
| } | ||
|
|
||
| void deallocate(void* buffer) const override { | ||
| paddle::memory::Free(place_, buffer); | ||
| } | ||
| void void CUDADeviceContext::EigenCudaStreamDevice::deallocate( | ||
| void* buffer) const override { | ||
| paddle::memory::Free(place_, buffer); | ||
| } | ||
|
|
||
| void* scratchpad() const override { | ||
| if (scratch_ == NULL) { | ||
| scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int)); | ||
| } | ||
| return scratch_; | ||
| void* void CUDADeviceContext::EigenCudaStreamDevice::scratchpad() | ||
| const override { | ||
| if (scratch_ == NULL) { | ||
| scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int)); | ||
| } | ||
| return scratch_; | ||
| } | ||
|
|
||
| unsigned int* semaphore() const override { | ||
| if (semaphore_ == NULL) { | ||
| char* scratch = | ||
| static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize; | ||
| semaphore_ = reinterpret_cast<unsigned int*>(scratch); | ||
| PADDLE_ENFORCE( | ||
| cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_)); | ||
| } | ||
| return semaphore_; | ||
| unsigned int* void CUDADeviceContext::EigenCudaStreamDevice::semaphore() | ||
| const override { | ||
| if (semaphore_ == NULL) { | ||
| char* scratch = static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize; | ||
| semaphore_ = reinterpret_cast<unsigned int*>(scratch); | ||
| PADDLE_ENFORCE( | ||
| cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_)); | ||
| } | ||
|
|
||
| private: | ||
| GPUPlace place_; | ||
| const cudaStream_t* stream_; // not owned; | ||
| const cudaDeviceProp* device_prop_; // not owned; | ||
| mutable void* scratch_; | ||
| mutable unsigned int* semaphore_; | ||
| }; | ||
| return semaphore_; | ||
| } | ||
|
|
||
| CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { | ||
| // Create CUDA stream on the given device. | ||
| SetDeviceId(place_.device); | ||
| PADDLE_ENFORCE(cudaStreamCreate(&stream_)); | ||
|
|
||
| // Set the CUDA stream into the EigenCudaStreamDevice instance. | ||
| eigen_stream_.reset(new EigenCudaStreamDevice()); | ||
| eigen_stream_->Reinitialize(&stream_, place); | ||
| eigen_stream_->SetValues(&stream_, place); | ||
|
|
||
| // Initialize Eigen::CpuDevice using EigenCudaStreamDevice. | ||
| eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); | ||
|
|
||
| // Create other handles in addition to the CUDA stream. | ||
| PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); | ||
| PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); | ||
| PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); | ||
| PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); | ||
| } | ||
|
|
||
| CUDADeviceContext::~CUDADeviceContext() { | ||
| // Wait for the completion of all operations before destructing. | ||
| SetDeviceId(place_.device); | ||
| Wait(); | ||
|
|
||
| // Note: the destruction order must be the same with the | ||
| // construction order. | ||
| PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); | ||
| PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); | ||
| eigen_stream_.reset(); | ||
| eigen_device_.reset(); | ||
| PADDLE_ENFORCE(cudaStreamDestroy(stream_)); | ||
| } | ||
|
|
||
| Place CUDADeviceContext::GetPlace() const { return place_; } | ||
|
|
||
| void CUDADeviceContext::Wait() const { | ||
| PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); | ||
| } | ||
|
|
||
| Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { | ||
| return eigen_device_.get(); | ||
| } | ||
|
|
||
| cublasHandle_t CUDADeviceContext::cublas_handle() const { | ||
| return cublas_handle_; | ||
| } | ||
|
|
||
| cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } | ||
|
|
||
| cudaStream_t CUDADeviceContext::stream() const { return stream_; } | ||
|
|
||
| #endif // PADDLE_ONLY_CPU | ||
| #endif // PADDLE_WITH_CUDA | ||
|
|
||
| } // namespace platform | ||
| } // namespace paddle |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,13 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
|
|
||
|
|
@@ -27,74 +27,82 @@ limitations under the License. */ | |
| namespace paddle { | ||
| namespace platform { | ||
|
|
||
| template <typename T> | ||
| struct EigenDeviceConverter; | ||
|
|
||
| template <> | ||
| struct EigenDeviceConverter<platform::CPUPlace> { | ||
| using EigenDeviceType = Eigen::DefaultDevice; | ||
| }; | ||
|
|
||
| class DeviceContext { | ||
| public: | ||
| virtual ~DeviceContext() {} | ||
| virtual Place GetPlace() const = 0; | ||
|
|
||
| template <typename PlaceType, | ||
| typename DeviceType = | ||
| typename EigenDeviceConverter<PlaceType>::EigenDeviceType> | ||
| DeviceType* GetEigenDevice() const; | ||
|
|
||
| virtual void Wait() const {} | ||
| }; | ||
|
|
||
| class CPUDeviceContext : public DeviceContext { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we use boost::variant instead of instead of inheritance, we should remove derivation. |
||
| public: | ||
| CPUDeviceContext(); | ||
| explicit CPUDeviceContext(CPUPlace place); | ||
|
|
||
| Eigen::DefaultDevice* eigen_device() const; | ||
| CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } | ||
| explicit CPUDeviceContext(CPUPlace place) { | ||
| eigen_device_.reset(new Eigen::DefaultDevice()); | ||
| } | ||
|
|
||
| Place GetPlace() const override; | ||
| Eigen::DefaultDevice* GetEigenDevice() const { return eigen_device_.get(); } | ||
| Place GetPlace() const { return CPUPlace(); } | ||
|
|
||
| private: | ||
| std::unique_ptr<Eigen::DefaultDevice> eigen_device_; | ||
| }; | ||
|
|
||
| #ifndef PADDLE_ONLY_CPU | ||
| template <> | ||
| struct EigenDeviceConverter<platform::GPUPlace> { | ||
| using EigenDeviceType = Eigen::GpuDevice; | ||
| }; | ||
|
|
||
| class EigenCudaStreamDevice; | ||
|
|
||
| #ifdef PADDLE_WITH_CUDA | ||
|
|
||
| // The CUDADeviceContext is a parameter to framework::OperatorBase::Run: | ||
| /* | ||
| virtual void Run(const Scope& scope, | ||
| const platform::DeviceContext& dev_ctx) const = 0; | ||
| */ | ||
| // To call Eigen functions in Run, we'd need to provide a parameter of | ||
| // type Eigen::CpuDevice, from CUDADeviceContext::GetEigenDevice(). | ||
| // | ||
| // SomeEigenFunction(dev_ctx.GetEigenDevice(), ...); | ||
| // | ||
| // If we are going to call CUDA, cuDNN, cuBLAS function, we need to | ||
| // pass them handles returned by stream, cudnn_handle, cublas_handle. | ||
| // For example: | ||
| // | ||
| // SomeCUDNNFunction(dev_ctx.cudnn_handle(), ...); | ||
| // | ||
| class CUDADeviceContext : public DeviceContext { | ||
| public: | ||
| explicit CUDADeviceContext(GPUPlace place); | ||
| virtual ~CUDADeviceContext(); | ||
|
|
||
| /*! \brief Wait for all operations completion in the stream. */ | ||
| void Wait() const override; | ||
|
|
||
| /*! \brief Return place in the device context. */ | ||
| Place GetPlace() const override; | ||
|
|
||
| /*! \brief Return eigen device in the device context. */ | ||
| Eigen::GpuDevice* eigen_device() const; | ||
|
|
||
| /*! \brief Return cublas handle in the device context. */ | ||
| cublasHandle_t cublas_handle() const; | ||
| Eigen::GpuDevice* GetEigenDevice() const { return eigen_device_.get(); } | ||
| Place GetPlace() const override { return place_; } | ||
|
|
||
| /*! \brief Return cudnn handle in the device context. */ | ||
| cudnnHandle_t cudnn_handle() const; | ||
| /*! \brief Wait for all operations completion in the stream. */ | ||
| void Wait() const override { PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); } | ||
|
|
||
| /*! \brief Return cuda stream in the device context. */ | ||
| cudaStream_t stream() const; | ||
| cublasHandle_t cublas_handle() const { return cublas_handle_; } | ||
| cudnnHandle_t cudnn_handle() const { return cudnn_handle_; } | ||
| cudaStream_t stream() const { return stream_; } | ||
|
|
||
| private: | ||
| GPUPlace place_; | ||
| // Eigen requires that a Eigen::GpuDevice instance being initialized | ||
| // from a class derived from Eigen::StreamInterface. | ||
| class EigenCudaStreamDevice : public Eigen::StreamInterface { | ||
| public: | ||
| EigenCudaStreamDevice(); | ||
| ~EigenCudaStreamDevice() override {} | ||
|
|
||
| // https://github.com/PaddlePaddle/Paddle/pull/3497#issue-250238535 | ||
| // explained that initializing CUDA stream in the constructor | ||
| // would cause SEGFAULT, so we add this method. | ||
| void SetValues(const cudaStream_t* cuda_stream, GPUPlace place); | ||
|
|
||
| const cudaStream_t& stream() const override; | ||
| const cudaDeviceProp& deviceProperties() const override; | ||
| void* allocate(size_t num_bytes) const override; | ||
| void deallocate(void* buffer) const override; | ||
| void* scratchpad() const override; | ||
| unsigned int* semaphore() const override; | ||
|
|
||
| private: | ||
| GPUPlace place_; | ||
| const cudaStream_t* stream_; // not owned; | ||
| const cudaDeviceProp* device_prop_; // not owned; | ||
| mutable void* scratch_; | ||
| mutable unsigned int* semaphore_; | ||
| }; | ||
|
|
||
| GPUPlace place_; | ||
| std::unique_ptr<Eigen::GpuDevice> eigen_device_; | ||
| std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; | ||
|
|
||
|
|
@@ -103,7 +111,13 @@ class CUDADeviceContext : public DeviceContext { | |
| cublasHandle_t cublas_handle_; | ||
| }; | ||
|
|
||
| #endif | ||
| #endif // PADDLE_WITH_CUDA | ||
|
|
||
| #ifdef PADDLE_WITH_CUDA | ||
| typedef boost::variant<CPUDeviceContext, CUDADeviceContext> DeviceContext; | ||
| #else | ||
| typedef boost::variant<CPUDeviceContext> DeviceContext; | ||
| #endif // PADDLE_WITH_CUDA | ||
|
|
||
| } // namespace platform | ||
| } // namespace paddle | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now have
-DPADDLE_WITH_GPUalready, please merge the latest develop branch first.