Skip to content

Commit a881b4d

Browse files
authored
Struct SparseValue && Bug Fix (#31721)
* add PullSparseValue for pull sparse * fix bug for PullSparseValue * add test mode in lookuptable * revert API change * add comment for is_training
1 parent b8b82b7 commit a881b4d

22 files changed

+232
-122
lines changed

paddle/fluid/distributed/fleet.cc

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -146,41 +146,6 @@ void FleetWrapper::CreateClient2ClientConnection() {
146146
client2client_max_retry_);
147147
}
148148

149-
std::future<int32_t> FleetWrapper::PullSparseVarsAsync(
150-
const Scope& scope, const uint64_t table_id,
151-
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
152-
std::vector<std::vector<float>>* fea_values, int fea_value_dim) {
153-
fea_keys->clear();
154-
fea_keys->resize(0);
155-
fea_keys->reserve(MAX_FEASIGN_NUM);
156-
for (auto name : var_names) {
157-
Variable* var = scope.FindVar(name);
158-
if (var == nullptr) {
159-
continue;
160-
}
161-
LoDTensor* tensor = var->GetMutable<LoDTensor>();
162-
CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
163-
int64_t* ids = tensor->data<int64_t>();
164-
size_t len = tensor->numel();
165-
for (auto i = 0u; i < len; ++i) {
166-
if (ids[i] == 0u) {
167-
continue;
168-
}
169-
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
170-
}
171-
}
172-
fea_values->resize(fea_keys->size() + 1);
173-
for (auto& t : *fea_values) {
174-
t.resize(fea_value_dim);
175-
}
176-
std::vector<float*> pull_result_ptr;
177-
for (auto& t : *fea_values) {
178-
pull_result_ptr.push_back(t.data());
179-
}
180-
return pserver_ptr_->_worker_ptr->pull_sparse(
181-
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
182-
}
183-
184149
void FleetWrapper::PullSparseVarsSync(
185150
const Scope& scope, const uint64_t table_id,
186151
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
@@ -224,8 +189,10 @@ void FleetWrapper::PullSparseVarsSync(
224189
for (auto& t : *fea_values) {
225190
pull_result_ptr.push_back(t.data());
226191
}
192+
bool training = true;
227193
auto status = pserver_ptr_->_worker_ptr->pull_sparse(
228-
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
194+
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size(),
195+
training);
229196
pull_sparse_status.push_back(std::move(status));
230197
for (auto& t : pull_sparse_status) {
231198
t.wait();
@@ -238,9 +205,13 @@ void FleetWrapper::PullSparseVarsSync(
238205
}
239206
}
240207

208+
// is_training is true means training, false means inference, the behavior is
209+
// different on pserver
210+
241211
void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
242212
uint64_t padding_id,
243213
platform::Place place,
214+
bool is_training,
244215
std::vector<const LoDTensor*>* inputs,
245216
std::vector<LoDTensor*>* outputs) {
246217
std::vector<uint64_t> fea_keys;
@@ -279,7 +250,8 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
279250
}
280251
auto* communicator = Communicator::GetInstance();
281252
auto status = communicator->_worker_ptr->pull_sparse(
282-
pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size());
253+
pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size(),
254+
is_training);
283255
status.wait();
284256
auto ret = status.get();
285257
if (ret != 0) {

paddle/fluid/distributed/fleet.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,14 @@ class FleetWrapper {
8484
int fea_dim,
8585
const std::vector<std::string>& var_emb_names);
8686

87-
// Pull sparse variables from server in async mode
88-
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
89-
// Param<out>: fea_values std::future
90-
std::future<int32_t> PullSparseVarsAsync(
91-
const Scope& scope, const uint64_t table_id,
92-
const std::vector<std::string>& var_names,
93-
std::vector<uint64_t>* fea_keys,
94-
std::vector<std::vector<float>>* fea_values, int fea_dim);
95-
9687
// Pull sparse variables from server in sync mode
9788
// pull immediately to tensors
89+
// is_training is true means training, false means inference, the behavior is
90+
// different on pserver
91+
9892
void PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
9993
uint64_t padding_id, platform::Place place,
94+
bool is_training,
10095
std::vector<const LoDTensor*>* inputs, // NOLINT
10196
std::vector<LoDTensor*>* outputs); // NOLINT
10297

paddle/fluid/distributed/service/brpc_ps_client.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -768,8 +768,8 @@ std::future<int32_t> BrpcPsClient::push_global_step(int table_id,
768768

769769
std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
770770
size_t table_id,
771-
const uint64_t *keys,
772-
size_t num) {
771+
const uint64_t *keys, size_t num,
772+
bool is_training) {
773773
size_t request_call_num = _server_channels.size();
774774

775775
auto shard_sorted_kvs = std::make_shared<
@@ -837,16 +837,27 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
837837
uint32_t kv_request_count = 0;
838838
size_t sorted_kv_size = sorted_kvs.size();
839839
auto &request_buffer = closure->cntl(i)->request_attachment();
840+
841+
request_buffer.append((void *)&is_training, sizeof(bool));
842+
std::vector<uint32_t> keys_counter;
843+
keys_counter.reserve(sorted_kv_size);
844+
840845
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
841846
++kv_request_count;
847+
uint32_t keys = 1;
842848
last_key = sorted_kvs[kv_idx].first;
843849
request_buffer.append((void *)&last_key, sizeof(uint64_t));
844850
while (kv_idx < sorted_kv_size - 1 &&
845851
last_key == sorted_kvs[kv_idx + 1].first) {
846852
++kv_idx;
853+
++keys;
847854
}
855+
keys_counter.push_back(keys);
848856
}
849857

858+
request_buffer.append((void *)keys_counter.data(),
859+
sizeof(uint32_t) * keys_counter.size());
860+
850861
if (kv_request_count == 0) {
851862
closure->Run();
852863
} else {
@@ -956,7 +967,7 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
956967
}
957968

958969
auto status = pull_sparse((float **)save_vec.data(), table_id,
959-
save_key.data(), save_key.size());
970+
save_key.data(), save_key.size(), true);
960971
status.wait();
961972

962973
// create lod tensor

paddle/fluid/distributed/service/brpc_ps_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ class BrpcPsClient : public PSClient {
148148

149149
virtual std::future<int32_t> pull_sparse(float **select_values,
150150
size_t table_id,
151-
const uint64_t *keys, size_t num);
151+
const uint64_t *keys, size_t num,
152+
bool is_training);
152153

153154
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
154155

paddle/fluid/distributed/service/brpc_ps_server.cc

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
1616
#include <thread> // NOLINT
17+
#include "paddle/fluid/distributed/table/depends/sparse_utils.h"
1718
#include "paddle/fluid/distributed/table/table.h"
1819
#include "paddle/fluid/framework/archive.h"
1920
#include "paddle/fluid/platform/profiler.h"
@@ -337,33 +338,39 @@ int32_t BrpcPsService::pull_sparse(Table *table,
337338
brpc::Controller *cntl) {
338339
platform::RecordEvent record_event("PsService->pull_sparse");
339340
CHECK_TABLE_EXIST(table, request, response)
340-
thread_local std::string push_sparse_request_buffer;
341+
341342
auto &req_io_buffer = cntl->request_attachment();
342343
auto req_buffer_size = req_io_buffer.size();
344+
343345
if (req_buffer_size < 1) {
344346
set_response_code(response, -1, "req attachment is empty");
345347
return 0;
346348
}
349+
347350
if (request.params_size() < 1) {
348351
set_response_code(response, -1,
349352
"PsRequestMessage.params is requeired at "
350353
"least 1 for num of sparse_key");
351354
return 0;
352355
}
356+
353357
uint32_t num = *(uint32_t *)(request.params(0).c_str());
354-
push_sparse_request_buffer.resize(0);
355-
push_sparse_request_buffer.reserve(req_buffer_size);
356-
const char *data = (const char *)cntl->request_attachment().fetch(
357-
const_cast<char *>(push_sparse_request_buffer.data()), req_buffer_size);
358-
/*
359-
Attachment Content:
360-
|---keysData---|
361-
|---8*{num}B---|
362-
*/
363-
const uint64_t *keys = (const uint64_t *)data;
358+
auto dim = table->value_accesor()->select_dim();
359+
360+
thread_local std::string req_buffer;
361+
req_buffer.reserve(req_buffer_size);
362+
363+
const void *data = cntl->request_attachment().fetch(
364+
const_cast<char *>(req_buffer.data()), req_buffer_size);
365+
366+
auto value = PullSparseValue(num, dim);
367+
368+
value.DeserializeFromBytes(const_cast<void *>(data));
369+
364370
std::vector<float> res_data;
365-
res_data.resize(num * table->value_accesor()->select_size() / sizeof(float));
366-
table->pull_sparse(res_data.data(), keys, num);
371+
res_data.resize(num * dim);
372+
table->pull_sparse(res_data.data(), value);
373+
367374
cntl->response_attachment().append((char *)res_data.data(),
368375
res_data.size() * sizeof(float));
369376
return 0;

paddle/fluid/distributed/service/communicator.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,11 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
320320
push_g_vec.push_back(tensor->data<float>() + i * dim);
321321
}
322322

323+
bool training = true;
324+
323325
auto status = _worker_ptr->pull_sparse(
324326
(float **)push_g_vec.data(), table_id, // NOLINT
325-
sparse_push_keys.data(), sparse_push_keys.size());
327+
sparse_push_keys.data(), sparse_push_keys.size(), training);
326328
status.wait();
327329
return;
328330
}

paddle/fluid/distributed/service/ps_client.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,11 @@ class PSClient {
112112
// future结束前keys和values缓冲区不能再次使用
113113
// 整合多个线程请求的keys,聚集并分散发送到server
114114
// 返回结果后,遍历buffer并对values赋值
115+
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
115116
virtual std::future<int32_t> pull_sparse(float **select_values,
116117
size_t table_id,
117-
const uint64_t *keys,
118-
size_t num) = 0;
118+
const uint64_t *keys, size_t num,
119+
bool is_training) = 0;
119120

120121
virtual std::future<int32_t> print_table_stat(uint32_t table_id) = 0;
121122

paddle/fluid/distributed/table/common_graph_table.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,16 @@ class GraphTable : public SparseTable {
103103

104104
Node *find_node(uint64_t id);
105105

106-
virtual int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) {
106+
virtual int32_t pull_sparse(float *values,
107+
const PullSparseValue &pull_value) {
107108
return 0;
108109
}
110+
109111
virtual int32_t push_sparse(const uint64_t *keys, const float *values,
110112
size_t num) {
111113
return 0;
112114
}
115+
113116
virtual void clear() {}
114117
virtual int32_t flush() { return 0; }
115118
virtual int32_t shrink(const std::string &param) { return 0; }
@@ -140,5 +143,5 @@ class GraphTable : public SparseTable {
140143

141144
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
142145
};
143-
}
144-
};
146+
} // namespace distributed
147+
}; // namespace paddle

paddle/fluid/distributed/table/common_sparse_table.cc

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ int32_t CommonSparseTable::initialize_value() {
254254
}
255255

256256
auto accessor = _config.accessor();
257-
258257
std::vector<uint64_t> feasigns;
259258

260259
for (size_t x = 0; x < accessor.fea_dim(); ++x) {
@@ -271,9 +270,14 @@ int32_t CommonSparseTable::initialize_value() {
271270
std::vector<uint64_t> ids(bucket_feasigns);
272271
std::copy(feasigns.begin() + buckets[x], feasigns.begin() + buckets[x + 1],
273272
ids.begin());
273+
274+
std::vector<uint32_t> fres;
275+
fres.resize(ids.size(), 1);
276+
277+
auto pull_value = PullSparseValue(ids, fres, param_dim_);
274278
std::vector<float> pulls;
275279
pulls.resize(bucket_feasigns * param_dim_);
276-
pull_sparse(pulls.data(), ids.data(), bucket_feasigns);
280+
pull_sparse(pulls.data(), pull_value);
277281
}
278282

279283
return 0;
@@ -399,32 +403,36 @@ int32_t CommonSparseTable::pour() {
399403
return 0;
400404
}
401405

402-
int32_t CommonSparseTable::pull_sparse(float* pull_values, const uint64_t* keys,
403-
size_t num) {
406+
int32_t CommonSparseTable::pull_sparse(float* pull_values,
407+
const PullSparseValue& pull_value) {
404408
rwlock_->RDLock();
405409

406-
std::vector<std::vector<uint64_t>> offset_bucket;
407-
offset_bucket.resize(task_pool_size_);
408-
409-
for (int x = 0; x < num; ++x) {
410-
auto y = keys[x] % task_pool_size_;
411-
offset_bucket[y].push_back(x);
412-
}
413-
414-
std::vector<std::future<int>> tasks(task_pool_size_);
410+
auto shard_num = task_pool_size_;
411+
std::vector<std::future<int>> tasks(shard_num);
415412

416-
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
413+
for (int shard_id = 0; shard_id < shard_num; ++shard_id) {
417414
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
418-
[this, shard_id, &keys, &offset_bucket, &pull_values]() -> int {
415+
[this, shard_id, shard_num, &pull_value, &pull_values]() -> int {
419416
auto& block = shard_values_[shard_id];
420-
auto& offsets = offset_bucket[shard_id];
421417

422-
for (int i = 0; i < offsets.size(); ++i) {
423-
auto offset = offsets[i];
424-
auto id = keys[offset];
425-
auto* value = block->Init(id);
426-
std::copy_n(value + param_offset_, param_dim_,
427-
pull_values + param_dim_ * offset);
418+
std::vector<int> offsets;
419+
pull_value.Fission(shard_id, shard_num, &offsets);
420+
421+
if (pull_value.is_training_) {
422+
for (auto& offset : offsets) {
423+
auto feasign = pull_value.feasigns_[offset];
424+
auto frequencie = pull_value.frequencies_[offset];
425+
auto* value = block->Init(feasign, true, frequencie);
426+
std::copy_n(value + param_offset_, param_dim_,
427+
pull_values + param_dim_ * offset);
428+
}
429+
} else {
430+
for (auto& offset : offsets) {
431+
auto feasign = pull_value.feasigns_[offset];
432+
auto* value = block->Init(feasign, false);
433+
std::copy_n(value + param_offset_, param_dim_,
434+
pull_values + param_dim_ * offset);
435+
}
428436
}
429437

430438
return 0;

paddle/fluid/distributed/table/common_sparse_table.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ class CommonSparseTable : public SparseTable {
6161
int32_t save(const std::string& path, const std::string& param);
6262

6363
virtual std::pair<int64_t, int64_t> print_table_stat();
64-
virtual int32_t pull_sparse(float* pull_values, const uint64_t* keys,
65-
size_t num);
64+
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
6665

6766
virtual int32_t push_sparse(const uint64_t* keys, const float* values,
6867
size_t num);

0 commit comments

Comments
 (0)