Skip to content

Question about Deepspeed-Domino #7654

@GoldenStain

Description

@GoldenStain

I have thoroughly reviewed your code and noticed that in DominoTransformerLayer.forward(), you make extensive use of dist.all_reduce. However, to my knowledge, it does not have a corresponding backward implementation. Could you explain why the results are still correct?

class DominoTransformerLayer(DominoModule):
    """A domino single transformer layer.
    [s, b, h] -> [s, b, h]
    """

    def __init__(self,
                 config,
                 mpu,
                 apply_rotary_pos_emb,
                 layer_number,
                 layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.causal,
                 drop_path_rate=0.):

        super(DominoTransformerLayer, self).__init__()
        self.layer_number = layer_number
        self.layer_type = layer_type

        self.apply_residual_connection_post_layernorm \
            = config.apply_residual_connection_post_layernorm

        self.llama_model = False

        self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon)

        # Self attention.
        self.self_attention = ShardedAttention(config,
                                               mpu,
                                               apply_rotary_pos_emb,
                                               layer_number,
                                               attention_type=AttnType.self_attn,
                                               attn_mask_type=self_attn_mask_type)

        self.hidden_dropout = config.hidden_dropout

        self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
        ffn_hidden_size = config.ffn_hidden_size
        if config.gated_linear_unit:
            ffn_hidden_size *= 2

        self.output_size_c = config.ffn_hidden_size
        self.input_size_c = config.hidden_size
        self.input_size_r = config.ffn_hidden_size
        self.output_size_r = self.input_size_c

        tp_world_size = mpu.get_tensor_model_parallel_world_size()
        self.TP_group = mpu.get_tensor_model_parallel_group()
        self.output_size_per_partition = self.output_size_c // tp_world_size
        self.input_size_per_partition = self.input_size_r // tp_world_size

        self.linear_fc1 = DominoAsyncColumnParallelLinear(self.input_size_c,
                                                          self.output_size_per_partition,
                                                          mpu.get_tensor_model_parallel_group(),
                                                          config=config,
                                                          init_method=config.init_method,
                                                          bias=config.add_bias_linear)

        self.mlp_activation_func = F.gelu

        self.linear_fc2 = RowParallelLinearNoComm(self.input_size_per_partition,
                                                  self.output_size_r,
                                                  config=config,
                                                  init_method=config.output_layer_init_method,
                                                  bias=config.add_bias_linear,
                                                  skip_bias_add=True)

        self.bias_dropout_add_func = bias_dropout_add(self.hidden_dropout)

    def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):

        hidden_states0, hidden_states1 = hidden_states

        layernorm_output0 = self.input_layernorm(hidden_states0)
        layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)

        # Micro batch 0: attention
        attention_output0, attention_bias0 = self.self_attention(layernorm_output0,
                                                                 attention_mask,
                                                                 DominoUtil.BATCH_0,
                                                                 rotary_pos_emb=rotary_pos_emb)

        fwd_handle0 = dist.all_reduce(attention_output0, group=self.TP_group, async_op=True)
        # End of Micro batch 0: attention

        # Micro batch 1: attention
        layernorm_output1 = self.input_layernorm(hidden_states1)
        layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)

        attention_output1, attention_bias1 = self.self_attention(layernorm_output1,
                                                                 attention_mask,
                                                                 DominoUtil.BATCH_1,
                                                                 rotary_pos_emb=rotary_pos_emb)
        fwd_handle1 = dist.all_reduce(attention_output1, group=self.TP_group, async_op=True)

        # Micro batch 0: Residual connection.
        fwd_handle0.wait()
        if self.apply_residual_connection_post_layernorm:
            residual0 = layernorm_output0
        else:
            residual0 = hidden_states0

        layernorm_input0 = self.bias_dropout_add_func(attention_output0, attention_bias0, residual0)

        layernorm_output0 = self.post_attention_layernorm(layernorm_input0)
        layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)

        if self.apply_residual_connection_post_layernorm:
            residual0 = layernorm_output0
        else:
            residual0 = layernorm_input0
        # End of Micro batch 0: Residual connection.

        # ------------ MLP ------------
        # Micro batch 0: MLP
        output0, _ = self.linear_fc1(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
        output0 = self.mlp_activation_func(output0)

        # Micro batch 1: Residual connection.
        fwd_handle1.wait()
        if self.apply_residual_connection_post_layernorm:
            residual1 = layernorm_output1
        else:
            residual1 = hidden_states1

        layernorm_input1 = self.bias_dropout_add_func(attention_output1, attention_bias1, residual1)

        layernorm_output1 = self.post_attention_layernorm(layernorm_input1)
        layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)

        if self.apply_residual_connection_post_layernorm:
            residual1 = layernorm_output1
        else:
            residual1 = layernorm_input1
        # End of Micro batch 1: Residual connection.

        hidden_states0, last_mlp_bias = self.linear_fc2(output0)
        fwd_handle0 = dist.all_reduce(hidden_states0, group=self.TP_group, async_op=True)
        # End of Micro batch 0: MLP

        # Micro batch 1: MLP
        output1, _ = self.linear_fc1(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
        output1 = self.mlp_activation_func(output1)

        hidden_states1, last_mlp_bias = self.linear_fc2(output1)

        fwd_handle1 = dist.all_reduce(hidden_states1, group=self.TP_group, async_op=True)
        # End of Micro batch 1: MLP

        # ------------  End of MLP ------------

        fwd_handle0.wait()
        hidden_states0 = self.bias_dropout_add_func(hidden_states0, last_mlp_bias, residual0)

        fwd_handle1.wait()
        hidden_states1 = self.bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1)

        return hidden_states0, hidden_states1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions