@@ -16,6 +16,7 @@ limitations under the License. */
1616#include < glog/logging.h>
1717
1818#include < vector>
19+ #include < thread>
1920
2021#if defined(TRACE_PROFILE) && (defined(PADDLE_WITH_XPU_KP) || defined(PADDLE_WITH_XPU))
2122// The producer side.
@@ -382,24 +383,25 @@ void BoxWrapper::PullSparseCaseXPU(const paddle::platform::Place& place,
382383 VLOG (3 ) << " Begine BoxPs PullSparse" ;
383384 xpu::ctx_guard RAII_GUARD (ctx_xpu);
384385
385- int64_t total_bytes = total_length * feature_pull_size_;
386- void * total_values_xpu =
387- dev.pull_push_tensor .mutable_data <void >(total_bytes, place);
386+
388387
389388#ifdef TRACE_PROFILE
390389 TRACE_SCOPE_START (" copy keys" , xpu_wait (ctx_xpu->xpu_stream ));
391390#endif
392391 VLOG (3 ) << " Begin copy keys, key_num[" << total_length << " ]" ;
393392 // LoDTensor& total_keys_tensor = dev.keys_tensor;
394393 uint64_t * total_keys;
395- if (use_l3_tensor) {
396- total_keys = dev.keys_tensor .mutable_data <uint64_t >(total_length * sizeof (int64_t ), l3_place);
394+ int * key2slot = nullptr ;
395+ if (FLAGS_enable_pullpush_dedup_keys && use_xpu_sparse_map_) {
396+ total_keys = dev.keys_tensor .mutable_data <uint64_t >(total_length * 3 * sizeof (int64_t ), place);
397397 } else {
398- total_keys = dev.keys_tensor .mutable_data <uint64_t >(total_length * sizeof (int64_t ), place);
398+ if (use_l3_tensor) {
399+ total_keys = dev.keys_tensor .mutable_data <uint64_t >(total_length * sizeof (int64_t ), l3_place);
400+ } else {
401+ total_keys = dev.keys_tensor .mutable_data <uint64_t >(total_length * sizeof (int64_t ), place);
402+ }
399403 }
400- int * key2slot = nullptr ;
401- key2slot =dev.keys2slot .mutable_data <int >(total_length * sizeof (int ), place);
402-
404+ key2slot = dev.keys2slot .mutable_data <int >(total_length * sizeof (int ), place);
403405 // construct slot_level lod info
404406 std::vector<int64_t > slot_lengths_lod (slot_num + 1 , 0 );
405407 for (int i = 1 ; i <= slot_num ; i++) {
@@ -429,6 +431,65 @@ void BoxWrapper::PullSparseCaseXPU(const paddle::platform::Place& place,
429431 static_cast <int >(slot_lengths.size ()),
430432 static_cast <int >(total_length), key2slot);
431433 }
434+ uint64_t * d_pull_keys = total_keys;
435+ int pull_size = total_length;
436+ int * d_merged_idx = nullptr ;
437+ int * d_merged_offsets = nullptr ;
438+ int * d_res_idx = nullptr ;
439+ std::thread thread_get_restore_idx;
440+
441+ if (FLAGS_enable_pullpush_dedup_keys && use_xpu_sparse_map_) {
442+ uint64_t * d_merged_keys = reinterpret_cast <uint64_t *>(&total_keys[total_length]);
443+ pull_size = boxps_ptr_->DedupKeysAndFillIdxXPU (device_id, total_length, total_keys,
444+ d_merged_keys, d_merged_idx, d_merged_offsets);
445+ d_pull_keys = d_merged_keys;
446+ d_res_idx = reinterpret_cast <int *>(&total_keys[2 * total_length]);
447+
448+ thread_get_restore_idx = std::thread ([&] {
449+ xpu_set_device (device_id);
450+ std::vector<int > h_idx (total_length);
451+ std::vector<int > h_offset (pull_size + 1 );
452+ xpu_memcpy (h_idx.data (),
453+ d_merged_idx,
454+ h_idx.size () * sizeof (int ),
455+ XPUMemcpyKind::XPU_DEVICE_TO_HOST);
456+ xpu_memcpy (h_offset.data (),
457+ d_merged_offsets,
458+ pull_size * sizeof (int ),
459+ XPUMemcpyKind::XPU_DEVICE_TO_HOST);
460+ h_offset[pull_size] = total_length - 1 ;
461+ std::vector<int > tmp1 (total_length);
462+
463+ for (size_t i = 0 ; i < (size_t )pull_size; i++) {
464+ if (i != 0 ) {
465+ tmp1[h_offset[i]] = tmp1[h_offset[i] - 1 ];
466+ }
467+ else {
468+ tmp1[0 ] = 0 ;
469+ }
470+ for (int j = h_offset[i] + 1 ; j < h_offset[i + 1 ]; j++) {
471+ tmp1[j] = tmp1[j - 1 ] + 1 ;
472+ }
473+ }
474+ if (h_offset[pull_size - 1 ] != h_offset[pull_size]) {
475+ tmp1[h_offset[pull_size]] = tmp1[h_offset[pull_size] - 1 ] + 1 ;
476+ } else {
477+ tmp1[h_offset[pull_size]] = tmp1[h_offset[pull_size] - 1 ];
478+ }
479+ std::vector<int > h_res_idx (total_length);
480+ for (size_t i = 0 ; i < (size_t )total_length; i++) {
481+ h_res_idx[h_idx[i]] = i - tmp1[i];
482+ }
483+
484+ xpu_memcpy (d_res_idx,
485+ h_res_idx.data (),
486+ total_length * sizeof (int ),
487+ XPUMemcpyKind::XPU_HOST_TO_DEVICE);
488+ });
489+ }
490+
491+ void * total_values_xpu = dev.pull_push_tensor .mutable_data <void >(pull_size * feature_pull_size_, place);
492+
432493 VLOG (3 ) << " Begin call PullSparseXPU in BoxPS, dev: " << device_id
433494 << " len: " << total_length;
434495#ifdef TRACE_PROFILE
@@ -437,8 +498,9 @@ void BoxWrapper::PullSparseCaseXPU(const paddle::platform::Place& place,
437498 TRACE_SCOPE_START (" PullSparseXPU" , xpu_wait (ctx_xpu->xpu_stream ));
438499#endif
439500 pull_boxps_timer.Start ();
440- boxps_ptr_->PullSparseXPU (total_keys, total_values_xpu,
441- static_cast <int >(total_length), device_id);
501+
502+ boxps_ptr_->PullSparseXPU (d_pull_keys, total_values_xpu, pull_size, device_id);
503+
442504 pull_boxps_timer.Pause ();
443505#ifdef TRACE_PROFILE
444506 TRACE_SCOPE_END (" PullSparseXPU" , xpu_wait (ctx_xpu->xpu_stream ));
@@ -467,10 +529,14 @@ void BoxWrapper::PullSparseCaseXPU(const paddle::platform::Place& place,
467529#ifdef TRACE_PROFILE
468530 TRACE_SCOPE_START (" CopyForPull" , xpu_wait (ctx_xpu->xpu_stream ));
469531#endif
532+ if (FLAGS_enable_pullpush_dedup_keys && use_xpu_sparse_map_) {
533+ thread_get_restore_idx.join ();
534+ }
535+
470536 box_wrapper_kernel_->CopyForPull (place, xpu_keys, (float **)values.data (), total_values_xpu,
471- pull_offset, slot_lengths_lod.data (), slot_num, key2slot, hidden_size,
472- expand_embed_dim, total_length, total_dims, skip_offset,
473- expand_only);
537+ pull_offset, slot_lengths_lod.data (), slot_num, key2slot, d_res_idx , hidden_size,
538+ expand_embed_dim, total_length, total_dims, skip_offset,
539+ expand_only, d_merged_idx, d_merged_offsets, pull_size );
474540#ifdef TRACE_PROFILE
475541 TRACE_SCOPE_END (" CopyForPull" , xpu_wait (ctx_xpu->xpu_stream ));
476542 TRACE_SCOPE_END (" pull copy" , xpu_wait (ctx_xpu->xpu_stream ));
0 commit comments