File tree Expand file tree Collapse file tree
torchtitan/distributed/deepep Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -399,7 +399,8 @@ def dispatch_tokens(
399399 permuted_scores = scores_with_zero [padding_indices ]
400400
401401 if score_before_experts and permuted_scores is not None :
402- hidden_states = (hidden_states .to (torch .float32 ) * permuted_scores .reshape (- 1 , 1 )).to (hidden_states .dtype )
402+ # Avoid float32 conversion to save memory
403+ hidden_states = hidden_states * permuted_scores .to (hidden_states .dtype ).reshape (- 1 , 1 )
403404 permuted_scores_for_state = None
404405 else :
405406 permuted_scores_for_state = permuted_scores
@@ -422,7 +423,8 @@ def combine_tokens(
422423) -> torch .Tensor :
423424 """Combine tokens from experts via DeepEP."""
424425 if state .permuted_scores is not None :
425- hidden_states = (hidden_states .to (torch .float32 ) * state .permuted_scores .reshape (- 1 , 1 )).to (hidden_states .dtype )
426+ # In-place multiplication to save memory
427+ hidden_states = hidden_states * state .permuted_scores .to (hidden_states .dtype ).reshape (- 1 , 1 )
426428
427429 # Remove alignment padding if it was applied
428430 if state .padding_indices is not None :
You can’t perform that action at this time.
0 commit comments