Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ USE_INT_STAT(STAT_total_feasign_num_in_mem);
namespace paddle {
namespace framework {

DLManager& global_dlmanager_pool() {
static DLManager manager;
return manager;
}

void RecordCandidateList::ReSize(size_t length) {
mutex_.lock();
capacity_ = length;
Expand Down Expand Up @@ -366,6 +371,10 @@ void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX
if (!so_parser_name_.empty()) {
LoadIntoMemoryFromSo();
return;
}
VLOG(3) << "LoadIntoMemory() begin, thread_id=" << thread_id_;
std::string filename;
while (this->PickOneFile(&filename)) {
Expand Down Expand Up @@ -408,6 +417,51 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
#endif
}

template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemoryFromSo() {
#ifdef _LINUX
VLOG(3) << "LoadIntoMemoryFromSo() begin, thread_id=" << thread_id_;

string::LineFileReader reader;
paddle::framework::CustomParser* parser =
global_dlmanager_pool().Load(so_parser_name_, slot_conf_);

std::string filename;
while (this->PickOneFile(&filename)) {
VLOG(3) << "PickOneFile, filename=" << filename
<< ", thread_id=" << thread_id_;
int err_no = 0;
this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_);
CHECK(this->fp_ != nullptr);
__fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER);

paddle::framework::ChannelWriter<T> writer(input_channel_);
T instance;
platform::Timer timeline;
timeline.Start();

while (1) {
if (!reader.getline(&*(fp_.get()))) {
break;
} else {
const char* str = reader.get();
ParseOneInstanceFromSo(str, &instance, parser);
}

writer << std::move(instance);
instance = T();
}

writer.Flush();
timeline.Pause();
VLOG(3) << "LoadIntoMemoryFromSo() read all lines, file=" << filename
<< ", cost time=" << timeline.ElapsedSec()
<< " seconds, thread_id=" << thread_id_;
}
VLOG(3) << "LoadIntoMemoryFromSo() end, thread_id=" << thread_id_;
#endif
}

// explicit instantiation
template class InMemoryDataFeed<Record>;

Expand Down Expand Up @@ -827,16 +881,23 @@ void MultiSlotInMemoryDataFeed::Init(
inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
slot_conf_.resize(all_slot_num);
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;

slot_conf_[i].name = slot.name();
slot_conf_[i].type = slot.type();
slot_conf_[i].use_slots_index = use_slots_index_[i];

total_dims_without_inductive_[i] = 1;
inductive_shape_index_[i] = -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
slot_conf_[i].use_slots_is_dense = slot.is_dense();
std::vector<int> local_shape;
if (slot.is_dense()) {
for (int j = 0; j < slot.shape_size(); ++j) {
Expand Down Expand Up @@ -869,6 +930,7 @@ void MultiSlotInMemoryDataFeed::Init(
}
visit_.resize(all_slot_num, false);
pipe_command_ = data_feed_desc.pipe_command();
so_parser_name_ = data_feed_desc.so_parser_name();
finish_init_ = true;
input_type_ = data_feed_desc.input_type();
}
Expand All @@ -887,6 +949,12 @@ void MultiSlotInMemoryDataFeed::GetMsgFromLogKey(const std::string& log_key,
*rank = (uint32_t)strtoul(rank_str.c_str(), NULL, 16);
}

void MultiSlotInMemoryDataFeed::ParseOneInstanceFromSo(const char* str,
Record* instance,
CustomParser* parser) {
parser->ParseOneInstance(str, instance);
}

bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
#ifdef _LINUX
thread_local string::LineFileReader reader;
Expand Down
95 changes: 95 additions & 0 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,94 @@ using PvInstance = PvInstanceObject*;

inline PvInstance make_pv_instance() { return new PvInstanceObject(); }

struct SlotConf {
std::string name;
std::string type;
int use_slots_index;
int use_slots_is_dense;
};

class CustomParser {
public:
CustomParser() {}
virtual ~CustomParser() {}
virtual void Init(const std::vector<SlotConf>& slots) = 0;
virtual void ParseOneInstance(const char* str, Record* instance) = 0;
};

typedef paddle::framework::CustomParser* (*CreateParserObjectFunc)();

class DLManager {
struct DLHandle {
void* module;
paddle::framework::CustomParser* parser;
};

public:
DLManager() {}

~DLManager() {
#ifdef _LINUX
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = handle_map_.begin(); it != handle_map_.end(); ++it) {
delete it->second.parser;
dlclose(it->second.module);
}
#endif
}

bool Close(const std::string& name) {
#ifdef _LINUX
auto it = handle_map_.find(name);
if (it == handle_map_.end()) {
return true;
}
delete it->second.parser;
dlclose(it->second.module);
#endif
VLOG(0) << "Not implement in windows";
return false;
}

paddle::framework::CustomParser* Load(const std::string& name,
std::vector<SlotConf>& conf) {
#ifdef _LINUX
std::lock_guard<std::mutex> lock(mutex_);
DLHandle handle;
std::map<std::string, DLHandle>::iterator it = handle_map_.find(name);
if (it != handle_map_.end()) {
return it->second.parser;
}

handle.module = dlopen(name.c_str(), RTLD_NOW);
if (handle.module == nullptr) {
VLOG(0) << "Create so of " << name << " fail";
return nullptr;
}

CreateParserObjectFunc create_parser_func =
(CreateParserObjectFunc)dlsym(handle.module, "CreateParserObject");
handle.parser = create_parser_func();
handle.parser->Init(conf);
handle_map_.insert({name, handle});

return handle.parser;
#endif
VLOG(0) << "Not implement in windows";
return nullptr;
}

paddle::framework::CustomParser* ReLoad(const std::string& name,
std::vector<SlotConf>& conf) {
Close(name);
return Load(name, conf);
}

private:
std::mutex mutex_;
std::map<std::string, DLHandle> handle_map_;
};

class DataFeed {
public:
DataFeed() {
Expand Down Expand Up @@ -252,6 +340,8 @@ class DataFeed {
bool finish_set_filelist_;
bool finish_start_;
std::string pipe_command_;
std::string so_parser_name_;
std::vector<SlotConf> slot_conf_;
std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_;
platform::Place place_;
Expand Down Expand Up @@ -324,10 +414,13 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetCurrentPhase(int current_phase);
virtual void LoadIntoMemory();
virtual void LoadIntoMemoryFromSo();

protected:
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void ParseOneInstanceFromSo(const char* str, T* instance,
CustomParser* parser) {}
virtual void PutToFeedVec(const std::vector<T>& ins_vec) = 0;

int thread_id_;
Expand Down Expand Up @@ -688,6 +781,8 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
protected:
virtual bool ParseOneInstance(Record* instance);
virtual bool ParseOneInstanceFromPipe(Record* instance);
virtual void ParseOneInstanceFromSo(const char* str, Record* instance,
CustomParser* parser);
virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id,
uint32_t* cmatch, uint32_t* rank);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/data_feed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ message DataFeedDesc {
optional string rank_offset = 6;
optional int32 pv_batch_size = 7 [ default = 32 ];
optional int32 input_type = 8 [ default = 0 ];
optional string so_parser_name = 9;
}
17 changes: 17 additions & 0 deletions python/paddle/fluid/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,23 @@ def set_pipe_command(self, pipe_command):
"""
self.proto_desc.pipe_command = pipe_command

def set_so_parser_name(self, so_parser_name):
"""
Set so parser name of current dataset

Examples:
.. code-block:: python

import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_so_parser_name("./abc.so")

Args:
pipe_command(str): pipe command

"""
self.proto_desc.so_parser_name = so_parser_name

def set_rank_offset(self, rank_offset):
"""
Set rank_offset for merge_pv. It set the message of Pv.
Expand Down