Skip to content

Commit 3e68780

Browse files
authored
Merge pull request #9 from Yelrose/develop
Add Table Name; Feature Info
2 parents 5f6c168 + 7d51520 commit 3e68780

File tree

5 files changed

+148
-5
lines changed

5 files changed

+148
-5
lines changed

paddle/fluid/distributed/service/graph_py_service.cc

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,25 @@ std::vector<std::string> GraphPyService::split(std::string& str,
3232
return res;
3333
}
3434

35+
36+
void GraphPyService::add_table_feat_conf(std::string table_name,
37+
std::string feat_name,
38+
std::string feat_dtype,
39+
int32_t feat_shape) {
40+
if(this->table_id_map.count(table_name)) {
41+
this->table_feat_conf_table_name.push_back(table_name);
42+
this->table_feat_conf_feat_name.push_back(feat_name);
43+
this->table_feat_conf_feat_dtype.push_back(feat_dtype);
44+
this->table_feat_conf_feat_shape.push_back(feat_shape);
45+
}
46+
}
47+
48+
3549
void GraphPyService::set_up(std::string ips_str, int shard_num,
3650
std::vector<std::string> node_types,
3751
std::vector<std::string> edge_types) {
3852
set_shard_num(shard_num);
53+
set_num_node_types(node_types.size());
3954
// set_client_Id(client_id);
4055
// set_rank(rank);
4156

@@ -121,7 +136,31 @@ ::paddle::distributed::PSParameter GraphPyServer::GetServerProto() {
121136
VLOG(0) << " make a new table " << tuple.second;
122137
::paddle::distributed::TableParameter* sparse_table_proto =
123138
downpour_server_proto->add_downpour_table_param();
124-
GetDownpourSparseTableProto(sparse_table_proto, tuple.second);
139+
std::vector<std::string > feat_name;
140+
std::vector<std::string > feat_dtype;
141+
std::vector<int32_t> feat_shape;
142+
for(size_t i=0; i<this->table_feat_conf_table_name.size(); i++) {
143+
if(tuple.first == table_feat_conf_table_name[i]) {
144+
feat_name.push_back(table_feat_conf_feat_name[i]);
145+
feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
146+
feat_shape.push_back(table_feat_conf_feat_shape[i]);
147+
}
148+
}
149+
std::string table_type;
150+
if(tuple.second < this->num_node_types) {
151+
table_type = "node";
152+
}
153+
else {
154+
table_type = "edge";
155+
}
156+
157+
GetDownpourSparseTableProto(sparse_table_proto,
158+
tuple.second,
159+
tuple.first,
160+
table_type,
161+
feat_name,
162+
feat_dtype,
163+
feat_shape);
125164
}
126165

127166
return server_fleet_desc;
@@ -137,11 +176,38 @@ ::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() {
137176
worker_proto->mutable_downpour_worker_param();
138177

139178
for (auto& tuple : this->table_id_map) {
179+
VLOG(0) << " make a new table " << tuple.second;
140180
::paddle::distributed::TableParameter* worker_sparse_table_proto =
141181
downpour_worker_proto->add_downpour_table_param();
142-
GetDownpourSparseTableProto(worker_sparse_table_proto, tuple.second);
182+
std::vector<std::string > feat_name;
183+
std::vector<std::string > feat_dtype;
184+
std::vector<int32_t> feat_shape;
185+
for(size_t i=0; i<this->table_feat_conf_table_name.size(); i++) {
186+
if(tuple.first == table_feat_conf_table_name[i]) {
187+
feat_name.push_back(table_feat_conf_feat_name[i]);
188+
feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
189+
feat_shape.push_back(table_feat_conf_feat_shape[i]);
190+
}
191+
}
192+
std::string table_type;
193+
if(tuple.second < this->num_node_types) {
194+
table_type = "node";
195+
}
196+
else {
197+
table_type = "edge";
198+
}
199+
200+
GetDownpourSparseTableProto(worker_sparse_table_proto,
201+
tuple.second,
202+
tuple.first,
203+
table_type,
204+
feat_name,
205+
feat_dtype,
206+
feat_shape);
143207
}
144208

209+
210+
145211
::paddle::distributed::ServerParameter* server_proto =
146212
worker_fleet_desc.mutable_server_param();
147213
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
@@ -155,11 +221,38 @@ ::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() {
155221
server_service_proto->set_server_thread_num(12);
156222

157223
for (auto& tuple : this->table_id_map) {
224+
VLOG(0) << " make a new table " << tuple.second;
158225
::paddle::distributed::TableParameter* sparse_table_proto =
159226
downpour_server_proto->add_downpour_table_param();
160-
GetDownpourSparseTableProto(sparse_table_proto, tuple.second);
227+
std::vector<std::string > feat_name;
228+
std::vector<std::string > feat_dtype;
229+
std::vector<int32_t> feat_shape;
230+
for(size_t i=0; i<this->table_feat_conf_table_name.size(); i++) {
231+
if(tuple.first == table_feat_conf_table_name[i]) {
232+
feat_name.push_back(table_feat_conf_feat_name[i]);
233+
feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
234+
feat_shape.push_back(table_feat_conf_feat_shape[i]);
235+
}
236+
}
237+
std::string table_type;
238+
if(tuple.second < this->num_node_types) {
239+
table_type = "node";
240+
}
241+
else {
242+
table_type = "edge";
243+
}
244+
245+
GetDownpourSparseTableProto(sparse_table_proto,
246+
tuple.second,
247+
tuple.first,
248+
table_type,
249+
feat_name,
250+
feat_dtype,
251+
feat_shape);
161252
}
162253

254+
255+
163256
return worker_fleet_desc;
164257
}
165258
void GraphPyClient::load_edge_file(std::string name, std::string filepath,
@@ -232,3 +325,4 @@ std::vector<GraphNode> GraphPyClient::pull_graph_list(std::string name,
232325
}
233326
}
234327
}
328+

paddle/fluid/distributed/service/graph_py_service.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,13 @@ class GraphPyService {
4747
protected:
4848
std::vector<std::string> server_list, port_list, host_sign_list;
4949
int server_size, shard_num;
50+
int num_node_types;
5051
std::unordered_map<std::string, uint32_t> table_id_map;
52+
std::vector<std::string> table_feat_conf_table_name;
53+
std::vector<std::string> table_feat_conf_feat_name;
54+
std::vector<std::string> table_feat_conf_feat_dtype;
55+
std::vector<int32_t> table_feat_conf_feat_shape;
56+
5157
// std::thread *server_thread, *client_thread;
5258

5359
// std::shared_ptr<paddle::distributed::PSServer> pserver_ptr;
@@ -65,25 +71,43 @@ class GraphPyService {
6571
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
6672
void GetDownpourSparseTableProto(
6773
::paddle::distributed::TableParameter* sparse_table_proto,
68-
uint32_t table_id) {
74+
uint32_t table_id,
75+
std::string table_name,
76+
std::string table_type,
77+
std::vector<std::string> feat_name,
78+
std::vector<std::string> feat_dtype,
79+
std::vector<int32_t> feat_shape) {
6980
sparse_table_proto->set_table_id(table_id);
7081
sparse_table_proto->set_table_class("GraphTable");
7182
sparse_table_proto->set_shard_num(shard_num);
7283
sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE);
7384
::paddle::distributed::TableAccessorParameter* accessor_proto =
7485
sparse_table_proto->mutable_accessor();
86+
7587
::paddle::distributed::CommonAccessorParameter* common_proto =
7688
sparse_table_proto->mutable_common();
7789

90+
// Set GraphTable Parameter
91+
common_proto->set_table_name(table_name);
92+
common_proto->set_name(table_type);
93+
for(size_t i = 0;i < feat_name.size();i ++) {
94+
common_proto->add_params(feat_dtype[i]);
95+
common_proto->add_dims(feat_shape[i]);
96+
common_proto->add_attributes(feat_name[i]);
97+
}
98+
7899
accessor_proto->set_accessor_class("CommMergeAccessor");
79100
}
80101

81102
void set_server_size(int server_size) { this->server_size = server_size; }
103+
void set_num_node_types(int num_node_types) { this->num_node_types = num_node_types; }
82104
int get_server_size(int server_size) { return server_size; }
83105
std::vector<std::string> split(std::string& str, const char pattern);
84106
void set_up(std::string ips_str, int shard_num,
85107
std::vector<std::string> node_types,
86108
std::vector<std::string> edge_types);
109+
110+
void add_table_feat_conf(std::string node_type, std::string feat_name, std::string feat_dtype, int32_t feat_shape);
87111
};
88112
class GraphPyServer : public GraphPyService {
89113
public:

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,22 @@ int32_t GraphTable::initialize() {
372372
and _shard_idx to server
373373
rank
374374
*/
375+
auto common = _config.common();
376+
377+
this->table_name = common.table_name();
378+
this->table_type = common.name();
379+
VLOG(0) << " init graph table type " << this->table_type << " table name " << this->table_name;
380+
int feat_conf_size = static_cast<int>(common.attributes().size());
381+
for(int i=0; i<feat_conf_size;i ++) {
382+
auto & f_name= common.attributes()[i];
383+
auto & f_shape = common.dims()[i];
384+
auto & f_dtype = common.params()[i];
385+
this->feat_name.push_back(f_name);
386+
this->feat_shape.push_back(f_shape);
387+
this->feat_dtype.push_back(f_dtype);
388+
VLOG(0) << "init graph table feat conf name:"<< f_name << " shape:" << f_shape << " dtype:" << f_dtype;
389+
}
390+
375391
shard_num = _config.shard_num();
376392
VLOG(0) << "in init graph table shard num = " << shard_num << " shard_idx"
377393
<< _shard_idx;

paddle/fluid/distributed/table/common_graph_table.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ class GraphTable : public SparseTable {
123123
size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num;
124124
const int task_pool_size_ = 11;
125125
const int random_sample_nodes_ranges = 3;
126+
127+
std::vector<std::string > feat_name;
128+
std::vector<std::string > feat_dtype;
129+
std::vector<int32_t > feat_shape;
130+
std::string table_name;
131+
std::string table_type;
132+
126133
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
127134
};
128135
}

paddle/fluid/pybind/fleet_py.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,16 @@ void BindGraphPyServer(py::module* m) {
171171
py::class_<GraphPyServer>(*m, "GraphPyServer")
172172
.def(py::init<>())
173173
.def("start_server", &GraphPyServer::start_server)
174-
.def("set_up", &GraphPyServer::set_up);
174+
.def("set_up", &GraphPyServer::set_up)
175+
.def("add_table_feat_conf", &GraphPyServer::add_table_feat_conf);
175176
}
176177
void BindGraphPyClient(py::module* m) {
177178
py::class_<GraphPyClient>(*m, "GraphPyClient")
178179
.def(py::init<>())
179180
.def("load_edge_file", &GraphPyClient::load_edge_file)
180181
.def("load_node_file", &GraphPyClient::load_node_file)
181182
.def("set_up", &GraphPyClient::set_up)
183+
.def("add_table_feat_conf", &GraphPyClient::add_table_feat_conf)
182184
.def("pull_graph_list", &GraphPyClient::pull_graph_list)
183185
.def("start_client", &GraphPyClient::start_client)
184186
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors)

0 commit comments

Comments
 (0)