-
Notifications
You must be signed in to change notification settings - Fork 5.9k
port allocation from majel to paddle #2217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
fa0e6e1
456e6d9
8b3220a
f202377
57b67d6
da380e8
45886cf
80af107
13bfb93
c3831ed
30daf3f
7822c86
2d6a2be
5dafd75
be29b9c
2258c23
4af8933
74b9d91
99f6048
47de9b4
7ae5e38
3c8e470
000a1e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| #include "paddle/majel/allocation.h" | ||
| #include <boost/variant.hpp> | ||
| #include "paddle/majel/malloc.h" | ||
|
|
||
| namespace majel { | ||
| namespace detail { | ||
|
|
||
| class Allocator : public boost::static_visitor<void*> { | ||
| public: | ||
| Allocator(size_t size) : size_(size) {} | ||
|
|
||
| void* operator()(const CpuPlace& p) const { | ||
| void* address = majel::malloc::malloc(p, size_); | ||
| return address; | ||
| } | ||
|
|
||
| #ifndef PADDLE_ONLY_CPU | ||
| void* operator()(const GpuPlace& p) const { | ||
| void* address = majel::malloc::malloc(p, size_); | ||
| return address; | ||
| } | ||
| #endif | ||
|
|
||
| private: | ||
| size_t size_; | ||
| }; | ||
|
|
||
| class Deallocator : public boost::static_visitor<> { | ||
| public: | ||
| Deallocator(void* ptr) : ptr_(ptr) {} | ||
|
|
||
| void operator()(CpuPlace p) const { | ||
| if (ptr_) { | ||
| majel::malloc::free(p, ptr_); | ||
| } | ||
| } | ||
| #ifndef PADDLE_ONLY_CPU | ||
| void operator()(GpuPlace p) const { | ||
| if (ptr_) { | ||
| majel::malloc::free(p, ptr_); | ||
| } | ||
| } | ||
| #endif | ||
| private: | ||
| void* ptr_; | ||
| }; | ||
|
|
||
| } // namespace detail | ||
| } // namespace majel | ||
|
|
||
| namespace majel { | ||
|
|
||
| Allocation::Allocation() : Allocation(0, get_place()) {} | ||
|
|
||
| Allocation::Allocation(size_t size) : Allocation(size, get_place()) {} | ||
|
|
||
| Allocation::Allocation(size_t size, Place place) | ||
| : owned_(true), size_(size), place_(place) { | ||
| if (size > 0) { | ||
| majel::detail::Allocator allocator(size_); | ||
| ptr_ = boost::apply_visitor(allocator, place_); | ||
| if (ptr_ == nullptr) { | ||
| throw std::bad_alloc(); | ||
| } | ||
| } else { | ||
| ptr_ = nullptr; | ||
| } | ||
| } | ||
|
|
||
| Allocation::Allocation(void* ptr, size_t size, Place place) | ||
| : owned_(false), ptr_(ptr), size_(size), place_(place) {} | ||
|
|
||
| Allocation::~Allocation() { | ||
| // If we don't own this allocation don't try to deallocate it | ||
| if (!owned_) { | ||
| return; | ||
| } | ||
|
|
||
| if (ptr_ != nullptr) { | ||
| majel::detail::Deallocator deallocator(ptr_); | ||
|
|
||
| boost::apply_visitor(deallocator, place_); | ||
| } | ||
| } | ||
|
|
||
| void* Allocation::ptr() const { return ptr_; } | ||
|
|
||
| void* Allocation::end() const { return (uint8_t*)ptr_ + size_; } | ||
|
|
||
| size_t Allocation::size() const { return size_; } | ||
|
|
||
| Place Allocation::place() const { return place_; } | ||
|
|
||
| } // namespace majel |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| #pragma once | ||
|
|
||
| #include "paddle/majel/place.h" | ||
|
|
||
| namespace majel { | ||
|
|
||
| class Allocation { | ||
| public: | ||
| Allocation(); | ||
| Allocation(size_t size); | ||
| Allocation(size_t size, Place place); | ||
|
|
||
| // Creates a non-owned allocation (an allocation not owned by the Majel | ||
| // memory allocator); non-owned allocations are not cleaned up in the | ||
| // destructor. | ||
| Allocation(void* ptr, size_t size, Place place); | ||
|
|
||
| ~Allocation(); | ||
| // No copying! | ||
| Allocation(const Allocation&) = delete; | ||
| // No assigning! | ||
| Allocation& operator=(const Allocation&) = delete; | ||
|
|
||
| void* ptr() const; | ||
| void* end() const; | ||
| Place place() const; | ||
| size_t size() const; | ||
|
|
||
| private: | ||
| bool owned_; | ||
| void* ptr_; | ||
| size_t size_; | ||
| Place place_; | ||
| }; | ||
|
|
||
| } // namespace majel |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| #pragma once | ||
|
|
||
| #include "paddle/majel/allocation.h" | ||
| #include "paddle/majel/place.h" | ||
|
|
||
| namespace majel { | ||
|
|
||
| class Buffer { | ||
| public: | ||
| Buffer() | ||
| : external_address_(nullptr), | ||
| allocation_(std::make_shared<Allocation>(0)) {} | ||
| Buffer(void* address) | ||
| : external_address_(address), | ||
| allocation_(std::make_shared<Allocation>(0)) {} | ||
| Buffer(void* address, Place p) | ||
| : external_address_(address), | ||
| allocation_(std::make_shared<Allocation>(0, p)) {} | ||
| Buffer(std::shared_ptr<Allocation> allocation) | ||
| : external_address_(nullptr), allocation_(allocation) {} | ||
|
|
||
| public: | ||
| void* get_address() const { | ||
| if (allocation_->ptr() == nullptr) { | ||
| return external_address_; | ||
| } | ||
|
|
||
| return allocation_->ptr(); | ||
| } | ||
|
|
||
| Place get_place() const { return allocation_->place(); } | ||
|
|
||
| std::shared_ptr<Allocation> data() const { return allocation_; } | ||
|
|
||
| private: | ||
| void* external_address_; | ||
| std::shared_ptr<Allocation> allocation_; | ||
| }; | ||
|
|
||
| } // namespace majel |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| #include "paddle/majel/malloc.h" | ||
| #include <glog/logging.h> | ||
| #include <memory> | ||
|
|
||
| #ifndef PADDLE_ONLY_CPU | ||
| #include <cuda_runtime.h> | ||
| #endif | ||
|
|
||
| namespace majel { | ||
| namespace malloc { | ||
| namespace detail { | ||
| #ifndef PADDLE_ONLY_CPU | ||
| const char* get_cuda_error_string() { | ||
| cudaError_t err = cudaGetLastError(); | ||
| return cudaGetErrorString(err); | ||
| } | ||
|
|
||
| const char* get_cuda_error_string(size_t err) { | ||
| return cudaGetErrorString((cudaError_t)err); | ||
| } | ||
|
|
||
| void* malloc_cuda(size_t size) { | ||
| void* dest_d; | ||
| cudaError_t result = cudaMalloc((void**)&dest_d, size); | ||
| if (result == cudaSuccess) { | ||
| return dest_d; | ||
| } | ||
|
|
||
| cudaGetLastError(); | ||
| return nullptr; | ||
| } | ||
|
|
||
| void free_cuda(void* dest_d) { | ||
| CHECK_NOTNULL(dest_d); | ||
|
|
||
| cudaError_t err = cudaFree(dest_d); | ||
| CHECK(cudaSuccess == err || cudaErrorCudartUnloading == err) | ||
| << get_cuda_error_string(); | ||
| } | ||
| #endif | ||
|
|
||
| class DefaultAllocator { | ||
| public: | ||
| static void* malloc(majel::Place place, size_t size); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error __must_check malloc(majel::Place place, size_t size, void** ptr);
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At first, I think that if
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is common to expose the out-of-memory error to the client code, and I think it is the client code's responsibility to recover the error by either print some error message or try another device. But I'd make the conclusion here to keep @QiJune 's original function definition because it follows C malloc's signature. |
||
|
|
||
| static void free(majel::Place, void* ptr); | ||
| }; | ||
|
|
||
| class DefaultAllocatorMallocVisitor : public boost::static_visitor<void*> { | ||
| public: | ||
| DefaultAllocatorMallocVisitor(size_t size) : size_(size) {} | ||
|
|
||
| void* operator()(majel::CpuPlace p) { | ||
| void* address; | ||
| posix_memalign(&address, 32ul, size_); | ||
| return address; | ||
| } | ||
|
|
||
| #ifndef PADDLE_ONLY_CPU | ||
| void* operator()(majel::GpuPlace p) { | ||
| void* address = malloc_cuda(size_); | ||
| return address; | ||
| } | ||
| #endif | ||
|
|
||
| private: | ||
| size_t size_; | ||
| }; | ||
|
|
||
| class DefaultAllocatorFreeVisitor : public boost::static_visitor<void> { | ||
| public: | ||
| DefaultAllocatorFreeVisitor(void* ptr) : ptr_(ptr) {} | ||
| void operator()(majel::CpuPlace p) { | ||
| if (ptr_) { | ||
| ::free(ptr_); | ||
| } | ||
| } | ||
|
|
||
| #ifndef PADDLE_ONLY_CPU | ||
| void operator()(majel::GpuPlace p) { | ||
| if (ptr_) { | ||
| free_cuda(ptr_); | ||
| } | ||
| } | ||
| #endif | ||
|
|
||
| private: | ||
| void* ptr_; | ||
| }; | ||
|
|
||
| void* DefaultAllocator::malloc(majel::Place place, size_t size) { | ||
| DefaultAllocatorMallocVisitor visitor(size); | ||
| return boost::apply_visitor(visitor, place); | ||
| } | ||
|
|
||
| void DefaultAllocator::free(majel::Place place, void* ptr) { | ||
| DefaultAllocatorFreeVisitor visitor(ptr); | ||
| boost::apply_visitor(visitor, place); | ||
| } | ||
|
|
||
| } // namespace detail | ||
| } // namespace malloc | ||
| } // namespace majel | ||
| namespace majel { | ||
| namespace malloc { | ||
|
|
||
| void* malloc(majel::Place place, size_t size) { | ||
| return detail::DefaultAllocator::malloc(place, size); | ||
| } | ||
|
|
||
| void free(majel::Place place, void* ptr) { | ||
| detail::DefaultAllocator::free(place, ptr); | ||
| } | ||
| } // namespace malloc | ||
| } // namespace majel | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| #pragma once | ||
| #include "paddle/majel/place.h" | ||
|
|
||
| namespace majel { | ||
| namespace malloc { | ||
|
|
||
| void* malloc(majel::Place place, size_t size); | ||
| void free(majel::Place place, void* ptr); | ||
|
|
||
| } // namespace malloc | ||
| } // namespace majel |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,33 +14,28 @@ struct CpuPlace { | |
| }; | ||
|
|
||
| struct GpuPlace { | ||
| GpuPlace(int d) : device(d) {} | ||
| GpuPlace() {} | ||
|
|
||
| // needed for variant equality comparison | ||
| inline bool operator==(const GpuPlace& o) const { return device == o.device; } | ||
| inline bool operator==(const GpuPlace&) const { return true; } | ||
|
|
||
| inline bool operator!=(const GpuPlace& o) const { return !(*this == o); } | ||
|
|
||
| GpuPlace() : GpuPlace(0) {} | ||
| int device; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get it -- why we should make GpuPlace no longer distinguish GPUs? Is that because we want to use the CUDA context to determine the current GPU? If so, I think what we need to do is not removing
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One CPU thread is binding to one GPU card. Every cpu thread will set the GPU card first, using
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am afraid that we cannot assume this? How if we are going to support OpenCL/FPGA other than CUDA? Would this assumption become a bug?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think communications between devices are costful, one neural network will be run on one device. If we are going to support OpenCL/FPGA, we can just define Then, we implement methods, such as malloc/free of corresponding device.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that we can run a neural network on one device, but when we aggregate gradients/parameters from these devices, it seems that we need to copy data from/to exact places?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will create one The gpu data block does not hold device id information, but the |
||
| inline bool operator!=(const GpuPlace&) const { return false; } | ||
| }; | ||
|
|
||
| class IsGpuPlace : public boost::static_visitor<bool> { | ||
| public: | ||
| bool operator()(const CpuPlace&) const { return false; } | ||
|
|
||
| bool operator()(const GpuPlace& gpu) const { return true; } | ||
| bool operator()(const GpuPlace&) const { return true; } | ||
| }; | ||
|
|
||
| typedef boost::variant<GpuPlace, CpuPlace> Place; | ||
|
|
||
| void set_place(const Place&); | ||
| #ifndef PADDLE_ONLY_CPU | ||
| typedef boost::variant<CpuPlace, GpuPlace> Place; | ||
| #else | ||
| typedef boost::variant<CpuPlace> Place; | ||
| #endif | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果编译CPU版本的话,那么Place里面只能接受CpuPlace;这个时候给一个Array传递GpuPlace,就会在编译的时候报错。 |
||
|
|
||
| const Place& get_place(); | ||
|
|
||
| const GpuPlace default_gpu(); | ||
| const CpuPlace default_cpu(); | ||
|
|
||
| bool is_gpu_place(const Place&); | ||
| bool is_cpu_place(const Place&); | ||
| bool places_are_same_class(const Place&, const Place&); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to its name, class
DefaultAllocatorshould be inallocation.{h,cc}; instead ofmalloc.{h,cc}?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Majel provides various memory management policies. Every memory management policy can be abstracted into a allocator class. Here, we just implement a simple one first,
DefalutAllocator.mallocis a global method which is responsible for memory allocation.mallocwill choose a specific memory allocation policy. AndAllocationis a memory block handled byArrayand will callmallocmethod.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with every sentence in your comment. And it seems that's the reason we should move class
Allocationtoallocation.{h,cc}?