@@ -406,7 +406,7 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place,
406406 const std::vector<float *>& values,
407407 const std::vector<int64_t >& slot_lengths,
408408 const int hidden_size, const int expand_embed_dim,
409- const int skip_offset) {
409+ const int skip_offset, bool expand_only ) {
410410#define EMBEDX_CASE (i, ...) \
411411 case i: { \
412412 constexpr size_t EmbedxDim = i; \
@@ -425,33 +425,33 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place,
425425 PullSparseCase< \
426426 boxps::FeaturePullValueGpuShareEmbedding<EmbedxDim, ExpandDim>>( \
427427 place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \
428- skip_offset); \
428+ skip_offset, expand_only); \
429429 } else if (feature_type_ == static_cast <int >(boxps::FEATURE_PCOC)) { \
430430 PullSparseCase<boxps::FeaturePullValueGpuPCOC<EmbedxDim, ExpandDim>>( \
431431 place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \
432- skip_offset); \
432+ skip_offset, expand_only); \
433433 } else if (feature_type_ == static_cast <int >(boxps::FEATURE_QUANT) || \
434434 feature_type_ == static_cast <int >(boxps::FEATURE_SHOWCLK)) { \
435435 PullSparseCase<boxps::FeaturePullValueGpuQuant<EmbedxDim, ExpandDim>>( \
436436 place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \
437- skip_offset); \
437+ skip_offset, expand_only); \
438438 } else if (feature_type_ == static_cast <int >(boxps::FEATURE_CONV)) { \
439439 PullSparseCase<boxps::FeaturePullValueGpuConv<EmbedxDim, ExpandDim>>( \
440440 place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \
441- skip_offset); \
441+ skip_offset, expand_only); \
442442 } else if (feature_type_ == static_cast <int >(boxps::FEATURE_VARIABLE)) { \
443443 PullSparseCase<boxps::FeatureVarPullValueGpu<EmbedxDim, ExpandDim>>( \
444444 place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \
445- skip_offset); \
445+ skip_offset, expand_only); \
446446 } else if (EmbedxDim == 0 && \
447447 feature_type_ == static_cast <int >(boxps::FEATURE_ADAM)) { \
448448 PullSparseCase<boxps::FeatureVarPullValueGpu<EmbedxDim, ExpandDim>>( \
449449 place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \
450- skip_offset); \
450+ skip_offset, expand_only); \
451451 } else { \
452452 PullSparseCase<boxps::FeaturePullValueGpu<EmbedxDim, ExpandDim>>( \
453453 place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \
454- skip_offset); \
454+ skip_offset, expand_only); \
455455 } \
456456 } break
457457
@@ -489,7 +489,9 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
489489 const std::vector<int64_t >& slot_lengths,
490490 const int hidden_size,
491491 const int expand_embed_dim,
492- const int batch_size, const int skip_offset) {
492+ const int batch_size,
493+ const int skip_offset,
494+ bool expand_only) {
493495#define EMBEDX_CASE (i, ...) \
494496 case i: { \
495497 constexpr size_t EmbedxDim = i; \
@@ -508,30 +510,30 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
508510 PushSparseGradCase< \
509511 boxps::FeaturePushValueGpuShareEmbedding<EmbedxDim, ExpandDim>>( \
510512 place, keys, grad_values, slot_lengths, hidden_size, \
511- expand_embed_dim, batch_size, skip_offset); \
513+ expand_embed_dim, batch_size, skip_offset, expand_only); \
512514 } else if (feature_type_ == static_cast <int >(boxps::FEATURE_PCOC)) { \
513515 PushSparseGradCase< \
514516 boxps::FeaturePushValueGpuPCOC<EmbedxDim, ExpandDim>>( \
515517 place, keys, grad_values, slot_lengths, hidden_size, \
516- expand_embed_dim, batch_size, skip_offset); \
518+ expand_embed_dim, batch_size, skip_offset, expand_only); \
517519 } else if (feature_type_ == static_cast <int >(boxps::FEATURE_VARIABLE)) { \
518520 PushSparseGradCase<boxps::FeatureVarPushValueGpu<EmbedxDim, ExpandDim>>( \
519521 place, keys, grad_values, slot_lengths, hidden_size, \
520- expand_embed_dim, batch_size, skip_offset); \
522+ expand_embed_dim, batch_size, skip_offset, expand_only); \
521523 } else if (feature_type_ == static_cast <int >(boxps::FEATURE_CONV)) { \
522524 PushSparseGradCase< \
523525 boxps::FeaturePushValueGpuConv<EmbedxDim, ExpandDim>>( \
524526 place, keys, grad_values, slot_lengths, hidden_size, \
525- expand_embed_dim, batch_size, skip_offset); \
527+ expand_embed_dim, batch_size, skip_offset, expand_only); \
526528 } else if (EmbedxDim == 0 && \
527529 feature_type_ == static_cast <int >(boxps::FEATURE_ADAM)) { \
528530 PushSparseGradCase<boxps::FeatureVarPushValueGpu<EmbedxDim, ExpandDim>>( \
529531 place, keys, grad_values, slot_lengths, hidden_size, \
530- expand_embed_dim, batch_size, skip_offset); \
532+ expand_embed_dim, batch_size, skip_offset, expand_only); \
531533 } else { \
532534 PushSparseGradCase<boxps::FeaturePushValueGpu<EmbedxDim, ExpandDim>>( \
533535 place, keys, grad_values, slot_lengths, hidden_size, \
534- expand_embed_dim, batch_size, skip_offset); \
536+ expand_embed_dim, batch_size, skip_offset, expand_only); \
535537 } \
536538 } break
537539
0 commit comments