-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Open
Description
I have thoroughly reviewed your code and noticed that in DominoTransformerLayer.forward(), you make extensive use of dist.all_reduce. However, to my knowledge, it does not have a corresponding backward implementation. Could you explain why the results are still correct?
class DominoTransformerLayer(DominoModule):
"""A domino single transformer layer.
[s, b, h] -> [s, b, h]
"""
def __init__(self,
config,
mpu,
apply_rotary_pos_emb,
layer_number,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.causal,
drop_path_rate=0.):
super(DominoTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= config.apply_residual_connection_post_layernorm
self.llama_model = False
self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = ShardedAttention(config,
mpu,
apply_rotary_pos_emb,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = config.hidden_dropout
self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
# MLP
ffn_hidden_size = config.ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
self.output_size_c = config.ffn_hidden_size
self.input_size_c = config.hidden_size
self.input_size_r = config.ffn_hidden_size
self.output_size_r = self.input_size_c
tp_world_size = mpu.get_tensor_model_parallel_world_size()
self.TP_group = mpu.get_tensor_model_parallel_group()
self.output_size_per_partition = self.output_size_c // tp_world_size
self.input_size_per_partition = self.input_size_r // tp_world_size
self.linear_fc1 = DominoAsyncColumnParallelLinear(self.input_size_c,
self.output_size_per_partition,
mpu.get_tensor_model_parallel_group(),
config=config,
init_method=config.init_method,
bias=config.add_bias_linear)
self.mlp_activation_func = F.gelu
self.linear_fc2 = RowParallelLinearNoComm(self.input_size_per_partition,
self.output_size_r,
config=config,
init_method=config.output_layer_init_method,
bias=config.add_bias_linear,
skip_bias_add=True)
self.bias_dropout_add_func = bias_dropout_add(self.hidden_dropout)
def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
hidden_states0, hidden_states1 = hidden_states
layernorm_output0 = self.input_layernorm(hidden_states0)
layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
# Micro batch 0: attention
attention_output0, attention_bias0 = self.self_attention(layernorm_output0,
attention_mask,
DominoUtil.BATCH_0,
rotary_pos_emb=rotary_pos_emb)
fwd_handle0 = dist.all_reduce(attention_output0, group=self.TP_group, async_op=True)
# End of Micro batch 0: attention
# Micro batch 1: attention
layernorm_output1 = self.input_layernorm(hidden_states1)
layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
attention_output1, attention_bias1 = self.self_attention(layernorm_output1,
attention_mask,
DominoUtil.BATCH_1,
rotary_pos_emb=rotary_pos_emb)
fwd_handle1 = dist.all_reduce(attention_output1, group=self.TP_group, async_op=True)
# Micro batch 0: Residual connection.
fwd_handle0.wait()
if self.apply_residual_connection_post_layernorm:
residual0 = layernorm_output0
else:
residual0 = hidden_states0
layernorm_input0 = self.bias_dropout_add_func(attention_output0, attention_bias0, residual0)
layernorm_output0 = self.post_attention_layernorm(layernorm_input0)
layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
if self.apply_residual_connection_post_layernorm:
residual0 = layernorm_output0
else:
residual0 = layernorm_input0
# End of Micro batch 0: Residual connection.
# ------------ MLP ------------
# Micro batch 0: MLP
output0, _ = self.linear_fc1(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
output0 = self.mlp_activation_func(output0)
# Micro batch 1: Residual connection.
fwd_handle1.wait()
if self.apply_residual_connection_post_layernorm:
residual1 = layernorm_output1
else:
residual1 = hidden_states1
layernorm_input1 = self.bias_dropout_add_func(attention_output1, attention_bias1, residual1)
layernorm_output1 = self.post_attention_layernorm(layernorm_input1)
layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
if self.apply_residual_connection_post_layernorm:
residual1 = layernorm_output1
else:
residual1 = layernorm_input1
# End of Micro batch 1: Residual connection.
hidden_states0, last_mlp_bias = self.linear_fc2(output0)
fwd_handle0 = dist.all_reduce(hidden_states0, group=self.TP_group, async_op=True)
# End of Micro batch 0: MLP
# Micro batch 1: MLP
output1, _ = self.linear_fc1(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
output1 = self.mlp_activation_func(output1)
hidden_states1, last_mlp_bias = self.linear_fc2(output1)
fwd_handle1 = dist.all_reduce(hidden_states1, group=self.TP_group, async_op=True)
# End of Micro batch 1: MLP
# ------------ End of MLP ------------
fwd_handle0.wait()
hidden_states0 = self.bias_dropout_add_func(hidden_states0, last_mlp_bias, residual0)
fwd_handle1.wait()
hidden_states1 = self.bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1)
return hidden_states0, hidden_states1Metadata
Metadata
Assignees
Labels
No labels