Conversation
| self._communicate_simple_fn = CommunicateSimpleFn.get_fn( | ||
| input_mode=self.layer_scatter_modes.layer_input_mode, | ||
| output_mode=self.layer_scatter_modes.attn_mode, | ||
| hidden_states_input_mode=self.layer_scatter_modes.attn_mode, |
There was a problem hiding this comment.
wondering whether it should be
| hidden_states_input_mode=self.layer_scatter_modes.attn_mode, | |
| hidden_states_input_mode=self.layer_scatter_modes.layer_input_mode, |
| output_mode=self.layer_scatter_modes.attn_mode, | ||
| hidden_states_input_mode=self.layer_scatter_modes.attn_mode, | ||
| residual_input_mode=self.layer_scatter_modes.layer_input_mode, | ||
| hidden_states_output_mode=self.layer_scatter_modes.mlp_mode, |
There was a problem hiding this comment.
| hidden_states_output_mode=self.layer_scatter_modes.mlp_mode, | |
| hidden_states_output_mode=self.layer_scatter_modes.attn_mode, |
| hidden_states_input_mode=self.layer_scatter_modes.attn_mode, | ||
| residual_input_mode=self.layer_scatter_modes.layer_input_mode, | ||
| hidden_states_output_mode=self.layer_scatter_modes.mlp_mode, | ||
| residual_output_mode=self.layer_scatter_modes.middle_residual_mode, |
There was a problem hiding this comment.
if it is like that, then maybe _communicate_with_all_reduce_and_layer_norm_fn's residual residual mode should be changed
| if ( | ||
| hidden_states_input_mode == ScatterMode.TP_ATTN_FULL | ||
| and residual_input_mode == ScatterMode.SCATTERED | ||
| and hidden_states_output_mode == ScatterMode.TP_ATTN_FULL | ||
| and residual_output_mode == ScatterMode.SCATTERED | ||
| ): |
There was a problem hiding this comment.
hmm, this branch looks like "input === output", then we should do nothing, i.e. trivial
maybe the condition is a bit wrong?
| and residual_output_mode == ScatterMode.TP_ATTN_FULL | ||
| ): | ||
| return CommunicateSimpleFn._scattered_to_tp_attn_full | ||
| return CommunicateSimpleFn._gather_hidden_states_and_residual |
There was a problem hiding this comment.
the if looks like we only gathere residual, so wondering maybe if condition is wrong (or the function name is wrong)
fzyzcjy
left a comment
There was a problem hiding this comment.
(forget to click approve for the review above)
|
@ch-wan maybe you should run "pre-commit run --all-files" to pass the lint tests |
Motivation
Fix one issue when we use DP for dense FFNs and EP (not DeepEP) for sparse FFNs. Related issue: #6297
Modifications
Checklist