Skip to content

Commit 2a36cb2

Browse files
committed
psgpu:optimize build_cpu haseset; test=develop
1 parent 85642a0 commit 2a36cb2

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

paddle/fluid/framework/fleet/heter_context.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License. */
2929
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
3030
#endif
3131

32+
#include "paddle/fluid/distributed/thirdparty/round_robin.h"
3233
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
3334
#include "paddle/fluid/framework/scope.h"
3435

@@ -106,7 +107,7 @@ class HeterContext {
106107
}
107108

108109
void batch_add_keys(int shard_num,
109-
const std::unordered_set<uint64_t>& shard_keys) {
110+
const robin_hood::unordered_set<uint64_t>& shard_keys) {
110111
int idx = feature_keys_[shard_num].size();
111112
feature_keys_[shard_num].resize(feature_keys_[shard_num].size() +
112113
shard_keys.size());

paddle/fluid/framework/fleet/ps_gpu_wrapper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License. */
2929
#include <gloo/broadcast.h>
3030
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
3131
#endif
32+
#include "paddle/fluid/distributed/thirdparty/round_robin.h"
3233
#include "paddle/fluid/framework/data_set.h"
3334
#include "paddle/fluid/framework/fleet/heter_context.h"
3435
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
@@ -270,7 +271,7 @@ class PSGPUWrapper {
270271
std::vector<int> heter_devices_;
271272
std::unordered_set<std::string> gpu_ps_config_keys_;
272273
HeterObjectPool<HeterContext> gpu_task_pool_;
273-
std::vector<std::vector<std::unordered_set<uint64_t>>> thread_keys_;
274+
std::vector<std::vector<robin_hood::unordered_set<uint64_t>>> thread_keys_;
274275
int thread_keys_thread_num_ = 37;
275276
int thread_keys_shard_num_ = 37;
276277
uint64_t max_fea_num_per_pass_ = 5000000000;

0 commit comments

Comments
 (0)