@@ -45,7 +45,7 @@ int32_t GraphTable::Load_to_ssd(const std::string &path,
4545}
4646
4747paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea (
48- int ntype_id, std::vector<uint64_t > &node_ids, int slot_num) {
48+ std::vector<uint64_t > &node_ids, int slot_num) {
4949 std::vector<std::vector<uint64_t >> bags (task_pool_size_);
5050 for (auto x : node_ids) {
5151 int location = x % shard_num % task_pool_size_;
@@ -63,7 +63,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
6363 std::vector<uint64_t > feature_ids;
6464 for (size_t j = 0 ; j < bags[i].size (); j++) {
6565 // TODO use FEATURE_TABLE instead
66- Node *v = find_node (1 , ntype_id, bags[i][j]);
66+ Node *v = find_node (1 , bags[i][j]);
6767 x.node_id = bags[i][j];
6868 if (v == NULL ) {
6969 x.feature_size = 0 ;
@@ -85,10 +85,6 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
8585 }
8686 x.feature_size = total_feature_size;
8787 node_fea_array[i].push_back (x);
88- VLOG (2 ) << " node_fea_array[i].size() = ["
89- << node_fea_array[i].size () << " ]" ;
90- VLOG (2 ) << " feature_array[i].size() = [" << feature_array[i].size ()
91- << " ]" ;
9288 }
9389 }
9490 return 0 ;
@@ -102,8 +98,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
10298 tot_len += feature_array[i].size ();
10399 }
104100 VLOG (0 ) << " Loaded feature table on cpu, feature_list_size[" << tot_len
105- << " ] node_ids_size[" << node_ids.size () << " ] ntype_id[" << ntype_id
106- << " ]" ;
101+ << " ] node_ids_size[" << node_ids.size () << " ]" ;
107102 res.init_on_cpu (tot_len, (unsigned int )node_ids.size (), slot_num);
108103 unsigned int offset = 0 , ind = 0 ;
109104 for (int i = 0 ; i < task_pool_size_; i++) {
@@ -1240,6 +1235,24 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge,
12401235
12411236 return 0 ;
12421237}
1238+
1239+ Node *GraphTable::find_node (int type_id, uint64_t id) {
1240+ size_t shard_id = id % shard_num;
1241+ if (shard_id >= shard_end || shard_id < shard_start) {
1242+ return nullptr ;
1243+ }
1244+ Node *node = nullptr ;
1245+ size_t index = shard_id - shard_start;
1246+ auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
1247+ for (auto & search_shard: search_shards) {
1248+ PADDLE_ENFORCE_NOT_NULL (search_shard[index]);
1249+ node = search_shard[index]->find_node (id);
1250+ if (node != nullptr ) {
1251+ break ;
1252+ }
1253+ }
1254+ return node;
1255+ }
12431256
12441257Node *GraphTable::find_node (int type_id, int idx, uint64_t id) {
12451258 size_t shard_id = id % shard_num;
@@ -1537,6 +1550,30 @@ std::pair<int32_t, std::string> GraphTable::parse_feature(
15371550 return std::make_pair<int32_t , std::string>(-1 , " " );
15381551}
15391552
1553+ std::vector<std::vector<uint64_t >> GraphTable::get_all_id (int type_id, int slice_num) {
1554+ std::vector<std::vector<uint64_t >> res (slice_num);
1555+ auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
1556+ std::vector<std::future<std::vector<uint64_t >>> tasks;
1557+ for (int idx = 0 ; idx < search_shards.size (); idx++) {
1558+ for (int j = 0 ; j < search_shards[idx].size (); j++) {
1559+ tasks.push_back (_shards_task_pool[j % task_pool_size_]->enqueue (
1560+ [&search_shards, idx, j]() -> std::vector<uint64_t > {
1561+ return search_shards[idx][j]->get_all_id ();
1562+ }));
1563+ }
1564+ }
1565+ for (size_t i = 0 ; i < tasks.size (); ++i) {
1566+ tasks[i].wait ();
1567+ }
1568+ for (size_t i = 0 ; i < tasks.size (); i++) {
1569+ auto ids = tasks[i].get ();
1570+ for (auto &id : ids) {
1571+ res[(uint64_t )(id) % slice_num].push_back (id);
1572+ }
1573+ }
1574+ return res;
1575+ }
1576+
15401577std::vector<std::vector<uint64_t >> GraphTable::get_all_id (int type_id, int idx,
15411578 int slice_num) {
15421579 std::vector<std::vector<uint64_t >> res (slice_num);
@@ -1559,6 +1596,28 @@ std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int idx,
15591596 }
15601597 return res;
15611598}
1599+
1600+ std::vector<std::vector<uint64_t >> GraphTable::get_all_feature_ids (int type_id, int idx,
1601+ int slice_num) {
1602+ std::vector<std::vector<uint64_t >> res (slice_num);
1603+ auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
1604+ std::vector<std::future<std::vector<uint64_t >>> tasks;
1605+ for (int i = 0 ; i < search_shards.size (); i++) {
1606+ tasks.push_back (_shards_task_pool[i % task_pool_size_]->enqueue (
1607+ [&search_shards, i]() -> std::vector<uint64_t > {
1608+ return search_shards[i]->get_all_feature_ids ();
1609+ }));
1610+ }
1611+ for (size_t i = 0 ; i < tasks.size (); ++i) {
1612+ tasks[i].wait ();
1613+ }
1614+ for (size_t i = 0 ; i < tasks.size (); i++) {
1615+ auto ids = tasks[i].get ();
1616+ for (auto &id : ids) res[id % slice_num].push_back (id);
1617+ }
1618+ return res;
1619+ }
1620+
15621621int32_t GraphTable::pull_graph_list (int type_id, int idx, int start,
15631622 int total_size,
15641623 std::unique_ptr<char []> &buffer,
0 commit comments