Skip to content
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/majel/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
cc_library(place SRCS place.cc)
cc_library(ddim SRCS ddim.cc)
cc_library(malloc SRCS malloc.cc)
cc_library(allocation SRCS allocation.cc)

if(WITH_TESTING)
add_subdirectory(test)
Expand Down
94 changes: 94 additions & 0 deletions paddle/majel/allocation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include "paddle/majel/allocation.h"
#include <boost/variant.hpp>
#include "paddle/majel/malloc.h"

namespace majel {
namespace detail {

class Allocator : public boost::static_visitor<void*> {
public:
Allocator(size_t size) : size_(size) {}

void* operator()(const CpuPlace& p) const {
void* address = majel::malloc::malloc(p, size_);
return address;
}

#ifndef PADDLE_ONLY_CPU
void* operator()(const GpuPlace& p) const {
void* address = majel::malloc::malloc(p, size_);
return address;
}
#endif

private:
size_t size_;
};

class Deallocator : public boost::static_visitor<> {
public:
Deallocator(void* ptr) : ptr_(ptr) {}

void operator()(CpuPlace p) const {
if (ptr_) {
majel::malloc::free(p, ptr_);
}
}
#ifndef PADDLE_ONLY_CPU
void operator()(GpuPlace p) const {
if (ptr_) {
majel::malloc::free(p, ptr_);
}
}
#endif
private:
void* ptr_;
};

} // namespace detail
} // namespace majel

namespace majel {

Allocation::Allocation() : Allocation(0, get_place()) {}

Allocation::Allocation(size_t size) : Allocation(size, get_place()) {}

Allocation::Allocation(size_t size, Place place)
: owned_(true), size_(size), place_(place) {
if (size > 0) {
majel::detail::Allocator allocator(size_);
ptr_ = boost::apply_visitor(allocator, place_);
if (ptr_ == nullptr) {
throw std::bad_alloc();
}
} else {
ptr_ = nullptr;
}
}

Allocation::Allocation(void* ptr, size_t size, Place place)
: owned_(false), ptr_(ptr), size_(size), place_(place) {}

Allocation::~Allocation() {
// If we don't own this allocation don't try to deallocate it
if (!owned_) {
return;
}

if (ptr_ != nullptr) {
majel::detail::Deallocator deallocator(ptr_);

boost::apply_visitor(deallocator, place_);
}
}

void* Allocation::ptr() const { return ptr_; }

void* Allocation::end() const { return (uint8_t*)ptr_ + size_; }

size_t Allocation::size() const { return size_; }

Place Allocation::place() const { return place_; }

} // namespace majel
36 changes: 36 additions & 0 deletions paddle/majel/allocation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include "paddle/majel/place.h"

namespace majel {

class Allocation {
public:
Allocation();
Allocation(size_t size);
Allocation(size_t size, Place place);

// Creates a non-owned allocation (an allocation not owned by the Majel
// memory allocator); non-owned allocations are not cleaned up in the
// destructor.
Allocation(void* ptr, size_t size, Place place);

~Allocation();
// No copying!
Allocation(const Allocation&) = delete;
// No assigning!
Allocation& operator=(const Allocation&) = delete;

void* ptr() const;
void* end() const;
Place place() const;
size_t size() const;

private:
bool owned_;
void* ptr_;
size_t size_;
Place place_;
};

} // namespace majel
40 changes: 40 additions & 0 deletions paddle/majel/buffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#pragma once

#include "paddle/majel/allocation.h"
#include "paddle/majel/place.h"

namespace majel {

class Buffer {
public:
Buffer()
: external_address_(nullptr),
allocation_(std::make_shared<Allocation>(0)) {}
Buffer(void* address)
: external_address_(address),
allocation_(std::make_shared<Allocation>(0)) {}
Buffer(void* address, Place p)
: external_address_(address),
allocation_(std::make_shared<Allocation>(0, p)) {}
Buffer(std::shared_ptr<Allocation> allocation)
: external_address_(nullptr), allocation_(allocation) {}

public:
void* get_address() const {
if (allocation_->ptr() == nullptr) {
return external_address_;
}

return allocation_->ptr();
}

Place get_place() const { return allocation_->place(); }

std::shared_ptr<Allocation> data() const { return allocation_; }

private:
void* external_address_;
std::shared_ptr<Allocation> allocation_;
};

} // namespace majel
115 changes: 115 additions & 0 deletions paddle/majel/malloc.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include "paddle/majel/malloc.h"
#include <glog/logging.h>
#include <memory>

#ifndef PADDLE_ONLY_CPU
#include <cuda_runtime.h>
#endif

namespace majel {
namespace malloc {
namespace detail {
#ifndef PADDLE_ONLY_CPU
const char* get_cuda_error_string() {
cudaError_t err = cudaGetLastError();
return cudaGetErrorString(err);
}

const char* get_cuda_error_string(size_t err) {
return cudaGetErrorString((cudaError_t)err);
}

void* malloc_cuda(size_t size) {
void* dest_d;
cudaError_t result = cudaMalloc((void**)&dest_d, size);
if (result == cudaSuccess) {
return dest_d;
}

cudaGetLastError();
return nullptr;
}

void free_cuda(void* dest_d) {
CHECK_NOTNULL(dest_d);

cudaError_t err = cudaFree(dest_d);
CHECK(cudaSuccess == err || cudaErrorCudartUnloading == err)
<< get_cuda_error_string();
}
#endif

class DefaultAllocator {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to its name, class DefaultAllocator should be in allocation.{h,cc}; instead of malloc.{h,cc}?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Majel provides various memory management policies. Every memory management policy can be abstracted into a allocator class. Here, we just implement a simple one first, DefalutAllocator.
malloc is a global method which is responsible for memory allocation. malloc will choose a specific memory allocation policy. And Allocation is a memory block handled by Array and will call malloc method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with every sentence in your comment. And it seems that's the reason we should move class Allocation to allocation.{h,cc}?

public:
static void* malloc(majel::Place place, size_t size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error __must_check malloc(majel::Place place, size_t size, void** ptr);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first, I think that if malloc method gets an error, the client can hardly do nothing to overcome the problem. So, just let it fatal.
Second, if we check the result state of malloc, then we have to check the result state when we construct an array. It's will be quite fussy.

Copy link
Collaborator

@wangkuiyi wangkuiyi May 31, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is common to expose the out-of-memory error to the client code, and I think it is the client code's responsibility to recover the error by either print some error message or try another device.

But I'd make the conclusion here to keep @QiJune 's original function definition because it follows C malloc's signature.


static void free(majel::Place, void* ptr);
};

class DefaultAllocatorMallocVisitor : public boost::static_visitor<void*> {
public:
DefaultAllocatorMallocVisitor(size_t size) : size_(size) {}

void* operator()(majel::CpuPlace p) {
void* address;
posix_memalign(&address, 32ul, size_);
return address;
}

#ifndef PADDLE_ONLY_CPU
void* operator()(majel::GpuPlace p) {
void* address = malloc_cuda(size_);
return address;
}
#endif

private:
size_t size_;
};

class DefaultAllocatorFreeVisitor : public boost::static_visitor<void> {
public:
DefaultAllocatorFreeVisitor(void* ptr) : ptr_(ptr) {}
void operator()(majel::CpuPlace p) {
if (ptr_) {
::free(ptr_);
}
}

#ifndef PADDLE_ONLY_CPU
void operator()(majel::GpuPlace p) {
if (ptr_) {
free_cuda(ptr_);
}
}
#endif

private:
void* ptr_;
};

void* DefaultAllocator::malloc(majel::Place place, size_t size) {
DefaultAllocatorMallocVisitor visitor(size);
return boost::apply_visitor(visitor, place);
}

void DefaultAllocator::free(majel::Place place, void* ptr) {
DefaultAllocatorFreeVisitor visitor(ptr);
boost::apply_visitor(visitor, place);
}

} // namespace detail
} // namespace malloc
} // namespace majel
namespace majel {
namespace malloc {

void* malloc(majel::Place place, size_t size) {
return detail::DefaultAllocator::malloc(place, size);
}

void free(majel::Place place, void* ptr) {
detail::DefaultAllocator::free(place, ptr);
}
} // namespace malloc
} // namespace majel
11 changes: 11 additions & 0 deletions paddle/majel/malloc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once
#include "paddle/majel/place.h"

namespace majel {
namespace malloc {

void* malloc(majel::Place place, size_t size);
void free(majel::Place place, void* ptr);

} // namespace malloc
} // namespace majel
9 changes: 1 addition & 8 deletions paddle/majel/place.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,14 @@ class PlacePrinter : public boost::static_visitor<> {

void operator()(const CpuPlace&) { os_ << "CpuPlace"; }

void operator()(const GpuPlace& p) { os_ << "GpuPlace(" << p.device << ")"; }
void operator()(const GpuPlace&) { os_ << "GpuPlace"; }
};

} // namespace detail

static Place the_default_place;

void set_place(const Place& place) { the_default_place = place; }

const Place& get_place() { return the_default_place; }

const GpuPlace default_gpu() { return GpuPlace(0); }

const CpuPlace default_cpu() { return CpuPlace(); }

bool is_gpu_place(const Place& p) {
return boost::apply_visitor(IsGpuPlace(), p);
}
Expand Down
23 changes: 9 additions & 14 deletions paddle/majel/place.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,28 @@ struct CpuPlace {
};

struct GpuPlace {
GpuPlace(int d) : device(d) {}
GpuPlace() {}

// needed for variant equality comparison
inline bool operator==(const GpuPlace& o) const { return device == o.device; }
inline bool operator==(const GpuPlace&) const { return true; }

inline bool operator!=(const GpuPlace& o) const { return !(*this == o); }

GpuPlace() : GpuPlace(0) {}
int device;
Copy link
Collaborator

@wangkuiyi wangkuiyi May 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it -- why we should make GpuPlace no longer distinguish GPUs? Is that because we want to use the CUDA context to determine the current GPU?

If so, I think what we need to do is not removing int device; from GpuPlace, but to redefine get_place to call cuCtxGetDevice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One CPU thread is binding to one GPU card. Every cpu thread will set the GPU card first, using cudaSetDevice. There is no need for the tensor hold the place which the specific GPU card is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid that we cannot assume this? How if we are going to support OpenCL/FPGA other than CUDA? Would this assumption become a bug?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think communications between devices are costful, one neural network will be run on one device. If we are going to support OpenCL/FPGA, we can just define Place as follows:

typedef boost::variant<CpuPlace, CudaPlace, OpenclPlace, FpgaPlace> Place;

Then, we implement methods, such as malloc/free of corresponding device.

Copy link
Collaborator

@wangkuiyi wangkuiyi May 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we can run a neural network on one device, but when we aggregate gradients/parameters from these devices, it seems that we need to copy data from/to exact places?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will create one Context for one device, and then set device id to specific Context. The device id will be handled in Paddle, not in tensor library. It's sure that we can copy data between different GPU cards. Following is a example:

    cudaSetDevice(1);
    cudaDeviceEnablePeerAccess(2,flags); //flags=0
    cudaSetDevice(2);
    cudaDeviceEnablePeerAccess(1,flags); //flags=0

    // Allocate some data
    float *gpu1data, *gpu2data;
    cudaSetDevice(1);
    cudaMalloc(&gpu1data, nbytes);
    cudaSetDevice(2);
    cudaMalloc(&gpu2data, nbytes);
    
    // Do the p2p copy!
    cudaMemcpy(gpu1data, gpu2data, cudaMemcpyDefault);

The gpu data block does not hold device id information, but the Context does.

inline bool operator!=(const GpuPlace&) const { return false; }
};

class IsGpuPlace : public boost::static_visitor<bool> {
public:
bool operator()(const CpuPlace&) const { return false; }

bool operator()(const GpuPlace& gpu) const { return true; }
bool operator()(const GpuPlace&) const { return true; }
};

typedef boost::variant<GpuPlace, CpuPlace> Place;

void set_place(const Place&);
#ifndef PADDLE_ONLY_CPU
typedef boost::variant<CpuPlace, GpuPlace> Place;
#else
typedef boost::variant<CpuPlace> Place;
#endif
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果编译CPU版本的话,那么Place里面只能接受CpuPlace;这个时候给一个Array传递GpuPlace,就会在编译的时候报错。


const Place& get_place();

const GpuPlace default_gpu();
const CpuPlace default_cpu();

bool is_gpu_place(const Place&);
bool is_cpu_place(const Place&);
bool places_are_same_class(const Place&, const Place&);
Expand Down
4 changes: 4 additions & 0 deletions paddle/majel/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ cc_test(ddim_test
SRCS ddim_test.cc
DEPS ddim)

cc_test(allocation_test
SRCS allocation_test.cc
DEPS allocation malloc place)

if(WITH_GPU)
nv_test(cuda_test SRCS cuda_test.cu)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
Expand Down
Loading