Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions paddle/fluid/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,19 @@ class ChannelObject {
p.resize(finished);
return finished;
}
// read once only
size_t ReadOnce(std::vector<T>& p, size_t size) { // NOLINT
if (size == 0) {
return 0;
}
std::unique_lock<std::mutex> lock(mutex_);
p.resize(size);
size_t finished = Read(size, &p[0], lock, true);
p.resize(finished);
Notify();

return finished;
}
size_t ReadAll(std::vector<T>& p) { // NOLINT
p.clear();
size_t finished = 0;
Expand Down Expand Up @@ -241,17 +253,21 @@ class ChannelObject {
return !closed_;
}

size_t Read(size_t n, T* p, std::unique_lock<std::mutex>& lock) { // NOLINT
size_t Read(size_t n, T* p, std::unique_lock<std::mutex>& lock, // NOLINT
bool once = false) { // NOLINT
size_t finished = 0;
CHECK(n <= MaxCapacity() - reading_count_);
reading_count_ += n;
while (finished < n && WaitForRead(lock)) {
size_t m = std::min(n - finished, data_.size());
size_t m = (std::min)(n - finished, data_.size());
for (size_t i = 0; i < m; i++) {
p[finished++] = std::move(data_.front());
data_.pop_front();
}
reading_count_ -= m;
if (once && m > 0) {
break;
}
}
reading_count_ -= n - finished;
return finished;
Expand Down
112 changes: 103 additions & 9 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,107 @@ DLManager& global_dlmanager_pool() {
return manager;
}

class BufferedLineFileReader {
typedef std::function<bool()> SampleFunc;
static const int MAX_FILE_BUFF_SIZE = 4 * 1024 * 1024;
class FILEReader {
public:
explicit FILEReader(FILE* fp) : fp_(fp) {}
int read(char* buf, int len) { return fread(buf, sizeof(char), len, fp_); }

private:
FILE* fp_;
};

public:
typedef std::function<bool(const std::string&)> LineFunc;

private:
template <typename T>
int read_lines(T* reader, LineFunc func, int skip_lines) {
int lines = 0;
size_t ret = 0;
char* ptr = NULL;
char* eol = NULL;
total_len_ = 0;
error_line_ = 0;

SampleFunc spfunc = get_sample_func();
std::string x;
while (!is_error() && (ret = reader->read(buff_, MAX_FILE_BUFF_SIZE)) > 0) {
total_len_ += ret;
ptr = buff_;
eol = reinterpret_cast<char*>(memchr(ptr, '\n', ret));
while (eol != NULL) {
int size = static_cast<int>((eol - ptr) + 1);
x.append(ptr, size - 1);
++lines;
if (lines > skip_lines && spfunc()) {
if (!func(x)) {
++error_line_;
}
}

x.clear();
ptr += size;
ret -= size;
eol = reinterpret_cast<char*>(memchr(ptr, '\n', ret));
}
if (ret > 0) {
x.append(ptr, ret);
}
}
if (!is_error() && !x.empty()) {
++lines;
if (lines > skip_lines && spfunc()) {
if (!func(x)) {
++error_line_;
}
}
}
return lines;
}

public:
BufferedLineFileReader()
: random_engine_(std::random_device()()),
uniform_distribution_(0.0f, 1.0f) {
total_len_ = 0;
sample_line_ = 0;
buff_ =
reinterpret_cast<char*>(calloc(MAX_FILE_BUFF_SIZE + 1, sizeof(char)));
}
~BufferedLineFileReader() { free(buff_); }

int read_file(FILE* fp, LineFunc func, int skip_lines) {
FILEReader reader(fp);
return read_lines<FILEReader>(&reader, func, skip_lines);
}
uint64_t file_size(void) { return total_len_; }
void set_sample_rate(float r) { sample_rate_ = r; }
size_t get_sample_line() { return sample_line_; }
bool is_error(void) { return (error_line_ > 10); }

private:
SampleFunc get_sample_func() {
if (std::abs(sample_rate_ - 1.0f) < 1e-5f) {
return [this](void) { return true; };
}
return [this](void) {
return (uniform_distribution_(random_engine_) < sample_rate_);
};
}

private:
char* buff_ = nullptr;
uint64_t total_len_ = 0;

std::default_random_engine random_engine_;
std::uniform_real_distribution<float> uniform_distribution_;
float sample_rate_ = 1.0f;
size_t sample_line_ = 0;
size_t error_line_ = 0;
};
void RecordCandidateList::ReSize(size_t length) {
mutex_.lock();
capacity_ = length;
Expand Down Expand Up @@ -301,7 +402,7 @@ int InMemoryDataFeed<T>::Next() {
<< ", thread_id=" << thread_id_;
}
} else {
VLOG(3) << "enable heter NEXT: " << offset_index_
VLOG(3) << "enable heter next: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size();
if (offset_index_ >= batch_offsets_.size()) {
VLOG(3) << "offset_index: " << offset_index_
Expand All @@ -318,14 +419,7 @@ int InMemoryDataFeed<T>::Next() {
VLOG(3) << "finish reading for heterps, batch size zero, thread_id="
<< thread_id_;
}
/*
if (offset_index_ == batch_offsets_.size() - 1) {
std::vector<Record> data;
output_channel_->ReadAll(data);
consume_channel_->Write(std::move(data));
}
*/
VLOG(3) << "#15 enable heter NEXT: " << offset_index_
VLOG(3) << "enable heter next: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size()
<< " baych_size: " << this->batch_size_;
}
Expand Down
Loading