Skip to content

Commit 1816fc2

Browse files
huwei02root
andauthored
search and fill slot_feature (#20)
* search and fill slot_feature * search and fill slot_feature, fix compile error * search and fill slot_feature, rename 8 as slot_num_ Co-authored-by: root <[email protected]>
1 parent 750e343 commit 1816fc2

File tree

12 files changed

+520
-31
lines changed

12 files changed

+520
-31
lines changed

paddle/fluid/distributed/ps/table/common_graph_table.cc

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ int32_t GraphTable::Load_to_ssd(const std::string &path,
4545
}
4646

4747
paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
48-
int ntype_id, std::vector<uint64_t> &node_ids, int slot_num) {
48+
std::vector<uint64_t> &node_ids, int slot_num) {
4949
std::vector<std::vector<uint64_t>> bags(task_pool_size_);
5050
for (auto x : node_ids) {
5151
int location = x % shard_num % task_pool_size_;
@@ -63,7 +63,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
6363
std::vector<uint64_t> feature_ids;
6464
for (size_t j = 0; j < bags[i].size(); j++) {
6565
// TODO use FEATURE_TABLE instead
66-
Node *v = find_node(1, ntype_id, bags[i][j]);
66+
Node *v = find_node(1, bags[i][j]);
6767
x.node_id = bags[i][j];
6868
if (v == NULL) {
6969
x.feature_size = 0;
@@ -85,10 +85,6 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
8585
}
8686
x.feature_size = total_feature_size;
8787
node_fea_array[i].push_back(x);
88-
VLOG(2) << "node_fea_array[i].size() = ["
89-
<< node_fea_array[i].size() << "]";
90-
VLOG(2) << "feature_array[i].size() = [" << feature_array[i].size()
91-
<< "]";
9288
}
9389
}
9490
return 0;
@@ -102,8 +98,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
10298
tot_len += feature_array[i].size();
10399
}
104100
VLOG(0) << "Loaded feature table on cpu, feature_list_size[" << tot_len
105-
<< "] node_ids_size[" << node_ids.size() << "] ntype_id[" << ntype_id
106-
<< "]";
101+
<< "] node_ids_size[" << node_ids.size() << "]";
107102
res.init_on_cpu(tot_len, (unsigned int)node_ids.size(), slot_num);
108103
unsigned int offset = 0, ind = 0;
109104
for (int i = 0; i < task_pool_size_; i++) {
@@ -1240,6 +1235,24 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge,
12401235

12411236
return 0;
12421237
}
1238+
1239+
Node *GraphTable::find_node(int type_id, uint64_t id) {
1240+
size_t shard_id = id % shard_num;
1241+
if (shard_id >= shard_end || shard_id < shard_start) {
1242+
return nullptr;
1243+
}
1244+
Node *node = nullptr;
1245+
size_t index = shard_id - shard_start;
1246+
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
1247+
for (auto& search_shard: search_shards) {
1248+
PADDLE_ENFORCE_NOT_NULL(search_shard[index]);
1249+
node = search_shard[index]->find_node(id);
1250+
if (node != nullptr) {
1251+
break;
1252+
}
1253+
}
1254+
return node;
1255+
}
12431256

12441257
Node *GraphTable::find_node(int type_id, int idx, uint64_t id) {
12451258
size_t shard_id = id % shard_num;
@@ -1537,6 +1550,30 @@ std::pair<int32_t, std::string> GraphTable::parse_feature(
15371550
return std::make_pair<int32_t, std::string>(-1, "");
15381551
}
15391552

1553+
std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int slice_num) {
1554+
std::vector<std::vector<uint64_t>> res(slice_num);
1555+
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
1556+
std::vector<std::future<std::vector<uint64_t>>> tasks;
1557+
for (int idx = 0; idx < search_shards.size(); idx++) {
1558+
for (int j = 0; j < search_shards[idx].size(); j++) {
1559+
tasks.push_back(_shards_task_pool[j % task_pool_size_]->enqueue(
1560+
[&search_shards, idx, j]() -> std::vector<uint64_t> {
1561+
return search_shards[idx][j]->get_all_id();
1562+
}));
1563+
}
1564+
}
1565+
for (size_t i = 0; i < tasks.size(); ++i) {
1566+
tasks[i].wait();
1567+
}
1568+
for (size_t i = 0; i < tasks.size(); i++) {
1569+
auto ids = tasks[i].get();
1570+
for (auto &id : ids) {
1571+
res[(uint64_t)(id) % slice_num].push_back(id);
1572+
}
1573+
}
1574+
return res;
1575+
}
1576+
15401577
std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int idx,
15411578
int slice_num) {
15421579
std::vector<std::vector<uint64_t>> res(slice_num);
@@ -1559,6 +1596,28 @@ std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int idx,
15591596
}
15601597
return res;
15611598
}
1599+
1600+
std::vector<std::vector<uint64_t>> GraphTable::get_all_feature_ids(int type_id, int idx,
1601+
int slice_num) {
1602+
std::vector<std::vector<uint64_t>> res(slice_num);
1603+
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
1604+
std::vector<std::future<std::vector<uint64_t>>> tasks;
1605+
for (int i = 0; i < search_shards.size(); i++) {
1606+
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
1607+
[&search_shards, i]() -> std::vector<uint64_t> {
1608+
return search_shards[i]->get_all_feature_ids();
1609+
}));
1610+
}
1611+
for (size_t i = 0; i < tasks.size(); ++i) {
1612+
tasks[i].wait();
1613+
}
1614+
for (size_t i = 0; i < tasks.size(); i++) {
1615+
auto ids = tasks[i].get();
1616+
for (auto &id : ids) res[id % slice_num].push_back(id);
1617+
}
1618+
return res;
1619+
}
1620+
15621621
int32_t GraphTable::pull_graph_list(int type_id, int idx, int start,
15631622
int total_size,
15641623
std::unique_ptr<char[]> &buffer,

paddle/fluid/distributed/ps/table/common_graph_table.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ class GraphShard {
7070
}
7171
return res;
7272
}
73+
std::vector<uint64_t> get_all_feature_ids() {
74+
// TODO by huwei02, dedup
75+
std::vector<uint64_t> total_res;
76+
for (int i = 0; i < (int)bucket.size(); i++) {
77+
std::vector<uint64_t> res;
78+
res.push_back(bucket[i]->get_feature_ids(&res));
79+
total_res.insert(total_res.end(), res.begin(), res.end());
80+
}
81+
return total_res;
82+
}
7383
GraphNode *add_graph_node(uint64_t id);
7484
GraphNode *add_graph_node(Node *node);
7585
FeatureNode *add_feature_node(uint64_t id);
@@ -475,8 +485,11 @@ class GraphTable : public Table {
475485
int32_t load_edges(const std::string &path, bool reverse,
476486
const std::string &edge_type);
477487

488+
std::vector<std::vector<uint64_t>> get_all_id(int type, int slice_num);
478489
std::vector<std::vector<uint64_t>> get_all_id(int type, int idx,
479490
int slice_num);
491+
std::vector<std::vector<uint64_t>> get_all_feature_ids(int type, int idx,
492+
int slice_num);
480493
int32_t load_nodes(const std::string &path, std::string node_type);
481494

482495
int32_t add_graph_node(int idx, std::vector<uint64_t> &id_list,
@@ -486,6 +499,7 @@ class GraphTable : public Table {
486499

487500
int32_t get_server_index_by_id(uint64_t id);
488501
Node *find_node(int type_id, int idx, uint64_t id);
502+
Node *find_node(int type_id, uint64_t id);
489503

490504
virtual int32_t Pull(TableContext &context) { return 0; }
491505
virtual int32_t Push(TableContext &context) { return 0; }
@@ -561,7 +575,7 @@ class GraphTable : public Table {
561575
virtual paddle::framework::GpuPsCommGraph make_gpu_ps_graph(
562576
int idx, std::vector<uint64_t> ids);
563577
virtual paddle::framework::GpuPsCommGraphFea make_gpu_ps_graph_fea(
564-
int ntype_id, std::vector<uint64_t> &node_ids, int slot_num);
578+
std::vector<uint64_t> &node_ids, int slot_num);
565579
int32_t Load_to_ssd(const std::string &path, const std::string &param);
566580
int64_t load_graph_to_memory_from_ssd(int idx, std::vector<uint64_t> &ids);
567581
int32_t make_complementary_graph(int idx, int64_t byte_size);

paddle/fluid/distributed/ps/table/graph/graph_node.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ class Node {
5050
virtual void to_buffer(char *buffer, bool need_feature);
5151
virtual void recover_from_buffer(char *buffer);
5252
virtual std::string get_feature(int idx) { return std::string(""); }
53+
virtual int get_feature_ids(std::vector<uint64_t> *res) const {
54+
return 0;
55+
}
5356
virtual int get_feature_ids(int slot_idx, std::vector<uint64_t> *res) const {
5457
return 0;
5558
}
@@ -102,6 +105,25 @@ class FeatureNode : public Node {
102105
}
103106
}
104107

108+
virtual int get_feature_ids(std::vector<uint64_t> *res) const {
109+
PADDLE_ENFORCE_NOT_NULL(res);
110+
res->clear();
111+
errno = 0;
112+
for (auto& feature_item: feature) {
113+
const char *feat_str = feature_item.c_str();
114+
auto fields = paddle::string::split_string<std::string>(feat_str, " ");
115+
char *head_ptr = NULL;
116+
for (auto &field : fields) {
117+
PADDLE_ENFORCE_EQ(field.empty(), false);
118+
uint64_t feasign = strtoull(field.c_str(), &head_ptr, 10);
119+
PADDLE_ENFORCE_EQ(field.c_str() + field.length(), head_ptr);
120+
res->push_back(feasign);
121+
}
122+
}
123+
PADDLE_ENFORCE_EQ(errno, 0);
124+
return 0;
125+
}
126+
105127
virtual int get_feature_ids(int slot_idx, std::vector<uint64_t> *res) const {
106128
PADDLE_ENFORCE_NOT_NULL(res);
107129
res->clear();

0 commit comments

Comments
 (0)