@@ -254,7 +254,6 @@ int32_t CommonSparseTable::initialize_value() {
254254 }
255255
256256 auto accessor = _config.accessor ();
257-
258257 std::vector<uint64_t > feasigns;
259258
260259 for (size_t x = 0 ; x < accessor.fea_dim (); ++x) {
@@ -271,9 +270,14 @@ int32_t CommonSparseTable::initialize_value() {
271270 std::vector<uint64_t > ids (bucket_feasigns);
272271 std::copy (feasigns.begin () + buckets[x], feasigns.begin () + buckets[x + 1 ],
273272 ids.begin ());
273+
274+ std::vector<uint32_t > fres;
275+ fres.resize (ids.size (), 1 );
276+
277+ auto pull_value = PullSparseValue (ids, fres, param_dim_);
274278 std::vector<float > pulls;
275279 pulls.resize (bucket_feasigns * param_dim_);
276- pull_sparse (pulls.data (), ids. data (), bucket_feasigns );
280+ pull_sparse (pulls.data (), pull_value );
277281 }
278282
279283 return 0 ;
@@ -399,32 +403,36 @@ int32_t CommonSparseTable::pour() {
399403 return 0 ;
400404}
401405
402- int32_t CommonSparseTable::pull_sparse (float * pull_values, const uint64_t * keys,
403- size_t num ) {
406+ int32_t CommonSparseTable::pull_sparse (float * pull_values,
407+ const PullSparseValue& pull_value ) {
404408 rwlock_->RDLock ();
405409
406- std::vector<std::vector<uint64_t >> offset_bucket;
407- offset_bucket.resize (task_pool_size_);
408-
409- for (int x = 0 ; x < num; ++x) {
410- auto y = keys[x] % task_pool_size_;
411- offset_bucket[y].push_back (x);
412- }
413-
414- std::vector<std::future<int >> tasks (task_pool_size_);
410+ auto shard_num = task_pool_size_;
411+ std::vector<std::future<int >> tasks (shard_num);
415412
416- for (int shard_id = 0 ; shard_id < task_pool_size_ ; ++shard_id) {
413+ for (int shard_id = 0 ; shard_id < shard_num ; ++shard_id) {
417414 tasks[shard_id] = _shards_task_pool[shard_id]->enqueue (
418- [this , shard_id, &keys , &offset_bucket , &pull_values]() -> int {
415+ [this , shard_id, shard_num , &pull_value , &pull_values]() -> int {
419416 auto & block = shard_values_[shard_id];
420- auto & offsets = offset_bucket[shard_id];
421417
422- for (int i = 0 ; i < offsets.size (); ++i) {
423- auto offset = offsets[i];
424- auto id = keys[offset];
425- auto * value = block->Init (id);
426- std::copy_n (value + param_offset_, param_dim_,
427- pull_values + param_dim_ * offset);
418+ std::vector<int > offsets;
419+ pull_value.Fission (shard_id, shard_num, &offsets);
420+
421+ if (pull_value.is_training_ ) {
422+ for (auto & offset : offsets) {
423+ auto feasign = pull_value.feasigns_ [offset];
424+ auto frequencie = pull_value.frequencies_ [offset];
425+ auto * value = block->Init (feasign, true , frequencie);
426+ std::copy_n (value + param_offset_, param_dim_,
427+ pull_values + param_dim_ * offset);
428+ }
429+ } else {
430+ for (auto & offset : offsets) {
431+ auto feasign = pull_value.feasigns_ [offset];
432+ auto * value = block->Init (feasign, false );
433+ std::copy_n (value + param_offset_, param_dim_,
434+ pull_values + param_dim_ * offset);
435+ }
428436 }
429437
430438 return 0 ;
0 commit comments