Skip to content

Commit 7f73ad5

Browse files
faran928meta-codesync[bot]
authored andcommitted
Keep the logic for inference tensorpool forward consistent w/ set up before hetero sharding (#3553)
Summary: Pull Request resolved: #3553 Keep the logic for inference tensorpool forward consistent w/ set up before hetero sharding. Using optional tensor wrapper is interfering with lowering jobs as the model split boundary are different when tensorpool + TBE exist together Reverting some of the tests set up as well since this swaps the order for some of the nodes. Differential Revision: D87326553 fbshipit-source-id: ad29e23082e89dacb60f7af31a00e7c848f57f43
1 parent 5391be5 commit 7f73ad5

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

torchrec/distributed/tensor_pool.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
473473
dist_input, unbucketize_permute, bucket_mapping, bucketized_lengths = (
474474
self._lookup_ids_dist(ids)
475475
)
476-
unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor(
477-
unbucketize_permute
478-
)
476+
unbucketize_permute_non_opt = unbucketize_permute
479477

480478
lookup = self._lookup_local(dist_input)
481479

@@ -512,12 +510,20 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
512510
)
513511

514512
output = self._lookup_values_dist(lookup_list)
515-
516-
return index_select_view(
517-
output,
518-
unbucketize_permute_non_opt.to(device=output.device),
519-
self._dim,
520-
)
513+
# When memory_capacity_per_rank is added then boundary split for the
514+
# model is different. Handling device movement accordingly
515+
if self._sharding_plan.memory_capacity_per_rank is None:
516+
return index_select_view(
517+
output,
518+
unbucketize_permute_non_opt,
519+
self._dim,
520+
)
521+
else:
522+
return index_select_view(
523+
output,
524+
unbucketize_permute_non_opt.to(device=output.device),
525+
self._dim,
526+
)
521527

522528
# pyre-ignore
523529
def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor):

0 commit comments

Comments
 (0)