Skip to content

Commit 0dddde8

Browse files
committed
convert scores not hidden states; saves ~10G memory
1 parent db455a3 commit 0dddde8

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

torchtitan/distributed/deepep/deepep.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,8 @@ def dispatch_tokens(
399399
permuted_scores = scores_with_zero[padding_indices]
400400

401401
if score_before_experts and permuted_scores is not None:
402-
hidden_states = (hidden_states.to(torch.float32) * permuted_scores.reshape(-1, 1)).to(hidden_states.dtype)
402+
# Avoid float32 conversion to save memory
403+
hidden_states = hidden_states * permuted_scores.to(hidden_states.dtype).reshape(-1, 1)
403404
permuted_scores_for_state = None
404405
else:
405406
permuted_scores_for_state = permuted_scores
@@ -422,7 +423,8 @@ def combine_tokens(
422423
) -> torch.Tensor:
423424
"""Combine tokens from experts via DeepEP."""
424425
if state.permuted_scores is not None:
425-
hidden_states = (hidden_states.to(torch.float32) * state.permuted_scores.reshape(-1, 1)).to(hidden_states.dtype)
426+
# In-place multiplication to save memory
427+
hidden_states = hidden_states * state.permuted_scores.to(hidden_states.dtype).reshape(-1, 1)
426428

427429
# Remove alignment padding if it was applied
428430
if state.padding_indices is not None:

0 commit comments

Comments
 (0)