Skip to content
Merged
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 71 additions & 52 deletions paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,29 @@ namespace paddle {
namespace operators {
namespace reader {

static constexpr size_t kDoubleBufferSize = 2;
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 2;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 0; // kCacheSize - 2

class DoubleBufferReader : public framework::DecoratedReader {
public:
struct Item {
Item() : ctx_(nullptr) {}
Item(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
}
Item& operator=(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
return *this;
}

std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_;
Expand All @@ -34,42 +51,44 @@ class DoubleBufferReader : public framework::DecoratedReader {
explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA
for (size_t i = 0; i < kCacheSize; ++i) {
if (platform::is_gpu_place(place_)) {
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
#endif
}
}

start_thread();
}

void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
#endif
StartPrefetcher();
}

bool HasNext() const override;
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;

~DoubleBufferReader() {
buffer_->Close();
prefetcher_.join();
delete buffer_;
void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
}

bool HasNext() const override;
void EndPrefetcher() {
channel_->Close();
if (prefetcher_.joinable()) {
prefetcher_.join();
}
delete channel_;
channel_ = nullptr;
}

~DoubleBufferReader() { EndPrefetcher(); }

private:
void PrefetchThreadFunc();

std::thread prefetcher_;
framework::Channel<Item>* buffer_;
framework::Channel<Item>* channel_;
platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
mutable Item local_buffer_;
};

class CreateDoubleBufferReaderOp : public framework::OperatorBase {
Expand Down Expand Up @@ -123,70 +142,70 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
};

bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it will be better to use semaphore here. Or add TODO.

}
return channel_->CanReceive();
}

void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}

if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
local_buffer_.ctx_->Wait();
Item batch;
channel_->Receive(&batch);
*out = batch.payloads_;
if (batch.ctx_) {
batch.ctx_->Wait();
}
}

void DoubleBufferReader::ReInit() {
reader_->ReInit();
buffer_->Close();
prefetcher_.join();
delete buffer_;
start_thread();
EndPrefetcher();
StartPrefetcher();
}

void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
size_t gpu_ctx_offset = 0;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0;

while (reader_->HasNext()) {
Item batch;
reader_->ReadNext(&batch.payloads_);
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch);
if (platform::is_gpu_place(place_)) {
std::vector<framework::LoDTensor> gpu_batch;
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
gpu_ctx_offset %= this->ctxs_.size();
gpu_batch.resize(batch.payloads_.size());
for (size_t i = 0; i < batch.payloads_.size(); ++i) {
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
&gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
auto* gpu_ctx = ctxs_[cached_tensor_id].get();
gpu_batch.resize(cpu_batch.size());
for (size_t i = 0; i < cpu_batch.size(); ++i) {
framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
gpu_batch[i].set_lod(cpu_batch[i].lod());
}
batch.ctx_ = gpu_ctx.get();
std::swap(gpu_batch, batch.payloads_);
batch.payloads_ = gpu_batch;
batch.ctx_ = gpu_ctx;
} else {
// CPUPlace
batch.payloads_ = cpu_batch;
}
++cached_tensor_id;
cached_tensor_id %= kCacheSize;

try {
buffer_->Send(&batch);
channel_->Send(&batch);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate.";
break;
}
}
buffer_->Close();
channel_->Close();
VLOG(5) << "Prefetch thread terminates.";
}

bool DoubleBufferReader::HasNext() const {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}

} // namespace reader
} // namespace operators
} // namespace paddle
Expand Down