From 5382c8a1d08e36ce7babc03677f9fabec6cdd8f7 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sat, 14 Sep 2024 23:41:23 -0500 Subject: [PATCH 01/27] Implemented ColumnParallelLinear with Transformer-Engine --- megatron/model/transformer.py | 61 +++++----- megatron/model/transformer_engine.py | 169 ++++++++++++++++++++++++--- megatron/neox_arguments/neox_args.py | 44 +++---- 3 files changed, 200 insertions(+), 74 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index d112a7461..08e5987de 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -18,8 +18,6 @@ """Transformer.""" import math -from contextlib import nullcontext - import torch import torch.nn.functional as F import torch.nn as nn @@ -50,11 +48,6 @@ from megatron.model.utils import configure_sparse_attention from deepspeed.moe.layer import MoE -try: - from flash_attn.ops.activations import swiglu -except ImportError: - swiglu = None - # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) @@ -114,6 +107,11 @@ def __init__( self.bias_gelu_fusion = neox_args.bias_gelu_fusion self.multiple_of = multiple_of + if neox_args.te_linear: + from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear + else: + from megatron.mpu import ColumnParallelLinear + if neox_args.intermediate_size: ffn_dim = neox_args.intermediate_size elif neox_args.expansion_factor: @@ -124,12 +122,7 @@ def __init__( ffn_dim_in = ffn_dim if self.is_gated: # set activation function to be gated implementation - self.activation_func = Gated_Activation( - self.activation_func, - (swiglu is not None) - and (neox_args.activation == "swiglu") - and neox_args.use_flashattn_swiglu, - ) + self.activation_func = Gated_Activation(self.activation_func) # auto scale so gated activations has equal parameters ffn_dim = int(ffn_dim * 2 / 3) ffn_dim_in = ffn_dim // 2 @@ -142,7 +135,7 @@ def __init__( self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) ) - self.linear1 = mpu.ColumnParallelLinear( + self.linear1 = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=ffn_dim, @@ -170,7 +163,10 @@ def __init__( def forward(self, hidden_states): # [s, b, intermediate_size] intermediate_parallel, bias_parallel = self.linear1(hidden_states) - + # output = self.linear1(hidden_states) + # print(output) + # import sys + # sys.exit() if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): intermediate_parallel = self.activation_func( intermediate_parallel, bias_parallel @@ -186,10 +182,9 @@ def forward(self, hidden_states): class Gated_Activation(torch.nn.Module): - def __init__(self, activation_func, use_swiglu=False): + def __init__(self, activation_func): super().__init__() self.activation_func = activation_func - self.use_swiglu = use_swiglu def forward(self, x, bias=None): x, gate = x.chunk(2, dim=-1) @@ -197,11 +192,8 @@ def forward(self, x, bias=None): bias_1, bias_2 = bias.chunk(2, dim=-1) x = x + bias_1 gate = gate + bias_2 - if not self.use_swiglu: - intermediate_parallel = self.activation_func(gate) - return intermediate_parallel * x - else: - return swiglu(gate, x) + intermediate_parallel = self.activation_func(gate) + return intermediate_parallel * x class ParallelLinear(nn.Module): @@ -217,10 +209,16 @@ def __init__( is_last_layer=False, ): super().__init__() + + if neox_args.te_linear: + from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear + else: + from megatron.mpu import ColumnParallelLinear + self.is_rm = neox_args.train_impl == "rm" parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" if parallelism == "column": - self.final_linear = mpu.ColumnParallelLinear( + self.final_linear = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.padded_vocab_size, @@ -335,6 +333,11 @@ def __init__( ): super().__init__() + if neox_args.te_linear: + from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear + else: + from megatron.mpu import ColumnParallelLinear + self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func @@ -388,7 +391,7 @@ def __init__( if not self.gqa: # Strided linear layer. - self.query_key_value = mpu.ColumnParallelLinear( + self.query_key_value = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, @@ -398,7 +401,7 @@ def __init__( ) else: # QKV proj is smaller if we are using GQA / MQA - self.query_key_value = mpu.ColumnParallelLinear( + self.query_key_value = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, @@ -1191,7 +1194,7 @@ def forward(self, x, attention_mask, layer_past=None): self.layer_past = presents if attention_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): + with torch.enable_grad(): attention_output = bias_dropout_fn( attention_output, bias=attention_bias.expand_as(attention_output), @@ -1202,7 +1205,7 @@ def forward(self, x, attention_mask, layer_past=None): # mlp operator mlp_output, mlp_bias = self.mlp(x2) if mlp_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): + with torch.enable_grad(): output = bias_dropout_fn( mlp_output, bias=mlp_bias.expand_as(mlp_output), @@ -1228,7 +1231,7 @@ def forward(self, x, attention_mask, layer_past=None): if self.use_cache: attention_output, presents = attention_output self.layer_past = presents - with torch.enable_grad() if not self.eval else nullcontext(): + with torch.enable_grad(): if attention_bias is not None: # Use special bias_dropout_fn if we have a bias term from the above attention layer attention_output = bias_dropout_fn( @@ -1267,7 +1270,7 @@ def forward(self, x, attention_mask, layer_past=None): else: raise KeyError(self.moe_type) - with torch.enable_grad() if not self.eval else nullcontext(): + with torch.enable_grad(): if ( self.activation == "swiglu" or self.num_experts > 1 diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 338513a97..8a2a2d165 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -1,4 +1,23 @@ +import math + import torch +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.parameter import Parameter + +from megatron.mpu.initialize import get_model_parallel_rank +from megatron.mpu.initialize import get_model_parallel_world_size +from megatron.mpu.initialize import get_tensor_model_parallel_group +from megatron.mpu.mappings import copy_to_model_parallel_region +from megatron.mpu.mappings import gather_from_model_parallel_region +from megatron.mpu.mappings import reduce_from_model_parallel_region +from megatron.mpu.mappings import scatter_to_model_parallel_region +from megatron.mpu.mappings import reduce_scatter_to_sequence_parallel_region +from megatron.mpu.mappings import gather_from_sequence_parallel_region +from megatron.mpu.random import get_cuda_rng_tracker +from megatron.mpu.utils import divide +from megatron.mpu.utils import VocabUtility +from functools import partial try: import transformer_engine as te @@ -57,14 +76,16 @@ class TELinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer. """ + def __init__(self, in_features, out_features, bias=True): - def __init__(self): - # TODO - return + super(TELinear, self).__init__(in_features,out_features,bias) + + + # self.linear = te.pytorch.Linear(in_features, out_features, bias=use_bias, init_method=weight, **kwargs) - def forward(self, x): - # TODO - return + + # def forward(self, x): + # return self.linear(x) class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): @@ -82,22 +103,138 @@ def forward(self, x): return -class TEColumnParallelLinear(TELinear): +class TEColumnParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar to megatron's `ColumnParallelLinear` layer. """ + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ - def __init__(self): - # TODO - return - - def forward(self, x): - # TODO - return - + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + gather_output=True, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + MOE=False, + MoE_mp_size=1, + mup_rescale_parameters=False, + seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout. + ): + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + self.use_bias = bias + + self.sequence_parallel = neox_args.sequence_parallel + self.seq_dim = seq_dim + + self.init_method = init_method + self.stride = stride + self.mup_rescale_parameters = mup_rescale_parameters + self.use_mup = neox_args.use_mup + self.params_dtype=neox_args.params_dtype + self.parallel_mode="column" + # print("##########################") + # print(self.return_bias) + + super(TEColumnParallelLinear, self).__init__(in_features=self.input_size, out_features=self.output_size, + bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, + tp_size=self.world_size, parallel_mode=self.parallel_mode, return_bias=self.skip_bias_add, + params_dtype=self.params_dtype) + + # Copied from Mup + def width_mult(self): + assert hasattr(self.weight, "infshape"), ( + "Please call set_base_shapes(...). If using torch.nn.DataParallel, " + "switch to distributed training with " + "torch.nn.parallel.DistributedDataParallel instead" + ) + return self.weight.infshape.width_mult() -class TERowParallelLinear(TELinear): + # Copied from Mup + def _rescale_parameters(self): + """Rescale parameters to convert SP initialization to μP initialization. + Warning: This method is NOT idempotent and should be called only once + unless you know what you are doing. + """ + if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params: + raise RuntimeError( + "`_rescale_parameters` has been called once before already. " + "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" + "If you called `set_base_shapes` on a model loaded from a checkpoint, " + "or just want to re-set the base shapes of an existing model, " + "make sure to set the flag `rescale_params=False`.\n" + "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." + ) + if self.bias is not None: + self.bias.data *= self.width_mult() ** 0.5 + self.weight.data *= self.width_mult() ** 0.5 + self._has_rescaled_params = True + + def mup_reinitialize_weights(self, neox_args): + if neox_args.use_cpu_initialization: + self.master_weight = _initialize_affine_weight_cpu( + neox_args, + self.weight, + self.output_size, + self.input_size, + self.output_size_per_partition, + 0, + partial(self.init_method, use_mup=True), + stride=self.stride, + return_master_weight=keep_master_weight_for_test, + ) + else: + _initialize_affine_weight_gpu( + self.weight, + partial(self.init_method, use_mup=True), + partition_dim=0, + stride=self.stride, + ) + + def forward(self, inp): + output = super(TEColumnParallelLinear, self).forward(inp) + if self.skip_bias_add: + return output + else: + return output, None + +class TERowParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar to megatron's `RowParallelLinear` layer. diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 5194047d5..dc363ce2c 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -309,11 +309,6 @@ class NeoXArgsModel(NeoXArgsTemplate): Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "reglu", "swiglu", "bilinear", "glu"] """ - use_flashattn_swiglu: bool = False - """ - Use flash attention's version of swiglu - """ - scaled_upper_triang_masked_softmax_fusion: bool = False """ Enable fusion of query_key_value_scaling time (upper diagonal) masking and softmax. @@ -501,7 +496,16 @@ class NeoXArgsModel(NeoXArgsTemplate): """ Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ + + te_linear: bool = False + """ + Use TransformerEngine for Linear, ColumnParallelLinear, and RowParallelLinear layers. + """ + te_attention: bool = False + """ + Use TransformerEngine for attention layers. + """ @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): @@ -1052,9 +1056,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): Dataset implementation, can be one of "gpt2" or "pairwise" """ - train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal" + train_impl: Literal["normal", "dpo", "rm"] = "normal" """ - Training implementation, can be one of "normal", "dpo", "kto", or "rm" + Training implementation, can be one of "normal", "dpo", or "rm" """ dpo_fp32: bool = True @@ -1062,34 +1066,16 @@ class NeoXArgsTraining(NeoXArgsTemplate): Whether to cast logits to fp32 for DPO loss calculation. """ - dpo_reference_free: bool = False - """ - Whether to use reference-free DPO. - """ - dpo_beta: float = 0.1 """ Beta value for DPO """ - kto_fp32: bool = True - """ - Whether to cast logits to fp32 for KTO loss calculation. - """ - - kto_desirable_weight: float = 1.0 - """ - Weight for desirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. - """ - - kto_undesirable_weight: float = 1.0 - """ - Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. - """ - - kto_beta: float = 0.1 + z_loss: float = 0.0 """ - Beta value for KTO + Z-loss parameter, only implemented for RM training currently. + https://arxiv.org/pdf/2204.02311 + https://arxiv.org/pdf/2309.10305 """ allow_chopped: bool = True From fa887b7f2449f29bb976ba87d303b83b84b52bb6 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 02:23:58 -0500 Subject: [PATCH 02/27] Implemented RowParallelLinear with Transformer-Engine --- megatron/model/transformer.py | 22 ++-- megatron/model/transformer_engine.py | 159 ++++++++++++++++++++++++--- megatron/model/utils.py | 13 +++ megatron/neox_arguments/neox_args.py | 9 +- 4 files changed, 172 insertions(+), 31 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 08e5987de..360122f33 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -47,6 +47,7 @@ ) from megatron.model.utils import configure_sparse_attention from deepspeed.moe.layer import MoE +from .utils import linear_implementation_router # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) @@ -107,10 +108,7 @@ def __init__( self.bias_gelu_fusion = neox_args.bias_gelu_fusion self.multiple_of = multiple_of - if neox_args.te_linear: - from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear - else: - from megatron.mpu import ColumnParallelLinear + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) if neox_args.intermediate_size: ffn_dim = neox_args.intermediate_size @@ -147,7 +145,7 @@ def __init__( bias=neox_args.use_bias_in_mlp, ) # Project back to h. - self.linear2 = mpu.RowParallelLinear( + self.linear2 = RowParallelLinear( neox_args=neox_args, input_size=ffn_dim_in, output_size=neox_args.hidden_size, @@ -210,10 +208,7 @@ def __init__( ): super().__init__() - if neox_args.te_linear: - from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear - else: - from megatron.mpu import ColumnParallelLinear + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) self.is_rm = neox_args.train_impl == "rm" parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" @@ -247,7 +242,7 @@ def __init__( # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here # ) else: # Not using cross entropy loss for RMs - self.rm_linear = mpu.RowParallelLinear( + self.rm_linear = RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=1, @@ -333,10 +328,7 @@ def __init__( ): super().__init__() - if neox_args.te_linear: - from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear - else: - from megatron.mpu import ColumnParallelLinear + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" @@ -509,7 +501,7 @@ def __init__( self.attention_dropout = nn.Dropout(self.dropout_p) # Output. - self.dense = mpu.RowParallelLinear( + self.dense = RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size, diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 8a2a2d165..5de2c3459 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -107,11 +107,6 @@ class TEColumnParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar to megatron's `ColumnParallelLinear` layer. - """ - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. Arguments: input_size: first dimension of matrix A. @@ -145,7 +140,7 @@ def __init__( MOE=False, MoE_mp_size=1, mup_rescale_parameters=False, - seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout. + seq_dim=0, ): # Keep input parameters self.input_size = input_size @@ -186,6 +181,12 @@ def width_mult(self): ) return self.weight.infshape.width_mult() + def set_parallel_output(self, value: bool): + assert isinstance(value, bool) + self.gather_output = ( + not value + ) # if gather_output is True, parallel output is False, so we set the opposite + # Copied from Mup def _rescale_parameters(self): """Rescale parameters to convert SP initialization to μP initialization. @@ -227,8 +228,21 @@ def mup_reinitialize_weights(self, neox_args): stride=self.stride, ) - def forward(self, inp): - output = super(TEColumnParallelLinear, self).forward(inp) + def forward(self, inp, **kwargs): + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() + + output = super(TEColumnParallelLinear, self).forward(inp, **kwargs) + + if self.gather_output: + # All-gather across the partitions. + assert ( + not self.sequence_parallel + ), "sequence_parallel=True and gather_output=True are incompatible!" + output = gather_from_model_parallel_region(output_parallel) + else: + output = output_parallel + if self.skip_bias_add: return output else: @@ -238,15 +252,132 @@ class TERowParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar to megatron's `RowParallelLinear` layer. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. """ + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + input_is_parallel=False, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + MOE=False, + MoE_mp_size=1, + parallel_output=False, + mup_rescale_parameters=False, + ): + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + # Divide the weight matrix along the last dimension. + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + self.use_bias = bias + self.input_is_parallel = input_is_parallel + self.sequence_parallel = neox_args.sequence_parallel - def __init__(self): - # TODO - return + self.init_method = init_method + self.stride = stride + self.mup_rescale_parameters = mup_rescale_parameters + self.use_mup = neox_args.use_mup + self.params_dtype=neox_args.params_dtype + self.parallel_mode="row" + + # if self.input_is_parallel: + # self.input_size = divide(self.input_size, self.world_size) - def forward(self, x): - # TODO - return + super(TERowParallelLinear, self).__init__(in_features=self.input_size, out_features=self.output_size, + bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, + tp_size=self.world_size, parallel_mode=self.parallel_mode, return_bias=self.skip_bias_add, + params_dtype=self.params_dtype) + + # Copied from Mup + def width_mult(self): + assert hasattr(self.weight, "infshape"), ( + "Please call set_base_shapes(...). If using torch.nn.DataParallel, " + "switch to distributed training with " + "torch.nn.parallel.DistributedDataParallel instead" + ) + return self.weight.infshape.width_mult() + + # Copied from Mup + def _rescale_parameters(self): + """Rescale parameters to convert SP initialization to μP initialization. + Warning: This method is NOT idempotent and should be called only once + unless you know what you are doing. + """ + if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params: + raise RuntimeError( + "`_rescale_parameters` has been called once before already. " + "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" + "If you called `set_base_shapes` on a model loaded from a checkpoint, " + "or just want to re-set the base shapes of an existing model, " + "make sure to set the flag `rescale_params=False`.\n" + "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." + ) + if self.bias is not None: + self.bias.data *= self.width_mult() ** 0.5 + self.weight.data *= self.width_mult() ** 0.5 + self._has_rescaled_params = True + + def mup_reinitialize_weights(self, neox_args): + if neox_args.use_cpu_initialization: + self.master_weight = _initialize_affine_weight_cpu( + neox_args, + self.weight, + self.output_size, + self.input_size, + self.input_size_per_partition, + 1, + partial(self.init_method, use_mup=True), + stride=self.stride, + return_master_weight=self.keep_master_weight_for_test, + ) + else: + _initialize_affine_weight_gpu( + self.weight, + partial(self.init_method, use_mup=True), + partition_dim=1, + stride=self.stride, + ) + + def set_parallel_output(self, parallel_output: bool): + assert isinstance(parallel_output, bool) + self.parallel_output = parallel_output + + def forward(self, inp, **kwargs): + # if not self.input_is_parallel: + # inp = scatter_to_model_parallel_region(inp) + + output = super(TERowParallelLinear, self).forward(inp, **kwargs) + if self.skip_bias_add: + return output + else: + return output, None class TEDotProductAttention(te.pytorch.DotProductAttention): diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 8176f1f7a..d1ec2a347 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -402,3 +402,16 @@ def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): for name, param in module_.named_parameters(): if param.requires_grad: param.register_hook(reduce_weight_grads_from_model_parallel_region) + + +def linear_implementation_router(neox_args): + if neox_args.te_columnparallel: + from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear + else: + from megatron.mpu import ColumnParallelLinear + if neox_args.te_rowparallel: + from megatron.model.transformer_engine import TERowParallelLinear as RowParallelLinear + else: + from megatron.mpu import RowParallelLinear + + return ColumnParallelLinear, RowParallelLinear \ No newline at end of file diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index dc363ce2c..9bf86ccd6 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -497,9 +497,14 @@ class NeoXArgsModel(NeoXArgsTemplate): Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ - te_linear: bool = False + te_columnparallel: bool = False """ - Use TransformerEngine for Linear, ColumnParallelLinear, and RowParallelLinear layers. + Use TransformerEngine for RowParallelLinear layer. + """ + + te_rowparallel: bool = False + """ + Use TransformerEngine for ColumnParallelLinear layer. """ te_attention: bool = False From 0a6f1406c0137ce9b86f4d29f900c98632775c3c Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 16:25:01 -0500 Subject: [PATCH 03/27] Implemented LayerNormMLP with Transformer-Engine --- megatron/model/transformer.py | 29 +++++++-- megatron/model/transformer_engine.py | 97 ++++++++++++++++++++++------ megatron/neox_arguments/neox_args.py | 4 +- 3 files changed, 102 insertions(+), 28 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 360122f33..5003ef1d5 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -132,7 +132,6 @@ def __init__( ffn_dim_in = int( self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) ) - self.linear1 = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -383,7 +382,7 @@ def __init__( if not self.gqa: # Strided linear layer. - self.query_key_value = ColumnParallelLinear( + self.query_key_value = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, @@ -393,7 +392,7 @@ def __init__( ) else: # QKV proj is smaller if we are using GQA / MQA - self.query_key_value = ColumnParallelLinear( + self.query_key_value = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, @@ -1045,6 +1044,17 @@ def get_mlp(**kw): **kw, ) + def get_te_lnmlp(**kw): + from megatron.model.transformer_engine import TELayerNormMLP + return TELayerNormMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + multiple_of=neox_args.mlp_multiple_of, + **kw, + ) + self.num_experts = ( neox_args.moe_num_experts if layer_number % neox_args.expert_interval == 0 @@ -1052,7 +1062,10 @@ def get_mlp(**kw): ) args = neox_args if self.num_experts <= 1: - self.mlp = get_mlp() + if neox_args.te_layernorm_mlp: + self.mlp = get_te_lnmlp() + else: + self.mlp = get_mlp() else: from torch import distributed as dist @@ -1171,9 +1184,15 @@ def forward(self, x, attention_mask, layer_past=None): residual = x # applies the correct normalization depending on if the norms are tied - if self.gpt_j_tied: + if self.gpt_j_tied and not neox_args.te_layernorm_mlp: x = self.input_layernorm(x) x1, x2 = x, x + elif self.gpt_j_tied and neox_args.te_layernorm_mlp: + x2 = x + x = self.input_layernorm(x) + x1 = x + elif neox_args.te_layernorm_mlp: + x1, x2 = self.input_layernorm(x), x else: x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 5de2c3459..7c69cee1f 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -5,6 +5,8 @@ import torch.nn.init as init from torch.nn.parameter import Parameter +from megatron.model.transformer import Gated_Activation +from megatron.model.activations import get_activation from megatron.mpu.initialize import get_model_parallel_rank from megatron.mpu.initialize import get_model_parallel_world_size from megatron.mpu.initialize import get_tensor_model_parallel_group @@ -88,19 +90,84 @@ def __init__(self, in_features, out_features, bias=True): # return self.linear(x) -class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): +class TELayerNormMLP(te.pytorch.LayerNormMLP): """ - Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines - layernorm and linear layers + Wrapper for the Transformer-Engine's `LayerNormMLP` layer that combines + layernorm and followed by the MLP module, consisting of 2 successive + linear transformations, separated by the GeLU activation. """ - def __init__(self): - # TODO - return + def __init__( + self, + neox_args, + init_method, + output_layer_init_method, + parallel_output=False, + multiple_of=256, + MOE=False, + MoE_mp_size=1, + bias=True + ): + self.activation_func, self.is_gated = get_activation(neox_args) + self.activation_type = neox_args.activation + self.bias_gelu_fusion = neox_args.bias_gelu_fusion + self.multiple_of = multiple_of + self.bias = bias + self.init_method = init_method + self.output_layer_init_method = output_layer_init_method - def forward(self, x): - # TODO - return + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.sequence_parallel = neox_args.sequence_parallel + self.seq_len = neox_args.seq_length + self.batch_size = neox_args.train_micro_batch_size_per_gpu + self.params_dtype=neox_args.params_dtype + self.set_parallel_mode=False + if world_size > 1: + self.set_parallel_mode=True + + if neox_args.intermediate_size: + ffn_dim = neox_args.intermediate_size + elif neox_args.expansion_factor: + ffn_dim = int(neox_args.expansion_factor * neox_args.hidden_size) + else: + # 4h is default for ffn_dim + ffn_dim = 4 * neox_args.hidden_size + ffn_dim_in = ffn_dim + if self.is_gated: + # set activation function to be gated implementation + self.activation_func = Gated_Activation(self.activation_func) + # auto scale so gated activations has equal parameters + ffn_dim = int(ffn_dim * 2 / 3) + ffn_dim_in = ffn_dim // 2 + # set multiple + ffn_dim = int( + (2 * self.multiple_of) + * ((ffn_dim + (2 * multiple_of) - 1) // (2 * multiple_of)) + ) + ffn_dim_in = int( + self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) + ) + + if neox_args.norm in ['layernorm','te_layernorm']: + self.eps=1.0e-5 + self.normalization = 'LayerNorm' + elif neox_args.norm == ['rmsnorm','te_rmsnorm']: + self.eps=1.0e-8 + self.normalization = 'RMSNorm' + #TODO handle case if norm is not rmsnorm or layernorm + #TODO check if activation in list ‘gelu’, ‘geglu’, ‘relu’, ‘reglu’, ‘squared_relu’, + #‘swiglu’, ‘qgelu’, ‘srelu’ + #TODO handle MOE and mup + + super(TELayerNormMLP, self).__init__(hidden_size=neox_args.hidden_size, ffn_hidden_size=ffn_dim, + eps=self.eps, bias=self.bias, normalization=self.normalization, activation=neox_args.activation, + init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, + device=torch.cuda.current_device(), set_parallel_mode=self.set_parallel_mode, + sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, tp_size=self.world_size, + return_bias=neox_args.use_bias_in_mlp, params_dtype=self.params_dtype, seq_length=self.seq_len, + micro_batch_size=self.batch_size) class TEColumnParallelLinear(te.pytorch.Linear): @@ -234,15 +301,6 @@ def forward(self, inp, **kwargs): output = super(TEColumnParallelLinear, self).forward(inp, **kwargs) - if self.gather_output: - # All-gather across the partitions. - assert ( - not self.sequence_parallel - ), "sequence_parallel=True and gather_output=True are incompatible!" - output = gather_from_model_parallel_region(output_parallel) - else: - output = output_parallel - if self.skip_bias_add: return output else: @@ -305,9 +363,6 @@ def __init__( self.use_mup = neox_args.use_mup self.params_dtype=neox_args.params_dtype self.parallel_mode="row" - - # if self.input_is_parallel: - # self.input_size = divide(self.input_size, self.world_size) super(TERowParallelLinear, self).__init__(in_features=self.input_size, out_features=self.output_size, bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 9bf86ccd6..0973cb00c 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -507,9 +507,9 @@ class NeoXArgsModel(NeoXArgsTemplate): Use TransformerEngine for ColumnParallelLinear layer. """ - te_attention: bool = False + te_layernorm_mlp: bool = False """ - Use TransformerEngine for attention layers. + Use TransformerEngine for LayerNormMLP layer. """ @dataclass From 5cba717a182c2702e6d1ee1c10746281d759e02c Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 19:10:27 -0500 Subject: [PATCH 04/27] Implemented MultiheadAttention with Transformer-Engine --- megatron/model/transformer.py | 132 +++++++++++++++++++++++++-- megatron/model/transformer_engine.py | 117 ++++++++++++++++++++---- megatron/neox_arguments/neox_args.py | 5 + 3 files changed, 229 insertions(+), 25 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 5003ef1d5..099a0a899 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -327,8 +327,6 @@ def __init__( ): super().__init__() - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) - self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func @@ -748,6 +746,106 @@ def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe ) + def te_attention( + self, query_layer, key_layer, value_layer, layer_past, attention_mask + ): + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + output_size[2], output_size[0] * output_size[1], -1 + ) + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # preallocating result tensor: [b * np, sq, sk] + matmul_result = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=torch.cuda.current_device(), + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== + + if self.use_cache: + with torch.no_grad(): + attention_mask = attention_mask[ + ..., : attention_scores.size(3), : attention_scores.size(3) + ] + + # =========================== + # Attention probs and dropout + # =========================== + + if exists(self.rpe): + rpe = self.rpe(query_layer.size(0), key_layer.size(0)) + attention_scores += rpe # [1, np, sq, sk] + + if self.pos_emb == "alibi": + attention_scores = self.alibi_embed(attention_scores) + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) + + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + return context_layer + def gqa_project(self, hidden_states, attention_mask, layer_past=None): # QKV projection and separation into separate Q/K/V layers for GQA, # where KV projections may be smaller than Q projection. @@ -1016,7 +1114,9 @@ def __init__( ) # Self attention. - self.attention = ParallelSelfAttention( + if neox_args.te_mha: + from megatron.model.transformer_engine import TEMultiheadAttention + self.attention = TEMultiheadAttention( neox_args=neox_args, attention_mask_func=attention_mask_func, init_method=init_method, @@ -1026,7 +1126,20 @@ def __init__( use_cache=self.use_cache, rotary=rotary, parallel_output=self.gpt_j_residual, - ) + ) + + else: + self.attention = ParallelSelfAttention( + neox_args=neox_args, + attention_mask_func=attention_mask_func, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + rpe=rpe, + use_cache=self.use_cache, + rotary=rotary, + parallel_output=self.gpt_j_residual, + ) # Layernorm on the output of the attention layer. # If GPT-J residuals are used, this is surpurfulous but leaving it in @@ -1184,14 +1297,14 @@ def forward(self, x, attention_mask, layer_past=None): residual = x # applies the correct normalization depending on if the norms are tied - if self.gpt_j_tied and not neox_args.te_layernorm_mlp: + if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: x = self.input_layernorm(x) x1, x2 = x, x - elif self.gpt_j_tied and neox_args.te_layernorm_mlp: + elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: x2 = x x = self.input_layernorm(x) x1 = x - elif neox_args.te_layernorm_mlp: + elif self.neox_args.te_layernorm_mlp: x1, x2 = self.input_layernorm(x), x else: x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) @@ -1263,7 +1376,10 @@ def forward(self, x, attention_mask, layer_past=None): ) # output = x + mlp(ln2(x)) - layernorm_output = self.post_attention_layernorm(attention_output) + if self.neox_args.te_layernorm_mlp: + layernorm_output = attention_output + else: + layernorm_output = self.post_attention_layernorm(attention_output) mlp_bias = torch.tensor( 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype ) diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 7c69cee1f..9a8c0a506 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -78,16 +78,51 @@ class TELinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer. """ - def __init__(self, in_features, out_features, bias=True): + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + init_method=init.xavier_normal_, + stride=1, + skip_bias_add=False, + mup_rescale_parameters=False, + seq_dim=0, + ): + # Keep input parameters + self.input_size = input_size + self.output_size = output_size - super(TELinear, self).__init__(in_features,out_features,bias) - + self.skip_bias_add = skip_bias_add + self.use_bias = bias - # self.linear = te.pytorch.Linear(in_features, out_features, bias=use_bias, init_method=weight, **kwargs) + self.sequence_parallel = neox_args.sequence_parallel + self.seq_dim = seq_dim + + self.init_method = init_method + self.stride = stride + self.mup_rescale_parameters = mup_rescale_parameters + self.use_mup = neox_args.use_mup + self.params_dtype=neox_args.params_dtype + # print("##########################") + # print(self.return_bias) + + super(TELinear, self).__init__(in_features=self.input_size, out_features=self.output_size, + bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), return_bias=self.skip_bias_add, params_dtype=self.params_dtype) + + def forward(self, inp, **kwargs): + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() + + output = super(TELinear, self).forward(inp, **kwargs) - # def forward(self, x): - # return self.linear(x) + if self.skip_bias_add: + return output + else: + return output, None class TELayerNormMLP(te.pytorch.LayerNormMLP): @@ -121,7 +156,7 @@ def __init__( self.tp_group = get_tensor_model_parallel_group() self.sequence_parallel = neox_args.sequence_parallel self.seq_len = neox_args.seq_length - self.batch_size = neox_args.train_micro_batch_size_per_gpu + self.micro_batch_size = neox_args.train_micro_batch_size_per_gpu self.params_dtype=neox_args.params_dtype self.set_parallel_mode=False if world_size > 1: @@ -166,8 +201,8 @@ def __init__( init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, device=torch.cuda.current_device(), set_parallel_mode=self.set_parallel_mode, sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, tp_size=self.world_size, - return_bias=neox_args.use_bias_in_mlp, params_dtype=self.params_dtype, seq_length=self.seq_len, - micro_batch_size=self.batch_size) + return_bias=True, params_dtype=self.params_dtype, seq_length=self.seq_len, + micro_batch_size=self.micro_batch_size) class TEColumnParallelLinear(te.pytorch.Linear): @@ -435,19 +470,67 @@ def forward(self, inp, **kwargs): return output, None -class TEDotProductAttention(te.pytorch.DotProductAttention): +class TEMultiheadAttention(te.pytorch.MultiheadAttention): """ - Wrapper for the Transformer-Engine's `DotProductAttention` layer that also + Wrapper for the Transformer-Engine's `MultiheadAttention` layer that also has "flash attention" enabled. """ - def __init__(self): - # TODO - return + def __init__(self, + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=None, + rotary=False, + use_cache=False, + parallel_output=False): - def forward(self, x): - # TODO - return + self.attention_mask_func = attention_mask_func + self.init_method = init_method + self.output_layer_init_method = output_layer_init_method + self.layer_number = layer_number + 1 + + world_size = get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.sequence_parallel = neox_args.sequence_parallel + self.seq_len = neox_args.seq_length + self.micro_batch_size = neox_args.train_micro_batch_size_per_gpu + self.params_dtype=neox_args.params_dtype + self.set_parallel_mode=False + if world_size > 1: + self.set_parallel_mode=True + + if neox_args.norm in ['layernorm','te_layernorm']: + self.eps=1.0e-5 + self.normalization = 'LayerNorm' + elif neox_args.norm == ['rmsnorm','te_rmsnorm']: + self.eps=1.0e-8 + self.normalization = 'RMSNorm' + + if ( + not neox_args.num_kv_heads + or neox_args.num_kv_heads == neox_args.num_attention_heads + ): + self.gqa = False + self.num_kv_heads = None + else: + self.gqa = True + self.num_kv_heads = neox_args.num_kv_heads + + super(TEMultiheadAttention, self).__init__(hidden_size=neox_args.hidden_size, num_attention_heads=neox_args.num_attention_heads, + attention_dropout=neox_args.attention_dropout, layernorm_epsilon=self.eps, init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, layer_number=self.layer_number, + window_size=neox_args.sliding_window_width, num_gqa_groups=self.num_kv_heads, input_layernorm=False, + normalization=self.normalization, bias=True, device=torch.cuda.current_device(), + set_parallel_mode=self.set_parallel_mode, sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, + tp_size=self.world_size, params_dtype=self.params_dtype, return_bias=True) + + def forward(self, hidden_states, attention_mask, layer_past=None, **kwargs): + output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, **kwargs) + return output class TEDelayedScaling(te.common.recipe.DelayedScaling): diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 0973cb00c..7ba6d1000 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -512,6 +512,11 @@ class NeoXArgsModel(NeoXArgsTemplate): Use TransformerEngine for LayerNormMLP layer. """ + te_mha: bool = False + """ + Use TransformerEngine for MultiheadAttention layer. + """ + @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): """ From 94e552cc5a1e8fb0646aa33c3fa4183510157943 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 19:23:11 -0500 Subject: [PATCH 05/27] Cleaned up transformer.py --- megatron/model/transformer.py | 137 +++++++--------------------------- 1 file changed, 25 insertions(+), 112 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 099a0a899..62f316f3e 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -18,6 +18,8 @@ """Transformer.""" import math +from contextlib import nullcontext + import torch import torch.nn.functional as F import torch.nn as nn @@ -47,6 +49,12 @@ ) from megatron.model.utils import configure_sparse_attention from deepspeed.moe.layer import MoE + +try: + from flash_attn.ops.activations import swiglu +except ImportError: + swiglu = None + from .utils import linear_implementation_router # flags required to enable jit fusion kernels @@ -120,7 +128,12 @@ def __init__( ffn_dim_in = ffn_dim if self.is_gated: # set activation function to be gated implementation - self.activation_func = Gated_Activation(self.activation_func) + self.activation_func = Gated_Activation( + self.activation_func, + (swiglu is not None) + and (neox_args.activation == "swiglu") + and neox_args.use_flashattn_swiglu, + ) # auto scale so gated activations has equal parameters ffn_dim = int(ffn_dim * 2 / 3) ffn_dim_in = ffn_dim // 2 @@ -160,10 +173,6 @@ def __init__( def forward(self, hidden_states): # [s, b, intermediate_size] intermediate_parallel, bias_parallel = self.linear1(hidden_states) - # output = self.linear1(hidden_states) - # print(output) - # import sys - # sys.exit() if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): intermediate_parallel = self.activation_func( intermediate_parallel, bias_parallel @@ -179,9 +188,10 @@ def forward(self, hidden_states): class Gated_Activation(torch.nn.Module): - def __init__(self, activation_func): + def __init__(self, activation_func, use_swiglu=False): super().__init__() self.activation_func = activation_func + self.use_swiglu = use_swiglu def forward(self, x, bias=None): x, gate = x.chunk(2, dim=-1) @@ -189,8 +199,11 @@ def forward(self, x, bias=None): bias_1, bias_2 = bias.chunk(2, dim=-1) x = x + bias_1 gate = gate + bias_2 - intermediate_parallel = self.activation_func(gate) - return intermediate_parallel * x + if not self.use_swiglu: + intermediate_parallel = self.activation_func(gate) + return intermediate_parallel * x + else: + return swiglu(gate, x) class ParallelLinear(nn.Module): @@ -746,106 +759,6 @@ def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe ) - def te_attention( - self, query_layer, key_layer, value_layer, layer_past, attention_mask - ): - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = ( - query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0), - ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view( - output_size[2], output_size[0] * output_size[1], -1 - ) - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - # preallocating result tensor: [b * np, sq, sk] - matmul_result = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device(), - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - # ================================================== - # Update attention mask for inference. [b, np, sq, sk] - # ================================================== - - if self.use_cache: - with torch.no_grad(): - attention_mask = attention_mask[ - ..., : attention_scores.size(3), : attention_scores.size(3) - ] - - # =========================== - # Attention probs and dropout - # =========================== - - if exists(self.rpe): - rpe = self.rpe(query_layer.size(0), key_layer.size(0)) - attention_scores += rpe # [1, np, sq, sk] - - if self.pos_emb == "alibi": - attention_scores = self.alibi_embed(attention_scores) - - # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - with mpu.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = ( - value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3), - ) - - # change view [sk, b * np, hn] - value_layer = value_layer.view( - value_layer.size(0), output_size[0] * output_size[1], -1 - ) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view( - output_size[0] * output_size[1], output_size[2], -1 - ) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - return context_layer - def gqa_project(self, hidden_states, attention_mask, layer_past=None): # QKV projection and separation into separate Q/K/V layers for GQA, # where KV projections may be smaller than Q projection. @@ -1318,7 +1231,7 @@ def forward(self, x, attention_mask, layer_past=None): self.layer_past = presents if attention_bias is not None: - with torch.enable_grad(): + with torch.enable_grad() if not self.eval else nullcontext(): attention_output = bias_dropout_fn( attention_output, bias=attention_bias.expand_as(attention_output), @@ -1329,7 +1242,7 @@ def forward(self, x, attention_mask, layer_past=None): # mlp operator mlp_output, mlp_bias = self.mlp(x2) if mlp_bias is not None: - with torch.enable_grad(): + with torch.enable_grad() if not self.eval else nullcontext(): output = bias_dropout_fn( mlp_output, bias=mlp_bias.expand_as(mlp_output), @@ -1355,7 +1268,7 @@ def forward(self, x, attention_mask, layer_past=None): if self.use_cache: attention_output, presents = attention_output self.layer_past = presents - with torch.enable_grad(): + with torch.enable_grad() if not self.eval else nullcontext(): if attention_bias is not None: # Use special bias_dropout_fn if we have a bias term from the above attention layer attention_output = bias_dropout_fn( @@ -1397,7 +1310,7 @@ def forward(self, x, attention_mask, layer_past=None): else: raise KeyError(self.moe_type) - with torch.enable_grad(): + with torch.enable_grad() if not self.eval else nullcontext(): if ( self.activation == "swiglu" or self.num_experts > 1 From 40e10191c258081210664aa2cc2ae31f1abf3c56 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 19:25:20 -0500 Subject: [PATCH 06/27] Cleaned up neox_args --- megatron/neox_arguments/neox_args.py | 35 ++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 7ba6d1000..01d467791 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -309,6 +309,11 @@ class NeoXArgsModel(NeoXArgsTemplate): Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "reglu", "swiglu", "bilinear", "glu"] """ + use_flashattn_swiglu: bool = False + """ + Use flash attention's version of swiglu + """ + scaled_upper_triang_masked_softmax_fusion: bool = False """ Enable fusion of query_key_value_scaling time (upper diagonal) masking and softmax. @@ -1066,9 +1071,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): Dataset implementation, can be one of "gpt2" or "pairwise" """ - train_impl: Literal["normal", "dpo", "rm"] = "normal" + train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal" """ - Training implementation, can be one of "normal", "dpo", or "rm" + Training implementation, can be one of "normal", "dpo", "kto", or "rm" """ dpo_fp32: bool = True @@ -1076,16 +1081,36 @@ class NeoXArgsTraining(NeoXArgsTemplate): Whether to cast logits to fp32 for DPO loss calculation. """ + dpo_reference_free: bool = False + """ + Whether to use reference-free DPO. + """ + dpo_beta: float = 0.1 """ Beta value for DPO """ + kto_fp32: bool = True + """ + Whether to cast logits to fp32 for KTO loss calculation. + """ + + kto_desirable_weight: float = 1.0 + """ + Weight for desirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + """ + + kto_undesirable_weight: float = 1.0 + """ + Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + """ + + kto_beta: float = 0.1 + z_loss: float = 0.0 """ - Z-loss parameter, only implemented for RM training currently. - https://arxiv.org/pdf/2204.02311 - https://arxiv.org/pdf/2309.10305 + Beta value for KTO """ allow_chopped: bool = True From 885e72c3995997688ce245bd0da84cd0b505941a Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 19:26:04 -0500 Subject: [PATCH 07/27] Cleaned up neox_args --- megatron/neox_arguments/neox_args.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 01d467791..76d80778d 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1107,8 +1107,6 @@ class NeoXArgsTraining(NeoXArgsTemplate): """ kto_beta: float = 0.1 - - z_loss: float = 0.0 """ Beta value for KTO """ From fe8f22a0b6db0ea33320e5b1034221f7fac5274d Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Thu, 26 Sep 2024 08:18:51 +0000 Subject: [PATCH 08/27] - Fixed TE_MHA and added rope support - Implemented delayed scaling --- megatron/model/positional_embeddings.py | 1578 ++++++++++++++++++++--- megatron/model/transformer.py | 260 ++-- megatron/model/transformer_engine.py | 127 +- megatron/neox_arguments/neox_args.py | 40 + 4 files changed, 1638 insertions(+), 367 deletions(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index fcded9e96..fdf384a4f 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -1,4 +1,7 @@ -# Copyright (c) 2024, EleutherAI +# Copyright (c) 2024 EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,241 +15,1406 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch +"""Transformer.""" + import math +from contextlib import nullcontext + +import torch +import torch.nn.functional as F +import torch.nn as nn +from pkg_resources import packaging +from importlib.metadata import version +from .norms import get_norm +from megatron import mpu +from megatron.model import megablocks_utils +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.activations import get_activation +from megatron.model.utils import exists, get_fusion_type +from megatron.model.positional_embeddings import ( + RotaryEmbedding, + apply_rotary_pos_emb_torch, + apply_rotary_pos_emb, + AliBi, +) +from megatron.model.fused_rope import ( + FusedRoPEFunc, + fused_apply_rotary_pos_emb_cached, +) +from megatron.model.fused_bias_dropout import ( + get_bias_dropout_add, + bias_dropout_add_fused_train, + bias_dropout_add_fused_inference, +) +from megatron.model.utils import configure_sparse_attention +from deepspeed.moe.layer import MoE + +try: + from flash_attn.ops.activations import swiglu +except ImportError: + swiglu = None + +from .utils import linear_implementation_router + +# flags required to enable jit fusion kernels +torch._C._jit_set_profiling_mode(False) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_override_can_fuse_on_cpu(True) +torch._C._jit_override_can_fuse_on_gpu(True) + +""" We use the following notation throughout this file: + h: hidden size + n: number of attention heads + kv: number of key or value heads + p: number of model parallel partitions + np: n/p + kvp: kv/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + l: number of layers + Transformer takes input of size [s, b, h] and returns a + tensor of the same size. We use the following arguments: + hyperparameters: transformer hyperparameters + attention_mask_func: a function that takes `unmasked-attention-scores` + with size [b, np, s, s] and an `attention-mask` and will apply + the masking. The function should return a masked score of the + same size [b, np, s, s]. + masked-attention-scores = attention_mask_func( + unmasked-attention-scores, attention-mask) +""" + + +class ParallelMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + """ -class SinusoidalPositionalEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000, precision=torch.half): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - self.precision = precision - - def forward(self, x, seq_dim=1): - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) - sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq) - if self.precision == torch.bfloat16: - sinusoid_inp = sinusoid_inp.float() - sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos() - if self.precision == torch.bfloat16: - sin, cos = sin.bfloat16(), cos.bfloat16() - emb = torch.cat((sin, cos), dim=-1) - return emb[None, :, :] - - -class RotaryEmbedding(torch.nn.Module): def __init__( - self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False + self, + neox_args, + init_method, + output_layer_init_method, + parallel_output=False, + multiple_of=256, + MOE=False, + MoE_mp_size=1, ): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) - self.seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - self.precision = precision - self.max_seq_len = max_seq_len - self.base = base - self.dim = dim + assert ( + neox_args.intermediate_size == None or neox_args.expansion_factor == None + ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" + + self.activation_func, self.is_gated = get_activation(neox_args) + self.activation_type = neox_args.activation + self.bias_gelu_fusion = neox_args.bias_gelu_fusion + self.multiple_of = multiple_of - # precompute cos_cached, sin_cached in fp32 - cos_cached, sin_cached, inv_freq = self._prepare_cache( - max_seq_len, precision, base + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) + + if neox_args.intermediate_size: + ffn_dim = neox_args.intermediate_size + elif neox_args.expansion_factor: + ffn_dim = int(neox_args.expansion_factor * neox_args.hidden_size) + else: + # 4h is default for ffn_dim + ffn_dim = 4 * neox_args.hidden_size + ffn_dim_in = ffn_dim + if self.is_gated: + # set activation function to be gated implementation + self.activation_func = Gated_Activation( + self.activation_func, + (swiglu is not None) + and (neox_args.activation == "swiglu") + and neox_args.use_flashattn_swiglu, + ) + # auto scale so gated activations has equal parameters + ffn_dim = int(ffn_dim * 2 / 3) + ffn_dim_in = ffn_dim // 2 + # set multiple + ffn_dim = int( + (2 * self.multiple_of) + * ((ffn_dim + (2 * multiple_of) - 1) // (2 * multiple_of)) + ) + ffn_dim_in = int( + self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) + ) + self.linear1 = ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=ffn_dim, + gather_output=False, + init_method=init_method, + skip_bias_add=True, + MOE=MOE, + MoE_mp_size=MoE_mp_size, + bias=neox_args.use_bias_in_mlp, + ) + # Project back to h. + self.linear2 = RowParallelLinear( + neox_args=neox_args, + input_size=ffn_dim_in, + output_size=neox_args.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + parallel_output=parallel_output, + skip_bias_add=True, + MOE=MOE, + MoE_mp_size=MoE_mp_size, + bias=neox_args.use_bias_in_mlp, ) - self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) - self.cos_cached = cos_cached - self.sin_cached = sin_cached + def forward(self, hidden_states): + # [s, b, intermediate_size] + intermediate_parallel, bias_parallel = self.linear1(hidden_states) + if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): + intermediate_parallel = self.activation_func( + intermediate_parallel, bias_parallel + ) + else: + intermediate_parallel = self.activation_func( + intermediate_parallel + bias_parallel + ) - def _prepare_cache(self, seq_len, precision, base): - # precompute cos_cached, sin_cached in fp32 - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + # [s, b, h] + output, output_bias = self.linear2(intermediate_parallel) + return output, output_bias - t = torch.arange(seq_len).type_as(inv_freq) - freqs = torch.einsum("i,j->ij", t, inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - cos_cached = emb.cos()[:, None, None, :] - sin_cached = emb.sin()[:, None, None, :] +class Gated_Activation(torch.nn.Module): + def __init__(self, activation_func, use_swiglu=False): + super().__init__() + self.activation_func = activation_func + self.use_swiglu = use_swiglu - return ( - cos_cached.to(precision), - sin_cached.to(precision), - inv_freq.to(precision), - ) + def forward(self, x, bias=None): + x, gate = x.chunk(2, dim=-1) + if bias is not None: + bias_1, bias_2 = bias.chunk(2, dim=-1) + x = x + bias_1 + gate = gate + bias_2 + if not self.use_swiglu: + intermediate_parallel = self.activation_func(gate) + return intermediate_parallel * x + else: + return swiglu(gate, x) + + +class ParallelLinear(nn.Module): + """ + A Parallel Linear Layer transforming the transformer outputs from hidden_size -> vocab_size + """ - def forward(self, x, seq_dim=0, seq_len=None): - if seq_len is None: - seq_len = x.shape[seq_dim] + def __init__( + self, + neox_args, + parallel_output=True, + init_method=nn.init.xavier_normal_, + is_last_layer=False, + ): + super().__init__() - assert seq_len <= self.max_seq_len + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) - if seq_len != self.max_seq_len: - # y, z, _ = self._prepare_cache(seq_len, self.precision, self.base) - return ( - self.cos_cached[:seq_len, ...].to(x.device), - self.sin_cached[:seq_len, ...].to(x.device), + self.is_rm = neox_args.train_impl == "rm" + parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" + if parallelism == "column": + self.final_linear = ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.padded_vocab_size, + bias=False, + init_method=init_method, + gather_output=not parallel_output, + skip_bias_add=False, + mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here + seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0 ) else: - return self.cos_cached.to(x.device), self.sin_cached.to(x.device) + if not self.is_rm: + print( + 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.' + ) + exit() + # self.final_linear = mpu.RowParallelLinear( + # neox_args=neox_args, + # input_size=neox_args.hidden_size, + # output_size=neox_args.padded_vocab_size, + # bias=False, + # input_is_parallel=False, + # init_method=init_method, + # parallel_output=parallel_output, + # skip_bias_add=False, + # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here + # ) + else: # Not using cross entropy loss for RMs + self.rm_linear = RowParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=1, + bias=False, + input_is_parallel=False, + init_method=init_method, + parallel_output=False, + skip_bias_add=False, + mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here + ) + def forward(self, hidden_states): + if not self.is_rm: + return self.final_linear(hidden_states) + else: + return self.rm_linear(hidden_states) + + +class _MegablocksAdapter(nn.Module): + def __init__( + self, neox_args, layer_cls, init_method, output_layer_init_method, ep_group + ): + super().__init__() + megablocks_utils.assert_megablocks_is_available() + args = megablocks_utils.as_megablocks_args(neox_args) + args.device = torch.cuda.current_device() + args.init_method = init_method + args.output_layer_init_method = output_layer_init_method + + # NOTE: Shard the MoE layers over the data parallel group. Expert + # parallel sharding and data parallel sharding could be decoupled + # by extending the optimizer to handle data parallel reductions for + # MoE and non-MoE parameters separately. + if args.moe_expert_model_parallelism: + args.expert_parallel_group = ep_group + + self.moe = layer_cls(args) -# rotary pos emb helpers: + def forward(self, x): + return self.moe.forward(x) -def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - return torch.cat( - (-x2, x1), dim=x1.ndim - 1 - ) # dim=-1 triggers a bug in earlier torch versions +class MbMoE(_MegablocksAdapter): + def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): + super().__init__( + neox_args, + megablocks_utils.moe.MoE, + init_method, + output_layer_init_method, + ep_group, + ) -@torch.jit.script -def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): - cos, sin = ( - cos[offset : q.shape[0] + offset, ...], - sin[offset : q.shape[0] + offset, ...], - ) - return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) +class dMoE(_MegablocksAdapter): + def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): + super().__init__( + neox_args, + megablocks_utils.dmoe.dMoE, + init_method, + output_layer_init_method, + ep_group, + ) -def apply_rotary_pos_emb_torch( - q, k, cos, sin, offset: int = 0 -): # jitting fails with bf16 - cos, sin = ( - cos[offset : q.shape[0] + offset, ...], - sin[offset : q.shape[0] + offset, ...], - ) - return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) +class ParallelSelfAttention(nn.Module): + """Parallel self-attention layer abstract class. + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + """ -class AliBi(torch.nn.Module): - def __init__(self, num_heads, mp_size=1, mp_rank=1): + def __init__( + self, + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=None, + rotary=False, + use_cache=False, + parallel_output=False, + ): super().__init__() - # megatron splits across heads, so we need to make sure each - # head receives the correct matrix - assert mp_size <= num_heads and mp_rank <= mp_size - self.mp_size = mp_size - self.mp_rank = mp_rank - self.num_heads = num_heads - self.slice_size = num_heads // mp_size - self.cached_matrix = None - self.cached_seq_len = None - slopes = torch.Tensor(self._get_slopes(num_heads))[ - mp_rank * self.slice_size : (mp_rank + 1) * self.slice_size - ] - self.register_buffer("slopes", slopes) - - def _get_slopes(self, n): - """ - Get slopes for Alibi positional embedding - n : int = number of heads. - For best performance, restrict n to a power of 2. - """ - - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + self._get_slopes(2 * closest_power_of_2)[0::2][ - : n - closest_power_of_2 - ] + + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) + + self.fp16 = neox_args.precision == "fp16" + self.bf16 = neox_args.precision == "bfloat16" + self.attention_mask_func = attention_mask_func + self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling + self.use_cache = use_cache + self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = layer_number + # Per attention head and per partition values. + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) + self.hidden_size_per_attention_head = mpu.divide( + neox_args.hidden_size, neox_args.num_attention_heads + ) + self.num_attention_heads_per_partition = mpu.divide( + neox_args.num_attention_heads, world_size + ) + self.pos_emb = neox_args.pos_emb + + self.use_qk_layernorm = neox_args.use_qk_layernorm + if self.use_qk_layernorm: + norm, eps = get_norm(neox_args) + self.qk_layernorm = norm( + [ + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ], + eps=eps, + ) + + self.sliding_window_width = neox_args.sliding_window_width + + if ( + not neox_args.num_kv_heads + or neox_args.num_kv_heads == neox_args.num_attention_heads + ): + self.gqa = False + else: + self.gqa = True + if self.gqa: + self.num_kv_heads_per_partition = mpu.divide( + neox_args.num_kv_heads, world_size + ) # we do not yet clone KV heads in MQA across TP ranks... + self.kv_hidden_size = ( + neox_args.num_kv_heads * self.hidden_size_per_attention_head + ) # how large the total hidden dim for each of K and V is + else: + self.num_kv_heads_per_partition = self.num_attention_heads_per_partition + self.kv_hidden_size = neox_args.hidden_size + + if not self.gqa: + # Strided linear layer. + self.query_key_value = ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=3 * neox_args.hidden_size, + gather_output=False, + init_method=init_method, + bias=neox_args.use_bias_in_attn_linear, + ) + else: + # QKV proj is smaller if we are using GQA / MQA + self.query_key_value = ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, + gather_output=False, + init_method=init_method, + bias=neox_args.use_bias_in_attn_linear, + ) + + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = max(1, self.layer_number) + self.norm_factor *= coeff + + if neox_args.use_mup: + self.norm_factor = self.hidden_size_per_attention_head + + self.rpe = rpe + + if self.pos_emb == "alibi": + self.alibi_embed = AliBi( + neox_args.num_attention_heads, + neox_args.model_parallel_size, + mpu.get_model_parallel_rank(), + ) + + # TODO: this arg shouldn't need to be passed in - get from neox_args + if rotary: + if neox_args.rotary_pct == 1: + self.rotary_ndims = None + else: + assert neox_args.rotary_pct < 1 + self.rotary_ndims = int( + self.hidden_size_per_attention_head * neox_args.rotary_pct + ) + dim = ( + self.rotary_ndims + if self.rotary_ndims is not None + else self.hidden_size_per_attention_head + ) + self.rotary_emb = RotaryEmbedding( + dim, + base=neox_args.rotary_emb_base, + max_seq_len=neox_args.seq_length, + precision=neox_args.params_dtype, + save_inv_freqs=neox_args.rotary_save_freqs_buffer, + ) + else: + self.rotary_emb = None + + self.rope_fusion = neox_args.rope_fusion + self.attention_type = neox_args.attention_config[layer_number] + self.use_flash_attention = self.attention_type == "flash" + self.use_triton = ( + self.use_flash_attention + and self.pos_emb == "alibi" + and ( + not packaging.version.Version(version("flash-attn")) + >= packaging.version.Version("2.4.0.post1") + ) + ) + self.sparse = self.attention_type not in ("global", "flash") + + if self.gqa: + assert not self.sparse + + if self.sparse: + self.sparse_attn = configure_sparse_attention( + neox_args, + self.attention_type, + self.num_attention_heads_per_partition, + mpu=mpu, ) + else: + if self.use_flash_attention: + # we now use Flash Attention 2's provided interface. + # TODO: we no longer need to use flash_triton_fn since flash cuda supports alibi. + # consider adding OpenAI's more recent Flash-2 Triton kernel in future + # from https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py + from flash_attn.flash_attn_interface import ( + flash_attn_func, + flash_attn_varlen_func, + ) + from flash_attn.flash_attn_triton import ( + flash_attn_func as flash_attn_unpadded_unpacked_func_triton, + ) + + self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton + self.flash_qkv_fn = flash_attn_func + self.flash_varlen_qkv_fn = flash_attn_varlen_func + else: + self.scale_mask_softmax = FusedScaleMaskSoftmax( + input_in_fp16=self.fp16, + input_in_bf16=self.bf16, + fusion_type=get_fusion_type(neox_args), + mask_func=self.attention_mask_func, + softmax_in_fp32=self.attention_softmax_in_fp32, + scale=coeff, + ) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.dropout_p = neox_args.attention_dropout + self.attention_dropout = nn.Dropout(self.dropout_p) + + # Output. + self.dense = RowParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True, + parallel_output=parallel_output, + bias=neox_args.use_bias_in_attn_linear, + ) + + def attention( + self, query_layer, key_layer, value_layer, layer_past, attention_mask + ): + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== - def bias(self, seq_len_q, seq_len_k, device, dtype): # [b, np, sq, sk] - # seq_len_q = x.shape[-2] - # seq_len_k = x.shape[-1] - - # Initialize the AliBi matrix to match the first provided key length; grow it exponentially - # afterwards if longer inputs are provided. This is important for inference, where we will - # encounter progressively longer samples; it should have no effect at training time. - if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k: - a = self.cached_matrix - else: - target_seq_len = ( - seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4 - ) - a = -torch.tril( - torch.arange(target_seq_len) - .view(target_seq_len, 1) - .repeat(1, target_seq_len) - + torch.arange(0, -target_seq_len, -1) - ) - a = a.to(device).to(dtype) - slopes = self.slopes.to(a.device).to(a.dtype) - a = a * slopes.view(self.slopes.shape[0], 1, 1) - self.cached_seq_len = target_seq_len - self.cached_matrix = a - - # If the AliBi matrix is larger than the key length, clip it. - if self.cached_seq_len > seq_len_k: - a = self.cached_matrix[:, :seq_len_k, :seq_len_k] - - if seq_len_q != seq_len_k: - # In the train case x has dimensionality [b, np, sq, sk] with sq == sk - # The number of query tokens is equal to the number of key tokens - # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence) - # In this case we use the appropriate token index of the cache matrix. - # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used - assert ( - seq_len_q == 1 - ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" - a = a[:, seq_len_k - 1, :].view( - a.shape[0], 1, a.shape[2] - ) # seq_len_k - 1 points to the last token index in the current inference batch. - - return a + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + output_size[2], output_size[0] * output_size[1], -1 + ) + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # preallocating result tensor: [b * np, sq, sk] + matmul_result = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=torch.cuda.current_device(), + ) - def forward(self, x): + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== + + if self.use_cache: + with torch.no_grad(): + attention_mask = attention_mask[ + ..., : attention_scores.size(3), : attention_scores.size(3) + ] + + # =========================== + # Attention probs and dropout + # =========================== + + if exists(self.rpe): + rpe = self.rpe(query_layer.size(0), key_layer.size(0)) + attention_scores += rpe # [1, np, sq, sk] + + if self.pos_emb == "alibi": + attention_scores = self.alibi_embed(attention_scores) + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) + + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + return context_layer + + def flash_attention(self, query_layer, key_layer, value_layer): # [b, np, sq, sk] - seq_len_q = x.shape[-2] - seq_len_k = x.shape[-1] - - # Initialize the AliBi matrix to match the first provided key length; grow it exponentially - # afterwards if longer inputs are provided. This is important for inference, where we will - # encounter progressively longer samples; it should have no effect at training time. - if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k: - a = self.cached_matrix - else: - target_seq_len = ( - seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4 - ) - a = -torch.tril( - torch.arange(target_seq_len) - .view(target_seq_len, 1) - .repeat(1, target_seq_len) - + torch.arange(0, -target_seq_len, -1) - ) - a = a.to(x.device).to(x.dtype) - slopes = self.slopes.to(a.device).to(a.dtype) - a = a * slopes.view(self.slopes.shape[0], 1, 1) - self.cached_seq_len = target_seq_len - self.cached_matrix = a - - # If the AliBi matrix is larger than the key length, clip it. - if self.cached_seq_len > seq_len_k: - a = self.cached_matrix[:, :seq_len_k, :seq_len_k] - - if seq_len_q != seq_len_k: - # In the train case x has dimensionality [b, np, sq, sk] with sq == sk - # The number of query tokens is equal to the number of key tokens - # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence) - # In this case we use the appropriate token index of the cache matrix. - # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used - assert ( - seq_len_q == 1 - ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" - a = a[:, seq_len_k - 1, :].view( - a.shape[0], 1, a.shape[2] - ) # seq_len_k - 1 points to the last token index in the current inference batch. - - return x + a + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + + if self.use_flash_attention and not self.use_triton: + + # [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn] + key_layer = key_layer.transpose(0, 1).reshape( + output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 + ) + value_layer = value_layer.transpose(0, 1).reshape( + output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 + ) + + # [sq, b, np, hn] -> [b, sq, np, hn] + query_layer = query_layer.transpose(0, 1).reshape( + output_size[0], output_size[2], output_size[1], -1 + ) + + # only pass in window_size or alibi_slopes kwarg + # if we use Sliding Window Attention / AliBi. + # Flash attn defaults to (-1,-1), or + # does not have this kwarg prior to v2.3.0 + extra_kwargs = ( + {"window_size": (self.sliding_window_width, -1)} + if self.sliding_window_width is not None + else {} + ) + if self.pos_emb == "alibi": + extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to( + query_layer.device + ).to(torch.float32) + + if not self.training: + batch_size = output_size[0] + max_seqlen_q = output_size[2] + max_seqlen_k = output_size[3] + + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * max_seqlen_q, + step=max_seqlen_q, + dtype=torch.int32, + device=query_layer.device, + ) + + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * max_seqlen_k, + step=max_seqlen_k, + dtype=torch.int32, + device=key_layer.device, + ) + + q_shape = query_layer.shape + k_shape = key_layer.shape + v_shape = value_layer.shape + is_causal = max_seqlen_q == max_seqlen_k + output = self.flash_varlen_qkv_fn( + query_layer.reshape( + (q_shape[0] * q_shape[1], q_shape[2], q_shape[3]) + ), + key_layer.reshape( + (k_shape[0] * k_shape[1], k_shape[2], k_shape[3]) + ), + value_layer.reshape( + (v_shape[0] * v_shape[1], v_shape[2], v_shape[3]) + ), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=is_causal, + **extra_kwargs, + ) + output = output.reshape(q_shape) + else: + output = self.flash_qkv_fn( + query_layer, + key_layer, + value_layer, + self.dropout_p if self.training else 0.0, + softmax_scale=None, + causal=True, + **extra_kwargs, + ) + + matmul_result = output + # [b, sq, np, hn] -> [b, np, sq, hn] + matmul_result = matmul_result.transpose(1, 2) + + else: + # we still use Triton if using AliBi with flash-attn<2.4.0.post1. + + # [sq, b, np, hn] -> [b, sq, np, hn] + sq = query_layer.size(0) + b = query_layer.size(1) + sk = key_layer.size(0) + + query_layer = query_layer.transpose(0, 1) + key_layer = key_layer.transpose(0, 1) + value_layer = value_layer.transpose(0, 1) + + bias = self.alibi_embed.bias(sq, sk, query_layer.device, query_layer.dtype) + bias = bias.unsqueeze(0).tile((b, 1, 1, 1)) + + matmul_result = self.flash_triton_fn( + query_layer, key_layer, value_layer, bias=bias, causal=True + ) + matmul_result = matmul_result.transpose(1, 2) + + return matmul_result + + def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): + # TODO: sparse attn dropout? + # TODO: pad to block size + # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn] + query_layer, key_layer, value_layer = map( + lambda t: t.permute(1, 2, 0, 3).contiguous(), + (query_layer, key_layer, value_layer), + ) + # output shape [b, np(heads), sq, hn] + attn_mask = attention_mask.to(query_layer.dtype) * -10000 + if exists(self.rpe): + rpe = self.rpe(query_layer.size(0), key_layer.size(0)) + else: + rpe = None + return self.sparse_attn( + query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe + ) + + def gqa_project(self, hidden_states, attention_mask, layer_past=None): + # QKV projection and separation into separate Q/K/V layers for GQA, + # where KV projections may be smaller than Q projection. + # the logic for this is explained in comments of this function + # detailing the intermediate sizes of tensors at each reshape. + + # pass through projection: [sq, b, h] --> [sq, b, ((np + 2 * kvp) * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # First: reshape so we have seqlen, batch, and num. query heads each as separate dims + # Final dim is not exactly head dim: the first (head dim) dims are query heads, + # The last (head dim * ratio of kv to q heads) each are the "k/v heads" + # (right now we treat like we have same num. heads, but smaller head dim) + + # [sq, b, ((np + 2 * kvp) * hn)] --> [sq, b, np, (hn * (1 + 2 * (kvp / np)))] + new_qkv_shape = ( + mixed_x_layer.shape[0], + mixed_x_layer.shape[1], + self.num_attention_heads_per_partition, + int( + self.hidden_size_per_attention_head + * ( + 1 + + 2 + * ( + self.num_kv_heads_per_partition + / self.num_attention_heads_per_partition + ) + ) + ), + ) + mixed_x_layer = mixed_x_layer.reshape(*new_qkv_shape) + + # Next: split our fake head dim. (last dim) so that the first (head dim) dimensions go to Q, + # the last smaller 2 * (head dim * kv to q head ratio) each divided between K and V separately + split_sizes = ( + self.hidden_size_per_attention_head, + int( + ( + self.num_kv_heads_per_partition + / self.num_attention_heads_per_partition + ) + * self.hidden_size_per_attention_head + ), + int( + ( + self.num_kv_heads_per_partition + / self.num_attention_heads_per_partition + ) + * self.hidden_size_per_attention_head + ), + ) + + # [sq, b, np, (hn * (1 + 2 * (kvp / np)))] --> 1 x [sq, b, np, hn] , 2 x [sq, b, np, (hn * (kvp / np))] + (query_layer, key_layer, value_layer) = [ + x.contiguous() + for x in torch.split( + mixed_x_layer, + split_sizes, + dim=mixed_x_layer.dim() - 1, + ) + ] + + # reshape K/V to proper output shape (last dim = correct full "real" head size again) + # 2 x [sq, b, np, (hn * (kvp / np))] --> 2 x [sq, b, kvp, hn] + new_kv_shape = ( + key_layer.size(0), + key_layer.size(1), + self.num_kv_heads_per_partition, + self.hidden_size_per_attention_head, + ) + + key_layer = key_layer.view(*new_kv_shape) + + value_layer = value_layer.view(*new_kv_shape) + + # if not using Flash attention, we repeat K/V heads to match Q head counts + if not self.use_flash_attention: + key_layer = torch.repeat_interleave( + key_layer, + repeats=int( + self.num_attention_heads_per_partition + // self.num_kv_heads_per_partition + ), + dim=2, + ) + value_layer = torch.repeat_interleave( + value_layer, + repeats=int( + self.num_attention_heads_per_partition + // self.num_kv_heads_per_partition + ), + dim=2, + ) + + return query_layer, key_layer, value_layer + + def forward(self, hidden_states, attention_mask, layer_past=None): + + # hidden_states: [sq, b, h] + + # ===================== + # Query, Key, and Value + # ===================== + if not self.gqa: + # QKV projection for MHA. + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim( + mixed_x_layer, 3 + ) + else: + # Grouped Query Attention (GQA) - specific logic for performing QKV proj + # and separating out Q, K, and V outputs. + + # output shapes: 1 x [sq, b, np, hn], 2 x [sq, b, kvp, hn] if using flash + query_layer, key_layer, value_layer = self.gqa_project( + hidden_states, attention_mask, layer_past=layer_past + ) + # QK Normalization https://arxiv.org/abs/2302.05442 + if self.use_qk_layernorm: + query_layer = self.qk_layernorm(query_layer) + key_layer = self.qk_layernorm(key_layer) + + if exists(self.rotary_emb): + if exists(self.rotary_ndims): + # partial rotary + query_rot, query_pass = ( + query_layer[..., : self.rotary_ndims], + query_layer[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_layer[..., : self.rotary_ndims], + key_layer[..., self.rotary_ndims :], + ) + else: + # full rotary + query_rot, key_rot = query_layer, key_layer + + seq_len = key_layer.shape[0] + offset = 0 + if exists(layer_past) and layer_past.numel() > 0: + offset = layer_past[0].shape[0] + seq_len += offset + cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + if self.rope_fusion: + query_layer, key_layer = ( + fused_apply_rotary_pos_emb_cached(rot, cos, sin) + for rot in [query_rot, key_rot] + ) + else: + if self.bf16: + apply_rotary_fn = apply_rotary_pos_emb_torch + else: + apply_rotary_fn = apply_rotary_pos_emb + query_layer, key_layer = apply_rotary_fn( + query_rot, key_rot, cos, sin, offset=offset + ) + + if exists(self.rotary_ndims): + query_layer = torch.cat((query_layer, query_pass), dim=-1) + key_layer = torch.cat((key_layer, key_pass), dim=-1) + + + # ================================== + # Cache key and value for inference + # ================================== + + if exists(layer_past) and layer_past.numel() > 0: + past_key, past_value = layer_past + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=0 + ) + + if self.use_cache: + present = torch.stack((key_layer, value_layer)) + + if self.use_flash_attention: + context_layer = self.flash_attention(query_layer, key_layer, value_layer) + elif not self.sparse: + context_layer = self.attention( + query_layer, key_layer, value_layer, layer_past, attention_mask + ) + else: + context_layer = self.sparse_attention( + query_layer, key_layer, value_layer, attention_mask + ) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, + ) + context_layer = context_layer.view(*new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + if self.use_cache: + output = [output, present] + + return output, bias + + +class ParallelTransformerLayer(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [b, s, h] and returns an + output of the same size. + """ + + def __init__( + self, + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=None, + rotary=False, + use_cache=False, + ): + + super().__init__() + self.layer_number = layer_number + self.neox_args = neox_args + + norm, eps = get_norm(neox_args) + + # Layernorm on the input data. + self.input_layernorm = norm(neox_args.hidden_size, eps=eps) + self.use_cache = use_cache + + self.hidden_dropout = neox_args.hidden_dropout + self.bias_dropout_fusion = neox_args.bias_dropout_fusion + self.gpt_j_residual = neox_args.gpt_j_residual + self.gpt_j_tied = neox_args.gpt_j_tied + self.moe_type = neox_args.moe_type + self.activation = neox_args.activation + + if self.gpt_j_residual: + # GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers. + # the reduction we use is a simple allreduce for pure Tensor Parallel, + # but needs to be a reduce-scatter when using Megatron-style Sequence Parallel (LN sharding.) + self.reduce = ( + mpu.mappings.reduce_from_model_parallel_region + if not neox_args.sequence_parallel + else mpu.mappings.reduce_scatter_to_sequence_parallel_region + ) + + # Self attention. + if neox_args.te_mha or neox_args.fp8_mha: + from megatron.model.transformer_engine import TEMultiheadAttention + self.attention = TEMultiheadAttention( + neox_args=neox_args, + attention_mask_func=attention_mask_func, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + rpe=rpe, + use_cache=self.use_cache, + rotary=rotary, + parallel_output=self.gpt_j_residual, + ) + + else: + self.attention = ParallelSelfAttention( + neox_args=neox_args, + attention_mask_func=attention_mask_func, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + rpe=rpe, + use_cache=self.use_cache, + rotary=rotary, + parallel_output=self.gpt_j_residual, + ) + + # Layernorm on the output of the attention layer. + # If GPT-J residuals are used, this is surpurfulous but leaving it in + # leads to cleaner code + self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps) + + # MLP + def get_mlp(**kw): + return ParallelMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + multiple_of=neox_args.mlp_multiple_of, + **kw, + ) + + def get_te_lnmlp(**kw): + from megatron.model.transformer_engine import TELayerNormMLP + return TELayerNormMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + multiple_of=neox_args.mlp_multiple_of, + **kw, + ) + + self.num_experts = ( + neox_args.moe_num_experts + if layer_number % neox_args.expert_interval == 0 + else 1 + ) + args = neox_args + if self.num_experts <= 1: + if neox_args.te_layernorm_mlp: + self.mlp = get_te_lnmlp() + else: + self.mlp = get_mlp() + else: + from torch import distributed as dist + + if self.num_experts > dist.get_world_size(): + moe_mp_size = 1 + else: + moe_mp_size = dist.get_world_size() // self.num_experts + + if neox_args.moe_type == "deepspeed": + self.mlp = MoE( + args.hidden_size, + get_mlp( + "regular", + MOE=True, + MoE_mp_size=moe_mp_size, + ), + num_experts=self.num_experts, + ep_size=args.moe_expert_parallel_size, + k=args.moe_top_k, + use_residual=args.moe_use_residual, + capacity_factor=args.moe_train_capacity_factor, + eval_capacity_factor=args.moe_eval_capacity_factor, + min_capacity=args.moe_min_capacity, + drop_tokens=args.moe_token_dropping, + use_tutel=args.use_tutel, + enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, + ) + elif neox_args.moe_type == "megablocks": + + def integrate_megablocks_with_ds_expert_parallelism(): + # We make megablocks work with DS parallelism. + # + # We fool DS into accepting these MoE parameters as its own DS MoE params, + # which makes things work with the underlying expert parallelism, + # including TED parallelism. + # + # Effectively, we want to: + # + # - Make DS's data parallel gradient all-reduction skip these params. + # - But make these params participate in the expert parallel all-reduction! + # + # Further background: + # + # Normally, with the original megablocks demo codebase, it + # only supports 1 copy of any expert throughout + # the network, since it uses EP group = DP group. + # + # First, we trigger DS initialization of the MoE expert parallel groups and internal state. + throwaway = MoE( + args.hidden_size, + get_mlp( + "regular", + MOE=True, + MoE_mp_size=moe_mp_size, + ), + num_experts=self.num_experts, + ep_size=args.moe_expert_parallel_size, + k=args.moe_top_k, + use_residual=args.moe_use_residual, + capacity_factor=args.moe_train_capacity_factor, + eval_capacity_factor=args.moe_eval_capacity_factor, + min_capacity=args.moe_min_capacity, + drop_tokens=args.moe_token_dropping, + use_tutel=args.use_tutel, + enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, + ) + throwaway.set_deepspeed_parallelism() + + ep_group = throwaway.deepspeed_moe.ep_group + if args.moe_token_dropping: + self.mlp = MbMoE( + neox_args, init_method, output_layer_init_method, ep_group + ) + else: + self.mlp = dMoE( + neox_args, init_method, output_layer_init_method, ep_group + ) + + # Next, we trick DS into seeing these as its own MoE params. + for param in self.mlp.parameters(): + if getattr(param, "expert_model_parallel", None) is not None: + # is_moe_param looks for this attr. + param.allreduce = False + param.group_name = throwaway.expert_group_name + + integrate_megablocks_with_ds_expert_parallelism() + + else: + raise KeyError(neox_args.moe_type) + + self.layer_past = None # used to cache k/v pairs in inference + + def _get_bias_dropout(self): + if self.bias_dropout_fusion: + fn = ( + bias_dropout_add_fused_train + if self.training + else bias_dropout_add_fused_inference + ) + else: + fn = get_bias_dropout_add(self.training) + return fn + + def forward(self, x, attention_mask, layer_past=None): + layer_past = layer_past if layer_past is not None else self.layer_past + bias_dropout_fn = self._get_bias_dropout() + moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) + # x: [b, s, h] + + + #Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer. + if self.neox_args.fp8_mha: + from megatron.model.transformer_engine import TEDelayedScaling + + fp8_recipe = TEDelayedScaling( + neox_args=self.neox_args + ) + fp8_context = fp8_recipe.get_context() + else: + from contextlib import nullcontext + fp8_context = nullcontext() + + with fp8_context: + if self.gpt_j_residual: + # pseudocode: + # x = x + attn(ln(x)) + mlp(ln(x)) + # this means we can avoid doing the allreduce in the attn / mlp outputs + # to save communication time (we can do a single allreduce after we add mlp / attn outputs). + # due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but + # we preserve the functionality for backwards compatibility + + residual = x + # applies the correct normalization depending on if the norms are tied + if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: + x = self.input_layernorm(x) + x1, x2 = x, x + elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: + x2 = x + x = self.input_layernorm(x) + x1 = x + elif self.neox_args.te_layernorm_mlp: + x1, x2 = self.input_layernorm(x), x + else: + x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) + + # attention operator + attention_output, attention_bias = self.attention( + x1, attention_mask, layer_past=layer_past + ) + if self.use_cache: + attention_output, presents = attention_output + self.layer_past = presents + + if attention_bias is not None: + with torch.enable_grad() if not self.eval else nullcontext(): + attention_output = bias_dropout_fn( + attention_output, + bias=attention_bias.expand_as(attention_output), + residual=None, + prob=self.hidden_dropout, + ) + + # mlp operator + mlp_output, mlp_bias = self.mlp(x2) + if mlp_bias is not None: + with torch.enable_grad() if not self.eval else nullcontext(): + output = bias_dropout_fn( + mlp_output, + bias=mlp_bias.expand_as(mlp_output), + residual=attention_output, + prob=self.hidden_dropout, + ) + else: + output = mlp_output + + # output = (x + attn(ln(x)) + mlp(ln(x)) + output = residual + self.reduce(output) + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + + residual = x + + # x = x + attn(ln1(x)) + attention_output, attention_bias = self.attention( + self.input_layernorm(x), attention_mask, layer_past=layer_past + ) + if self.use_cache: + attention_output, presents = attention_output + self.layer_past = presents + with torch.enable_grad() if not self.eval else nullcontext(): + if attention_bias is not None: + # Use special bias_dropout_fn if we have a bias term from the above attention layer + attention_output = bias_dropout_fn( + attention_output, + bias=attention_bias.expand_as(residual), + residual=residual, + prob=self.hidden_dropout, + ) + else: + # Otherwise just apply dropout + residual + attention_output = ( + torch.nn.functional.dropout( + attention_output, + p=self.hidden_dropout, + training=self.training, + ) + + residual + ) + + # output = x + mlp(ln2(x)) + if self.neox_args.te_layernorm_mlp: + layernorm_output = attention_output + else: + layernorm_output = self.post_attention_layernorm(attention_output) + mlp_bias = torch.tensor( + 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype + ) + + if self.num_experts == 1: + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + if self.moe_type == "deepspeed": + mlp_output, moe_loss, _ = self.mlp(layernorm_output) + mlp_bias = ( + None # deepspeed.moe.layer.MoE.forward ignores the bias term + ) + elif self.moe_type == "megablocks": + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + raise KeyError(self.moe_type) + + with torch.enable_grad() if not self.eval else nullcontext(): + if ( + self.activation == "swiglu" + or self.num_experts > 1 + and self.moe_type == "deepspeed" + ): + # No dropout either + assert mlp_bias is None + output = mlp_output + attention_output + else: + output = bias_dropout_fn( + mlp_output, + bias=mlp_bias.expand_as(attention_output), + residual=attention_output, + prob=self.hidden_dropout, + ) + + return output, moe_loss + + +class ParallelTransformerLayerPipe(ParallelTransformerLayer): + """Extends ParallelTransformerLayer to forward attention_mask through the pipeline.""" + + def forward(self, args): + assert ( + len(args) == 2 + ), "ParallelTransformerLayerPipe expects 2 arguments - hidden_states and attention_mask" + hidden_states, attention_mask = args + # we are returning just [hidden_states, mask] + output, moe_loss = super().forward(hidden_states, attention_mask) + # auxiliary output + self.last_moe_loss = moe_loss + return output, attention_mask + + +class ParallelLinearPipe(ParallelLinear): + """Another helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" + + def forward(self, args): + assert isinstance( + args, torch.Tensor + ), "ParallelLinearPipe expects a single argument - hidden_states" + hidden_state = args + logits, bias = super().forward(hidden_state) + return logits + + +class NormPipe(nn.Module): + """Just a helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" + + def __init__(self, norm_class, hidden_size, eps): + super().__init__() + self.norm = norm_class(hidden_size, eps=eps) + + def forward(self, args): + assert not isinstance( + args, tuple + ), "NormPipe should only receive a single tensor as input" + return self.norm(args) + + +def parallel_lm_logits( + input_, + word_embeddings_weight, + parallel_output, + seq_parallel=False, + seq_dim=1, + bias=None, +): + """LM logits using word embedding weights.""" + # Parallel logits. + if seq_parallel: + # if using Sequence Parallelism, our logits are sharded along the sequence dimension. + # gather them here. (backward pass: reduce-scatter) + input_parallel = mpu.gather_from_sequence_parallel_region( + input_, seq_dim=seq_dim + ) + else: + # Set up backprop all-reduce. + input_parallel = mpu.copy_to_model_parallel_region(input_) + + # Matrix multiply. + if bias is None: + logits_parallel = F.linear(input_parallel, word_embeddings_weight) + else: + logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) + + # Gather if needed. + if parallel_output: + return logits_parallel + + return mpu.gather_from_model_parallel_region(logits_parallel) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 62f316f3e..fdf384a4f 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -340,6 +340,8 @@ def __init__( ): super().__init__() + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) + self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func @@ -393,7 +395,7 @@ def __init__( if not self.gqa: # Strided linear layer. - self.query_key_value = mpu.ColumnParallelLinear( + self.query_key_value = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, @@ -403,7 +405,7 @@ def __init__( ) else: # QKV proj is smaller if we are using GQA / MQA - self.query_key_value = mpu.ColumnParallelLinear( + self.query_key_value = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, @@ -412,6 +414,7 @@ def __init__( bias=neox_args.use_bias_in_attn_linear, ) + coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: @@ -857,19 +860,17 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None): return query_layer, key_layer, value_layer def forward(self, hidden_states, attention_mask, layer_past=None): - + # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== - if not self.gqa: # QKV projection for MHA. # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_attention_heads_per_partition, @@ -889,7 +890,6 @@ def forward(self, hidden_states, attention_mask, layer_past=None): query_layer, key_layer, value_layer = self.gqa_project( hidden_states, attention_mask, layer_past=layer_past ) - # QK Normalization https://arxiv.org/abs/2302.05442 if self.use_qk_layernorm: query_layer = self.qk_layernorm(query_layer) @@ -934,6 +934,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None): query_layer = torch.cat((query_layer, query_pass), dim=-1) key_layer = torch.cat((key_layer, key_pass), dim=-1) + # ================================== # Cache key and value for inference # ================================== @@ -1027,7 +1028,7 @@ def __init__( ) # Self attention. - if neox_args.te_mha: + if neox_args.te_mha or neox_args.fp8_mha: from megatron.model.transformer_engine import TEMultiheadAttention self.attention = TEMultiheadAttention( neox_args=neox_args, @@ -1200,134 +1201,149 @@ def forward(self, x, attention_mask, layer_past=None): bias_dropout_fn = self._get_bias_dropout() moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) # x: [b, s, h] - if self.gpt_j_residual: - # pseudocode: - # x = x + attn(ln(x)) + mlp(ln(x)) - # this means we can avoid doing the allreduce in the attn / mlp outputs - # to save communication time (we can do a single allreduce after we add mlp / attn outputs). - # due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but - # we preserve the functionality for backwards compatibility - - residual = x - # applies the correct normalization depending on if the norms are tied - if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: - x = self.input_layernorm(x) - x1, x2 = x, x - elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: - x2 = x - x = self.input_layernorm(x) - x1 = x - elif self.neox_args.te_layernorm_mlp: - x1, x2 = self.input_layernorm(x), x - else: - x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) - - # attention operator - attention_output, attention_bias = self.attention( - x1, attention_mask, layer_past=layer_past + + + #Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer. + if self.neox_args.fp8_mha: + from megatron.model.transformer_engine import TEDelayedScaling + + fp8_recipe = TEDelayedScaling( + neox_args=self.neox_args ) - if self.use_cache: - attention_output, presents = attention_output - self.layer_past = presents - - if attention_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): - attention_output = bias_dropout_fn( - attention_output, - bias=attention_bias.expand_as(attention_output), - residual=None, - prob=self.hidden_dropout, - ) - - # mlp operator - mlp_output, mlp_bias = self.mlp(x2) - if mlp_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): - output = bias_dropout_fn( - mlp_output, - bias=mlp_bias.expand_as(mlp_output), - residual=attention_output, - prob=self.hidden_dropout, - ) - else: - output = mlp_output - - # output = (x + attn(ln(x)) + mlp(ln(x)) - output = residual + self.reduce(output) + fp8_context = fp8_recipe.get_context() else: - # pseudocode: - # x = x + attn(ln1(x)) - # x = x + mlp(ln2(x)) + from contextlib import nullcontext + fp8_context = nullcontext() + + with fp8_context: + if self.gpt_j_residual: + # pseudocode: + # x = x + attn(ln(x)) + mlp(ln(x)) + # this means we can avoid doing the allreduce in the attn / mlp outputs + # to save communication time (we can do a single allreduce after we add mlp / attn outputs). + # due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but + # we preserve the functionality for backwards compatibility + + residual = x + # applies the correct normalization depending on if the norms are tied + if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: + x = self.input_layernorm(x) + x1, x2 = x, x + elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: + x2 = x + x = self.input_layernorm(x) + x1 = x + elif self.neox_args.te_layernorm_mlp: + x1, x2 = self.input_layernorm(x), x + else: + x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) - residual = x + # attention operator + attention_output, attention_bias = self.attention( + x1, attention_mask, layer_past=layer_past + ) + if self.use_cache: + attention_output, presents = attention_output + self.layer_past = presents - # x = x + attn(ln1(x)) - attention_output, attention_bias = self.attention( - self.input_layernorm(x), attention_mask, layer_past=layer_past - ) - if self.use_cache: - attention_output, presents = attention_output - self.layer_past = presents - with torch.enable_grad() if not self.eval else nullcontext(): if attention_bias is not None: - # Use special bias_dropout_fn if we have a bias term from the above attention layer - attention_output = bias_dropout_fn( - attention_output, - bias=attention_bias.expand_as(residual), - residual=residual, - prob=self.hidden_dropout, - ) - else: - # Otherwise just apply dropout + residual - attention_output = ( - torch.nn.functional.dropout( + with torch.enable_grad() if not self.eval else nullcontext(): + attention_output = bias_dropout_fn( attention_output, - p=self.hidden_dropout, - training=self.training, + bias=attention_bias.expand_as(attention_output), + residual=None, + prob=self.hidden_dropout, ) - + residual - ) - # output = x + mlp(ln2(x)) - if self.neox_args.te_layernorm_mlp: - layernorm_output = attention_output - else: - layernorm_output = self.post_attention_layernorm(attention_output) - mlp_bias = torch.tensor( - 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype - ) + # mlp operator + mlp_output, mlp_bias = self.mlp(x2) + if mlp_bias is not None: + with torch.enable_grad() if not self.eval else nullcontext(): + output = bias_dropout_fn( + mlp_output, + bias=mlp_bias.expand_as(mlp_output), + residual=attention_output, + prob=self.hidden_dropout, + ) + else: + output = mlp_output - if self.num_experts == 1: - mlp_output, mlp_bias = self.mlp(layernorm_output) + # output = (x + attn(ln(x)) + mlp(ln(x)) + output = residual + self.reduce(output) else: - if self.moe_type == "deepspeed": - mlp_output, moe_loss, _ = self.mlp(layernorm_output) - mlp_bias = ( - None # deepspeed.moe.layer.MoE.forward ignores the bias term - ) - elif self.moe_type == "megablocks": - mlp_output, mlp_bias = self.mlp(layernorm_output) + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + + residual = x + + # x = x + attn(ln1(x)) + attention_output, attention_bias = self.attention( + self.input_layernorm(x), attention_mask, layer_past=layer_past + ) + if self.use_cache: + attention_output, presents = attention_output + self.layer_past = presents + with torch.enable_grad() if not self.eval else nullcontext(): + if attention_bias is not None: + # Use special bias_dropout_fn if we have a bias term from the above attention layer + attention_output = bias_dropout_fn( + attention_output, + bias=attention_bias.expand_as(residual), + residual=residual, + prob=self.hidden_dropout, + ) + else: + # Otherwise just apply dropout + residual + attention_output = ( + torch.nn.functional.dropout( + attention_output, + p=self.hidden_dropout, + training=self.training, + ) + + residual + ) + + # output = x + mlp(ln2(x)) + if self.neox_args.te_layernorm_mlp: + layernorm_output = attention_output else: - raise KeyError(self.moe_type) - - with torch.enable_grad() if not self.eval else nullcontext(): - if ( - self.activation == "swiglu" - or self.num_experts > 1 - and self.moe_type == "deepspeed" - ): - # No dropout either - assert mlp_bias is None - output = mlp_output + attention_output + layernorm_output = self.post_attention_layernorm(attention_output) + mlp_bias = torch.tensor( + 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype + ) + + if self.num_experts == 1: + mlp_output, mlp_bias = self.mlp(layernorm_output) else: - output = bias_dropout_fn( - mlp_output, - bias=mlp_bias.expand_as(attention_output), - residual=attention_output, - prob=self.hidden_dropout, - ) + if self.moe_type == "deepspeed": + mlp_output, moe_loss, _ = self.mlp(layernorm_output) + mlp_bias = ( + None # deepspeed.moe.layer.MoE.forward ignores the bias term + ) + elif self.moe_type == "megablocks": + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + raise KeyError(self.moe_type) + + with torch.enable_grad() if not self.eval else nullcontext(): + if ( + self.activation == "swiglu" + or self.num_experts > 1 + and self.moe_type == "deepspeed" + ): + # No dropout either + assert mlp_bias is None + output = mlp_output + attention_output + else: + output = bias_dropout_fn( + mlp_output, + bias=mlp_bias.expand_as(attention_output), + residual=attention_output, + prob=self.hidden_dropout, + ) - return output, moe_loss + return output, moe_loss class ParallelTransformerLayerPipe(ParallelTransformerLayer): diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 9a8c0a506..07559bcfb 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -16,10 +16,13 @@ from megatron.mpu.mappings import scatter_to_model_parallel_region from megatron.mpu.mappings import reduce_scatter_to_sequence_parallel_region from megatron.mpu.mappings import gather_from_sequence_parallel_region +from megatron.mpu.layers import _initialize_affine_weight_gpu, _initialize_affine_weight_cpu from megatron.mpu.random import get_cuda_rng_tracker from megatron.mpu.utils import divide from megatron.mpu.utils import VocabUtility from functools import partial +from megatron.model.positional_embeddings import RotaryEmbedding +from megatron import mpu try: import transformer_engine as te @@ -90,7 +93,6 @@ def __init__( mup_rescale_parameters=False, seq_dim=0, ): - # Keep input parameters self.input_size = input_size self.output_size = output_size @@ -106,9 +108,6 @@ def __init__( self.use_mup = neox_args.use_mup self.params_dtype=neox_args.params_dtype - # print("##########################") - # print(self.return_bias) - super(TELinear, self).__init__(in_features=self.input_size, out_features=self.output_size, bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, device=torch.cuda.current_device(), return_bias=self.skip_bias_add, params_dtype=self.params_dtype) @@ -145,7 +144,6 @@ def __init__( ): self.activation_func, self.is_gated = get_activation(neox_args) self.activation_type = neox_args.activation - self.bias_gelu_fusion = neox_args.bias_gelu_fusion self.multiple_of = multiple_of self.bias = bias self.init_method = init_method @@ -188,16 +186,17 @@ def __init__( if neox_args.norm in ['layernorm','te_layernorm']: self.eps=1.0e-5 self.normalization = 'LayerNorm' - elif neox_args.norm == ['rmsnorm','te_rmsnorm']: + elif neox_args.norm in ['rmsnorm','te_rmsnorm']: self.eps=1.0e-8 self.normalization = 'RMSNorm' - #TODO handle case if norm is not rmsnorm or layernorm - #TODO check if activation in list ‘gelu’, ‘geglu’, ‘relu’, ‘reglu’, ‘squared_relu’, - #‘swiglu’, ‘qgelu’, ‘srelu’ - #TODO handle MOE and mup + else: + raise ValueError("Only LayerNorm and RMSNorm are supported with TransformerEngine") + + if self.activation_type not in ["gelu", "geglu", "relu", "reglu", "squared_relu","swiglu", "qgelu", "srelu"]: + raise ValueError("Only gelu, geglu, relu, reglu, squared_relu, swiglu, qgelu, and srelu are supported with TransformerEngine") super(TELayerNormMLP, self).__init__(hidden_size=neox_args.hidden_size, ffn_hidden_size=ffn_dim, - eps=self.eps, bias=self.bias, normalization=self.normalization, activation=neox_args.activation, + eps=self.eps, bias=self.bias, normalization=self.normalization, activation=self.activation_type, init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, device=torch.cuda.current_device(), set_parallel_mode=self.set_parallel_mode, sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, tp_size=self.world_size, @@ -265,8 +264,6 @@ def __init__( self.use_mup = neox_args.use_mup self.params_dtype=neox_args.params_dtype self.parallel_mode="column" - # print("##########################") - # print(self.return_bias) super(TEColumnParallelLinear, self).__init__(in_features=self.input_size, out_features=self.output_size, bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, @@ -283,12 +280,6 @@ def width_mult(self): ) return self.weight.infshape.width_mult() - def set_parallel_output(self, value: bool): - assert isinstance(value, bool) - self.gather_output = ( - not value - ) # if gather_output is True, parallel output is False, so we set the opposite - # Copied from Mup def _rescale_parameters(self): """Rescale parameters to convert SP initialization to μP initialization. @@ -308,7 +299,7 @@ def _rescale_parameters(self): self.bias.data *= self.width_mult() ** 0.5 self.weight.data *= self.width_mult() ** 0.5 self._has_rescaled_params = True - + def mup_reinitialize_weights(self, neox_args): if neox_args.use_cpu_initialization: self.master_weight = _initialize_affine_weight_cpu( @@ -316,26 +307,25 @@ def mup_reinitialize_weights(self, neox_args): self.weight, self.output_size, self.input_size, - self.output_size_per_partition, - 0, + self.input_size_per_partition, + 1, partial(self.init_method, use_mup=True), stride=self.stride, - return_master_weight=keep_master_weight_for_test, + return_master_weight=self.keep_master_weight_for_test, ) else: _initialize_affine_weight_gpu( self.weight, partial(self.init_method, use_mup=True), - partition_dim=0, + partition_dim=1, stride=self.stride, ) - + def forward(self, inp, **kwargs): if self.use_mup and self.mup_rescale_parameters: input_ /= self.width_mult() output = super(TEColumnParallelLinear, self).forward(inp, **kwargs) - if self.skip_bias_add: return output else: @@ -455,21 +445,17 @@ def mup_reinitialize_weights(self, neox_args): stride=self.stride, ) - def set_parallel_output(self, parallel_output: bool): - assert isinstance(parallel_output, bool) - self.parallel_output = parallel_output - def forward(self, inp, **kwargs): - # if not self.input_is_parallel: - # inp = scatter_to_model_parallel_region(inp) + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() output = super(TERowParallelLinear, self).forward(inp, **kwargs) + if self.skip_bias_add: return output else: return output, None - class TEMultiheadAttention(te.pytorch.MultiheadAttention): """ Wrapper for the Transformer-Engine's `MultiheadAttention` layer that also @@ -487,6 +473,7 @@ def __init__(self, use_cache=False, parallel_output=False): + self.neox_args = neox_args self.attention_mask_func = attention_mask_func self.init_method = init_method self.output_layer_init_method = output_layer_init_method @@ -524,12 +511,43 @@ def __init__(self, attention_dropout=neox_args.attention_dropout, layernorm_epsilon=self.eps, init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, layer_number=self.layer_number, window_size=neox_args.sliding_window_width, num_gqa_groups=self.num_kv_heads, input_layernorm=False, - normalization=self.normalization, bias=True, device=torch.cuda.current_device(), + normalization=self.normalization, bias=True, device=torch.cuda.current_device(),get_rng_state_tracker=get_cuda_rng_tracker, set_parallel_mode=self.set_parallel_mode, sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, - tp_size=self.world_size, params_dtype=self.params_dtype, return_bias=True) + tp_size=self.world_size, params_dtype=self.params_dtype, return_bias=True, qkv_format="sbhd", fuse_qkv_params=True) + + + + if neox_args.pos_emb == "rotary": + self.hidden_size_per_attention_head = mpu.divide( + neox_args.hidden_size, neox_args.num_attention_heads) + + if neox_args.rotary_pct == 1: + self.rotary_ndims = None + else: + assert neox_args.rotary_pct < 1 + self.rotary_ndims = int( + self.hidden_size_per_attention_head * neox_args.rotary_pct + ) + dim = ( + self.rotary_ndims + if self.rotary_ndims is not None + else self.hidden_size_per_attention_head + ) + self.rotary_embeddings = RotaryEmbedding( + dim, + base=neox_args.rotary_emb_base, + max_seq_len=neox_args.seq_length, + precision=neox_args.params_dtype, + save_inv_freqs=neox_args.rotary_save_freqs_buffer, + return_embeddings=True + ) + + def forward(self, hidden_states, attention_mask, layer_past=None, rope_emb=None, **kwargs): + if self.neox_args.pos_emb == "rotary": + rope_emb=self.rotary_embeddings(hidden_states) + + output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, rotary_pos_emb=rope_emb, **kwargs) - def forward(self, hidden_states, attention_mask, layer_past=None, **kwargs): - output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, **kwargs) return output @@ -537,7 +555,36 @@ class TEDelayedScaling(te.common.recipe.DelayedScaling): """ Wrapper for the Transformer-Engine's `DelayedScaling` layer. """ + ##TODO Test with H100 + def __init__( + self, + neox_args): + + self.neox_args = neox_args + self.tp_group = get_tensor_model_parallel_group() + + if neox_args.fp8_format == "e4m3": + fp8_format = te.common.recipe.Format.E4M3 + elif neox_args.fp8_format == "hybrid": + fp8_format = te.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") - def __init__(self): - # TODO - return + override_linear_precision = (False, False, not neox_args.fp8_wgrad) + + super().__init__( + margin=neox_args.fp8_margin, + fp8_format=fp8_format, + amax_compute_algo=neox_args.fp8_amax_compute_algo, + amax_history_len=neox_args.fp8_amax_history_len, + override_linear_precision=override_linear_precision, + fp8_mha=neox_args.fp8_mha, + ) + + def fp8_context(self): + fp8_group = None + if self.tp_group: + fp8_group = self.tp_group + fp8_context = te.pytorch.fp8_autocast(enabled=True, fp8_recipe=self, fp8_group=fp8_group) + + return get_context \ No newline at end of file diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 76d80778d..6dac733e4 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -522,6 +522,46 @@ class NeoXArgsModel(NeoXArgsTemplate): Use TransformerEngine for MultiheadAttention layer. """ + fp8_format: Literal["e4m3", "hybrid"] = "hybrid" + """ + Controls the FP8 data format used during forward and backward pass by TransformerEngine. + Hybrid uses E4M3 during forward pass, E5M2 during backward pass. + """ + + fp8_wgrad: bool = True + """ + When set to False, override FP8 config options and do the wgrad computation + in higher precision. + """ + + fp8_amax_history_len: int = 1 + """ + The length of the amax history window used for scaling factor computation. + """ + + fp8_amax_compute_algo: str = "most_recent" + """ + Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2 + predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent` + always chooses the most recently seen value. + """ + + fp8_wgrad: bool = True + """ + When set to False, override FP8 config options and do the wgrad computation + in higher precision. + """ + + fp8_margin: int = 0 + """ + Margin for the scaling factor computation. + """ + + fp8_mha: bool = False + """ + When set to True, use the FP8 implementation of Multi Head Attention. + """ + @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): """ From ee42a316bc365b3f0405d1aa64bfaaf23a43e71b Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Thu, 26 Sep 2024 08:25:49 +0000 Subject: [PATCH 09/27] Fixed mixed files. --- megatron/model/positional_embeddings.py | 1577 +++-------------------- 1 file changed, 207 insertions(+), 1370 deletions(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index fdf384a4f..38d7b12c6 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -1,7 +1,4 @@ -# Copyright (c) 2024 EleutherAI -# This file is based on code by the authors denoted below and has been modified from its original version. -# -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, EleutherAI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,1406 +12,246 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Transformer.""" - -import math -from contextlib import nullcontext - import torch -import torch.nn.functional as F -import torch.nn as nn -from pkg_resources import packaging -from importlib.metadata import version - -from .norms import get_norm -from megatron import mpu -from megatron.model import megablocks_utils -from megatron.model.fused_softmax import FusedScaleMaskSoftmax -from megatron.model.activations import get_activation -from megatron.model.utils import exists, get_fusion_type -from megatron.model.positional_embeddings import ( - RotaryEmbedding, - apply_rotary_pos_emb_torch, - apply_rotary_pos_emb, - AliBi, -) -from megatron.model.fused_rope import ( - FusedRoPEFunc, - fused_apply_rotary_pos_emb_cached, -) -from megatron.model.fused_bias_dropout import ( - get_bias_dropout_add, - bias_dropout_add_fused_train, - bias_dropout_add_fused_inference, -) -from megatron.model.utils import configure_sparse_attention -from deepspeed.moe.layer import MoE - -try: - from flash_attn.ops.activations import swiglu -except ImportError: - swiglu = None - -from .utils import linear_implementation_router - -# flags required to enable jit fusion kernels -torch._C._jit_set_profiling_mode(False) -torch._C._jit_set_profiling_executor(False) -torch._C._jit_override_can_fuse_on_cpu(True) -torch._C._jit_override_can_fuse_on_gpu(True) - -""" We use the following notation throughout this file: - h: hidden size - n: number of attention heads - kv: number of key or value heads - p: number of model parallel partitions - np: n/p - kvp: kv/p - hp: h/p - hn: h/n - b: batch size - s: sequence length - l: number of layers - Transformer takes input of size [s, b, h] and returns a - tensor of the same size. We use the following arguments: - hyperparameters: transformer hyperparameters - attention_mask_func: a function that takes `unmasked-attention-scores` - with size [b, np, s, s] and an `attention-mask` and will apply - the masking. The function should return a masked score of the - same size [b, np, s, s]. - masked-attention-scores = attention_mask_func( - unmasked-attention-scores, attention-mask) -""" - - -class ParallelMLP(nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. At the end, dropout is also - applied. - """ - - def __init__( - self, - neox_args, - init_method, - output_layer_init_method, - parallel_output=False, - multiple_of=256, - MOE=False, - MoE_mp_size=1, - ): - super().__init__() - assert ( - neox_args.intermediate_size == None or neox_args.expansion_factor == None - ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" - - self.activation_func, self.is_gated = get_activation(neox_args) - self.activation_type = neox_args.activation - self.bias_gelu_fusion = neox_args.bias_gelu_fusion - self.multiple_of = multiple_of - - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) - - if neox_args.intermediate_size: - ffn_dim = neox_args.intermediate_size - elif neox_args.expansion_factor: - ffn_dim = int(neox_args.expansion_factor * neox_args.hidden_size) - else: - # 4h is default for ffn_dim - ffn_dim = 4 * neox_args.hidden_size - ffn_dim_in = ffn_dim - if self.is_gated: - # set activation function to be gated implementation - self.activation_func = Gated_Activation( - self.activation_func, - (swiglu is not None) - and (neox_args.activation == "swiglu") - and neox_args.use_flashattn_swiglu, - ) - # auto scale so gated activations has equal parameters - ffn_dim = int(ffn_dim * 2 / 3) - ffn_dim_in = ffn_dim // 2 - # set multiple - ffn_dim = int( - (2 * self.multiple_of) - * ((ffn_dim + (2 * multiple_of) - 1) // (2 * multiple_of)) - ) - ffn_dim_in = int( - self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) - ) - self.linear1 = ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=ffn_dim, - gather_output=False, - init_method=init_method, - skip_bias_add=True, - MOE=MOE, - MoE_mp_size=MoE_mp_size, - bias=neox_args.use_bias_in_mlp, - ) - # Project back to h. - self.linear2 = RowParallelLinear( - neox_args=neox_args, - input_size=ffn_dim_in, - output_size=neox_args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - parallel_output=parallel_output, - skip_bias_add=True, - MOE=MOE, - MoE_mp_size=MoE_mp_size, - bias=neox_args.use_bias_in_mlp, - ) - - def forward(self, hidden_states): - # [s, b, intermediate_size] - intermediate_parallel, bias_parallel = self.linear1(hidden_states) - if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): - intermediate_parallel = self.activation_func( - intermediate_parallel, bias_parallel - ) - else: - intermediate_parallel = self.activation_func( - intermediate_parallel + bias_parallel - ) - - # [s, b, h] - output, output_bias = self.linear2(intermediate_parallel) - return output, output_bias - - -class Gated_Activation(torch.nn.Module): - def __init__(self, activation_func, use_swiglu=False): - super().__init__() - self.activation_func = activation_func - self.use_swiglu = use_swiglu - - def forward(self, x, bias=None): - x, gate = x.chunk(2, dim=-1) - if bias is not None: - bias_1, bias_2 = bias.chunk(2, dim=-1) - x = x + bias_1 - gate = gate + bias_2 - if not self.use_swiglu: - intermediate_parallel = self.activation_func(gate) - return intermediate_parallel * x - else: - return swiglu(gate, x) - +import math -class ParallelLinear(nn.Module): - """ - A Parallel Linear Layer transforming the transformer outputs from hidden_size -> vocab_size - """ - def __init__( - self, - neox_args, - parallel_output=True, - init_method=nn.init.xavier_normal_, - is_last_layer=False, - ): +class SinusoidalPositionalEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half): super().__init__() - - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) - - self.is_rm = neox_args.train_impl == "rm" - parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" - if parallelism == "column": - self.final_linear = ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=neox_args.padded_vocab_size, - bias=False, - init_method=init_method, - gather_output=not parallel_output, - skip_bias_add=False, - mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here - seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0 - ) - else: - if not self.is_rm: - print( - 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.' - ) - exit() - # self.final_linear = mpu.RowParallelLinear( - # neox_args=neox_args, - # input_size=neox_args.hidden_size, - # output_size=neox_args.padded_vocab_size, - # bias=False, - # input_is_parallel=False, - # init_method=init_method, - # parallel_output=parallel_output, - # skip_bias_add=False, - # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here - # ) - else: # Not using cross entropy loss for RMs - self.rm_linear = RowParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=1, - bias=False, - input_is_parallel=False, - init_method=init_method, - parallel_output=False, - skip_bias_add=False, - mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here - ) - - def forward(self, hidden_states): - if not self.is_rm: - return self.final_linear(hidden_states) - else: - return self.rm_linear(hidden_states) - - -class _MegablocksAdapter(nn.Module): + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.precision = precision + + def forward(self, x, seq_dim=1): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq) + if self.precision == torch.bfloat16: + sinusoid_inp = sinusoid_inp.float() + sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos() + if self.precision == torch.bfloat16: + sin, cos = sin.bfloat16(), cos.bfloat16() + emb = torch.cat((sin, cos), dim=-1) + return emb[None, :, :] + + +class RotaryEmbedding(torch.nn.Module): def __init__( - self, neox_args, layer_cls, init_method, output_layer_init_method, ep_group + self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False, return_embeddings=False ): super().__init__() - megablocks_utils.assert_megablocks_is_available() - args = megablocks_utils.as_megablocks_args(neox_args) - args.device = torch.cuda.current_device() - args.init_method = init_method - args.output_layer_init_method = output_layer_init_method - - # NOTE: Shard the MoE layers over the data parallel group. Expert - # parallel sharding and data parallel sharding could be decoupled - # by extending the optimizer to handle data parallel reductions for - # MoE and non-MoE parameters separately. - if args.moe_expert_model_parallelism: - args.expert_parallel_group = ep_group - - self.moe = layer_cls(args) - - def forward(self, x): - return self.moe.forward(x) - - -class MbMoE(_MegablocksAdapter): - def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): - super().__init__( - neox_args, - megablocks_utils.moe.MoE, - init_method, - output_layer_init_method, - ep_group, + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) + self.seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + self.max_seq_len = max_seq_len + self.base = base + self.dim = dim + self.return_embeddings = return_embeddings + + # precompute cos_cached, sin_cached in fp32 + cos_cached, sin_cached, inv_freq = self._prepare_cache( + max_seq_len, precision, base ) + self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) + self.cos_cached = cos_cached + self.sin_cached = sin_cached -class dMoE(_MegablocksAdapter): - def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): - super().__init__( - neox_args, - megablocks_utils.dmoe.dMoE, - init_method, - output_layer_init_method, - ep_group, - ) + def _prepare_cache(self, seq_len, precision, base): + # precompute cos_cached, sin_cached in fp32 + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + t = torch.arange(seq_len).type_as(inv_freq) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) -class ParallelSelfAttention(nn.Module): - """Parallel self-attention layer abstract class. + self.emb = emb.reshape(emb.size(0), 1, 1, emb.size(1)) - Self-attention layer takes input with size [b, s, h] - and returns output of the same size. - """ + cos_cached = emb.cos()[:, None, None, :] + sin_cached = emb.sin()[:, None, None, :] - def __init__( - self, - neox_args, - attention_mask_func, - init_method, - output_layer_init_method, - layer_number, - rpe=None, - rotary=False, - use_cache=False, - parallel_output=False, - ): - super().__init__() - - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) - - self.fp16 = neox_args.precision == "fp16" - self.bf16 = neox_args.precision == "bfloat16" - self.attention_mask_func = attention_mask_func - self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling - self.use_cache = use_cache - self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = layer_number - # Per attention head and per partition values. - world_size = mpu.get_model_parallel_world_size() - self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) - self.hidden_size_per_attention_head = mpu.divide( - neox_args.hidden_size, neox_args.num_attention_heads + return ( + cos_cached.to(precision), + sin_cached.to(precision), + inv_freq.to(precision), ) - self.num_attention_heads_per_partition = mpu.divide( - neox_args.num_attention_heads, world_size - ) - self.pos_emb = neox_args.pos_emb - self.use_qk_layernorm = neox_args.use_qk_layernorm - if self.use_qk_layernorm: - norm, eps = get_norm(neox_args) - self.qk_layernorm = norm( - [ - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ], - eps=eps, - ) + def forward(self, x, seq_dim=0, seq_len=None): + if self.return_embeddings: + return self.emb.to(self.precision).to(x.device) + if seq_len is None: + seq_len = x.shape[seq_dim] - self.sliding_window_width = neox_args.sliding_window_width + assert seq_len <= self.max_seq_len - if ( - not neox_args.num_kv_heads - or neox_args.num_kv_heads == neox_args.num_attention_heads - ): - self.gqa = False - else: - self.gqa = True - if self.gqa: - self.num_kv_heads_per_partition = mpu.divide( - neox_args.num_kv_heads, world_size - ) # we do not yet clone KV heads in MQA across TP ranks... - self.kv_hidden_size = ( - neox_args.num_kv_heads * self.hidden_size_per_attention_head - ) # how large the total hidden dim for each of K and V is - else: - self.num_kv_heads_per_partition = self.num_attention_heads_per_partition - self.kv_hidden_size = neox_args.hidden_size - - if not self.gqa: - # Strided linear layer. - self.query_key_value = ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=3 * neox_args.hidden_size, - gather_output=False, - init_method=init_method, - bias=neox_args.use_bias_in_attn_linear, + if seq_len != self.max_seq_len: + # y, z, _ = self._prepare_cache(seq_len, self.precision, self.base) + return ( + self.cos_cached[:seq_len, ...].to(x.device), + self.sin_cached[:seq_len, ...].to(x.device), ) else: - # QKV proj is smaller if we are using GQA / MQA - self.query_key_value = ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, - gather_output=False, - init_method=init_method, - bias=neox_args.use_bias_in_attn_linear, - ) - - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = max(1, self.layer_number) - self.norm_factor *= coeff - - if neox_args.use_mup: - self.norm_factor = self.hidden_size_per_attention_head - - self.rpe = rpe - - if self.pos_emb == "alibi": - self.alibi_embed = AliBi( - neox_args.num_attention_heads, - neox_args.model_parallel_size, - mpu.get_model_parallel_rank(), - ) + return self.cos_cached.to(x.device), self.sin_cached.to(x.device) - # TODO: this arg shouldn't need to be passed in - get from neox_args - if rotary: - if neox_args.rotary_pct == 1: - self.rotary_ndims = None - else: - assert neox_args.rotary_pct < 1 - self.rotary_ndims = int( - self.hidden_size_per_attention_head * neox_args.rotary_pct - ) - dim = ( - self.rotary_ndims - if self.rotary_ndims is not None - else self.hidden_size_per_attention_head - ) - self.rotary_emb = RotaryEmbedding( - dim, - base=neox_args.rotary_emb_base, - max_seq_len=neox_args.seq_length, - precision=neox_args.params_dtype, - save_inv_freqs=neox_args.rotary_save_freqs_buffer, - ) - else: - self.rotary_emb = None - self.rope_fusion = neox_args.rope_fusion - self.attention_type = neox_args.attention_config[layer_number] - self.use_flash_attention = self.attention_type == "flash" - self.use_triton = ( - self.use_flash_attention - and self.pos_emb == "alibi" - and ( - not packaging.version.Version(version("flash-attn")) - >= packaging.version.Version("2.4.0.post1") - ) - ) - self.sparse = self.attention_type not in ("global", "flash") +# rotary pos emb helpers: - if self.gqa: - assert not self.sparse - if self.sparse: - self.sparse_attn = configure_sparse_attention( - neox_args, - self.attention_type, - self.num_attention_heads_per_partition, - mpu=mpu, - ) - else: - if self.use_flash_attention: - # we now use Flash Attention 2's provided interface. - # TODO: we no longer need to use flash_triton_fn since flash cuda supports alibi. - # consider adding OpenAI's more recent Flash-2 Triton kernel in future - # from https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py - from flash_attn.flash_attn_interface import ( - flash_attn_func, - flash_attn_varlen_func, - ) - from flash_attn.flash_attn_triton import ( - flash_attn_func as flash_attn_unpadded_unpacked_func_triton, - ) +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat( + (-x2, x1), dim=x1.ndim - 1 + ) # dim=-1 triggers a bug in earlier torch versions - self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton - self.flash_qkv_fn = flash_attn_func - self.flash_varlen_qkv_fn = flash_attn_varlen_func - else: - self.scale_mask_softmax = FusedScaleMaskSoftmax( - input_in_fp16=self.fp16, - input_in_bf16=self.bf16, - fusion_type=get_fusion_type(neox_args), - mask_func=self.attention_mask_func, - softmax_in_fp32=self.attention_softmax_in_fp32, - scale=coeff, - ) - - # Dropout. Note that for a single iteration, this layer will generate - # different outputs on different number of parallel partitions but - # on average it should not be partition dependent. - self.dropout_p = neox_args.attention_dropout - self.attention_dropout = nn.Dropout(self.dropout_p) - - # Output. - self.dense = RowParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=neox_args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True, - parallel_output=parallel_output, - bias=neox_args.use_bias_in_attn_linear, - ) - - def attention( - self, query_layer, key_layer, value_layer, layer_past, attention_mask - ): - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = ( - query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0), - ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view( - output_size[2], output_size[0] * output_size[1], -1 - ) - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - # preallocating result tensor: [b * np, sq, sk] - matmul_result = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device(), - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - # ================================================== - # Update attention mask for inference. [b, np, sq, sk] - # ================================================== - - if self.use_cache: - with torch.no_grad(): - attention_mask = attention_mask[ - ..., : attention_scores.size(3), : attention_scores.size(3) - ] - # =========================== - # Attention probs and dropout - # =========================== +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos, sin = ( + cos[offset : q.shape[0] + offset, ...], + sin[offset : q.shape[0] + offset, ...], + ) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - if exists(self.rpe): - rpe = self.rpe(query_layer.size(0), key_layer.size(0)) - attention_scores += rpe # [1, np, sq, sk] - if self.pos_emb == "alibi": - attention_scores = self.alibi_embed(attention_scores) - - # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - with mpu.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = ( - value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3), - ) - - # change view [sk, b * np, hn] - value_layer = value_layer.view( - value_layer.size(0), output_size[0] * output_size[1], -1 - ) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view( - output_size[0] * output_size[1], output_size[2], -1 - ) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - return context_layer - - def flash_attention(self, query_layer, key_layer, value_layer): - # [b, np, sq, sk] - output_size = ( - query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0), - ) - - if self.use_flash_attention and not self.use_triton: - - # [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn] - key_layer = key_layer.transpose(0, 1).reshape( - output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 - ) - value_layer = value_layer.transpose(0, 1).reshape( - output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 - ) - - # [sq, b, np, hn] -> [b, sq, np, hn] - query_layer = query_layer.transpose(0, 1).reshape( - output_size[0], output_size[2], output_size[1], -1 - ) - - # only pass in window_size or alibi_slopes kwarg - # if we use Sliding Window Attention / AliBi. - # Flash attn defaults to (-1,-1), or - # does not have this kwarg prior to v2.3.0 - extra_kwargs = ( - {"window_size": (self.sliding_window_width, -1)} - if self.sliding_window_width is not None - else {} - ) - if self.pos_emb == "alibi": - extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to( - query_layer.device - ).to(torch.float32) - - if not self.training: - batch_size = output_size[0] - max_seqlen_q = output_size[2] - max_seqlen_k = output_size[3] - - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * max_seqlen_q, - step=max_seqlen_q, - dtype=torch.int32, - device=query_layer.device, - ) - - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * max_seqlen_k, - step=max_seqlen_k, - dtype=torch.int32, - device=key_layer.device, - ) - - q_shape = query_layer.shape - k_shape = key_layer.shape - v_shape = value_layer.shape - is_causal = max_seqlen_q == max_seqlen_k - output = self.flash_varlen_qkv_fn( - query_layer.reshape( - (q_shape[0] * q_shape[1], q_shape[2], q_shape[3]) - ), - key_layer.reshape( - (k_shape[0] * k_shape[1], k_shape[2], k_shape[3]) - ), - value_layer.reshape( - (v_shape[0] * v_shape[1], v_shape[2], v_shape[3]) - ), - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale=None, - causal=is_causal, - **extra_kwargs, - ) - output = output.reshape(q_shape) - else: - output = self.flash_qkv_fn( - query_layer, - key_layer, - value_layer, - self.dropout_p if self.training else 0.0, - softmax_scale=None, - causal=True, - **extra_kwargs, - ) - - matmul_result = output - # [b, sq, np, hn] -> [b, np, sq, hn] - matmul_result = matmul_result.transpose(1, 2) - - else: - # we still use Triton if using AliBi with flash-attn<2.4.0.post1. +def apply_rotary_pos_emb_torch( + q, k, cos, sin, offset: int = 0 +): # jitting fails with bf16 + cos, sin = ( + cos[offset : q.shape[0] + offset, ...], + sin[offset : q.shape[0] + offset, ...], + ) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - # [sq, b, np, hn] -> [b, sq, np, hn] - sq = query_layer.size(0) - b = query_layer.size(1) - sk = key_layer.size(0) - - query_layer = query_layer.transpose(0, 1) - key_layer = key_layer.transpose(0, 1) - value_layer = value_layer.transpose(0, 1) - - bias = self.alibi_embed.bias(sq, sk, query_layer.device, query_layer.dtype) - bias = bias.unsqueeze(0).tile((b, 1, 1, 1)) - - matmul_result = self.flash_triton_fn( - query_layer, key_layer, value_layer, bias=bias, causal=True - ) - matmul_result = matmul_result.transpose(1, 2) - - return matmul_result - - def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): - # TODO: sparse attn dropout? - # TODO: pad to block size - # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn] - query_layer, key_layer, value_layer = map( - lambda t: t.permute(1, 2, 0, 3).contiguous(), - (query_layer, key_layer, value_layer), - ) - # output shape [b, np(heads), sq, hn] - attn_mask = attention_mask.to(query_layer.dtype) * -10000 - if exists(self.rpe): - rpe = self.rpe(query_layer.size(0), key_layer.size(0)) - else: - rpe = None - return self.sparse_attn( - query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe - ) - - def gqa_project(self, hidden_states, attention_mask, layer_past=None): - # QKV projection and separation into separate Q/K/V layers for GQA, - # where KV projections may be smaller than Q projection. - # the logic for this is explained in comments of this function - # detailing the intermediate sizes of tensors at each reshape. - - # pass through projection: [sq, b, h] --> [sq, b, ((np + 2 * kvp) * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) - - # First: reshape so we have seqlen, batch, and num. query heads each as separate dims - # Final dim is not exactly head dim: the first (head dim) dims are query heads, - # The last (head dim * ratio of kv to q heads) each are the "k/v heads" - # (right now we treat like we have same num. heads, but smaller head dim) - - # [sq, b, ((np + 2 * kvp) * hn)] --> [sq, b, np, (hn * (1 + 2 * (kvp / np)))] - new_qkv_shape = ( - mixed_x_layer.shape[0], - mixed_x_layer.shape[1], - self.num_attention_heads_per_partition, - int( - self.hidden_size_per_attention_head - * ( - 1 - + 2 - * ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - ) - ), - ) - mixed_x_layer = mixed_x_layer.reshape(*new_qkv_shape) - - # Next: split our fake head dim. (last dim) so that the first (head dim) dimensions go to Q, - # the last smaller 2 * (head dim * kv to q head ratio) each divided between K and V separately - split_sizes = ( - self.hidden_size_per_attention_head, - int( - ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - * self.hidden_size_per_attention_head - ), - int( - ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - * self.hidden_size_per_attention_head - ), - ) - - # [sq, b, np, (hn * (1 + 2 * (kvp / np)))] --> 1 x [sq, b, np, hn] , 2 x [sq, b, np, (hn * (kvp / np))] - (query_layer, key_layer, value_layer) = [ - x.contiguous() - for x in torch.split( - mixed_x_layer, - split_sizes, - dim=mixed_x_layer.dim() - 1, - ) - ] - - # reshape K/V to proper output shape (last dim = correct full "real" head size again) - # 2 x [sq, b, np, (hn * (kvp / np))] --> 2 x [sq, b, kvp, hn] - new_kv_shape = ( - key_layer.size(0), - key_layer.size(1), - self.num_kv_heads_per_partition, - self.hidden_size_per_attention_head, - ) - - key_layer = key_layer.view(*new_kv_shape) - - value_layer = value_layer.view(*new_kv_shape) - - # if not using Flash attention, we repeat K/V heads to match Q head counts - if not self.use_flash_attention: - key_layer = torch.repeat_interleave( - key_layer, - repeats=int( - self.num_attention_heads_per_partition - // self.num_kv_heads_per_partition - ), - dim=2, - ) - value_layer = torch.repeat_interleave( - value_layer, - repeats=int( - self.num_attention_heads_per_partition - // self.num_kv_heads_per_partition - ), - dim=2, - ) - - return query_layer, key_layer, value_layer - - def forward(self, hidden_states, attention_mask, layer_past=None): - - # hidden_states: [sq, b, h] - - # ===================== - # Query, Key, and Value - # ===================== - if not self.gqa: - # QKV projection for MHA. - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim( - mixed_x_layer, 3 - ) - else: - # Grouped Query Attention (GQA) - specific logic for performing QKV proj - # and separating out Q, K, and V outputs. - - # output shapes: 1 x [sq, b, np, hn], 2 x [sq, b, kvp, hn] if using flash - query_layer, key_layer, value_layer = self.gqa_project( - hidden_states, attention_mask, layer_past=layer_past - ) - # QK Normalization https://arxiv.org/abs/2302.05442 - if self.use_qk_layernorm: - query_layer = self.qk_layernorm(query_layer) - key_layer = self.qk_layernorm(key_layer) - - if exists(self.rotary_emb): - if exists(self.rotary_ndims): - # partial rotary - query_rot, query_pass = ( - query_layer[..., : self.rotary_ndims], - query_layer[..., self.rotary_ndims :], - ) - key_rot, key_pass = ( - key_layer[..., : self.rotary_ndims], - key_layer[..., self.rotary_ndims :], - ) - else: - # full rotary - query_rot, key_rot = query_layer, key_layer - - seq_len = key_layer.shape[0] - offset = 0 - if exists(layer_past) and layer_past.numel() > 0: - offset = layer_past[0].shape[0] - seq_len += offset - cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) - if self.rope_fusion: - query_layer, key_layer = ( - fused_apply_rotary_pos_emb_cached(rot, cos, sin) - for rot in [query_rot, key_rot] - ) - else: - if self.bf16: - apply_rotary_fn = apply_rotary_pos_emb_torch - else: - apply_rotary_fn = apply_rotary_pos_emb - query_layer, key_layer = apply_rotary_fn( - query_rot, key_rot, cos, sin, offset=offset - ) - - if exists(self.rotary_ndims): - query_layer = torch.cat((query_layer, query_pass), dim=-1) - key_layer = torch.cat((key_layer, key_pass), dim=-1) - - - # ================================== - # Cache key and value for inference - # ================================== - - if exists(layer_past) and layer_past.numel() > 0: - past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) - value_layer = torch.cat( - (past_value.type_as(value_layer), value_layer), dim=0 - ) - - if self.use_cache: - present = torch.stack((key_layer, value_layer)) - - if self.use_flash_attention: - context_layer = self.flash_attention(query_layer, key_layer, value_layer) - elif not self.sparse: - context_layer = self.attention( - query_layer, key_layer, value_layer, layer_past, attention_mask - ) - else: - context_layer = self.sparse_attention( - query_layer, key_layer, value_layer, attention_mask - ) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + ( - self.hidden_size_per_partition, - ) - context_layer = context_layer.view(*new_context_layer_shape) - - # ================= - # Output. [sq, b, h] - # ================= - - output, bias = self.dense(context_layer) - - if self.use_cache: - output = [output, present] - - return output, bias - - -class ParallelTransformerLayer(nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [b, s, h] and returns an - output of the same size. - """ - - def __init__( - self, - neox_args, - attention_mask_func, - init_method, - output_layer_init_method, - layer_number, - rpe=None, - rotary=False, - use_cache=False, - ): +class AliBi(torch.nn.Module): + def __init__(self, num_heads, mp_size=1, mp_rank=1): super().__init__() - self.layer_number = layer_number - self.neox_args = neox_args - - norm, eps = get_norm(neox_args) - - # Layernorm on the input data. - self.input_layernorm = norm(neox_args.hidden_size, eps=eps) - self.use_cache = use_cache - - self.hidden_dropout = neox_args.hidden_dropout - self.bias_dropout_fusion = neox_args.bias_dropout_fusion - self.gpt_j_residual = neox_args.gpt_j_residual - self.gpt_j_tied = neox_args.gpt_j_tied - self.moe_type = neox_args.moe_type - self.activation = neox_args.activation - - if self.gpt_j_residual: - # GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers. - # the reduction we use is a simple allreduce for pure Tensor Parallel, - # but needs to be a reduce-scatter when using Megatron-style Sequence Parallel (LN sharding.) - self.reduce = ( - mpu.mappings.reduce_from_model_parallel_region - if not neox_args.sequence_parallel - else mpu.mappings.reduce_scatter_to_sequence_parallel_region - ) - - # Self attention. - if neox_args.te_mha or neox_args.fp8_mha: - from megatron.model.transformer_engine import TEMultiheadAttention - self.attention = TEMultiheadAttention( - neox_args=neox_args, - attention_mask_func=attention_mask_func, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - rpe=rpe, - use_cache=self.use_cache, - rotary=rotary, - parallel_output=self.gpt_j_residual, - ) - + # megatron splits across heads, so we need to make sure each + # head receives the correct matrix + assert mp_size <= num_heads and mp_rank <= mp_size + self.mp_size = mp_size + self.mp_rank = mp_rank + self.num_heads = num_heads + self.slice_size = num_heads // mp_size + self.cached_matrix = None + self.cached_seq_len = None + slopes = torch.Tensor(self._get_slopes(num_heads))[ + mp_rank * self.slice_size : (mp_rank + 1) * self.slice_size + ] + self.register_buffer("slopes", slopes) + + def _get_slopes(self, n): + """ + Get slopes for Alibi positional embedding + n : int = number of heads. + For best performance, restrict n to a power of 2. + """ + + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) else: - self.attention = ParallelSelfAttention( - neox_args=neox_args, - attention_mask_func=attention_mask_func, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - rpe=rpe, - use_cache=self.use_cache, - rotary=rotary, - parallel_output=self.gpt_j_residual, - ) - - # Layernorm on the output of the attention layer. - # If GPT-J residuals are used, this is surpurfulous but leaving it in - # leads to cleaner code - self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps) - - # MLP - def get_mlp(**kw): - return ParallelMLP( - neox_args=neox_args, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - parallel_output=self.gpt_j_residual, - multiple_of=neox_args.mlp_multiple_of, - **kw, - ) - - def get_te_lnmlp(**kw): - from megatron.model.transformer_engine import TELayerNormMLP - return TELayerNormMLP( - neox_args=neox_args, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - parallel_output=self.gpt_j_residual, - multiple_of=neox_args.mlp_multiple_of, - **kw, + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + self._get_slopes(2 * closest_power_of_2)[0::2][ + : n - closest_power_of_2 + ] ) - self.num_experts = ( - neox_args.moe_num_experts - if layer_number % neox_args.expert_interval == 0 - else 1 - ) - args = neox_args - if self.num_experts <= 1: - if neox_args.te_layernorm_mlp: - self.mlp = get_te_lnmlp() - else: - self.mlp = get_mlp() - else: - from torch import distributed as dist - - if self.num_experts > dist.get_world_size(): - moe_mp_size = 1 - else: - moe_mp_size = dist.get_world_size() // self.num_experts - - if neox_args.moe_type == "deepspeed": - self.mlp = MoE( - args.hidden_size, - get_mlp( - "regular", - MOE=True, - MoE_mp_size=moe_mp_size, - ), - num_experts=self.num_experts, - ep_size=args.moe_expert_parallel_size, - k=args.moe_top_k, - use_residual=args.moe_use_residual, - capacity_factor=args.moe_train_capacity_factor, - eval_capacity_factor=args.moe_eval_capacity_factor, - min_capacity=args.moe_min_capacity, - drop_tokens=args.moe_token_dropping, - use_tutel=args.use_tutel, - enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, - ) - elif neox_args.moe_type == "megablocks": - - def integrate_megablocks_with_ds_expert_parallelism(): - # We make megablocks work with DS parallelism. - # - # We fool DS into accepting these MoE parameters as its own DS MoE params, - # which makes things work with the underlying expert parallelism, - # including TED parallelism. - # - # Effectively, we want to: - # - # - Make DS's data parallel gradient all-reduction skip these params. - # - But make these params participate in the expert parallel all-reduction! - # - # Further background: - # - # Normally, with the original megablocks demo codebase, it - # only supports 1 copy of any expert throughout - # the network, since it uses EP group = DP group. - # - # First, we trigger DS initialization of the MoE expert parallel groups and internal state. - throwaway = MoE( - args.hidden_size, - get_mlp( - "regular", - MOE=True, - MoE_mp_size=moe_mp_size, - ), - num_experts=self.num_experts, - ep_size=args.moe_expert_parallel_size, - k=args.moe_top_k, - use_residual=args.moe_use_residual, - capacity_factor=args.moe_train_capacity_factor, - eval_capacity_factor=args.moe_eval_capacity_factor, - min_capacity=args.moe_min_capacity, - drop_tokens=args.moe_token_dropping, - use_tutel=args.use_tutel, - enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, - ) - throwaway.set_deepspeed_parallelism() - - ep_group = throwaway.deepspeed_moe.ep_group - if args.moe_token_dropping: - self.mlp = MbMoE( - neox_args, init_method, output_layer_init_method, ep_group - ) - else: - self.mlp = dMoE( - neox_args, init_method, output_layer_init_method, ep_group - ) - - # Next, we trick DS into seeing these as its own MoE params. - for param in self.mlp.parameters(): - if getattr(param, "expert_model_parallel", None) is not None: - # is_moe_param looks for this attr. - param.allreduce = False - param.group_name = throwaway.expert_group_name - - integrate_megablocks_with_ds_expert_parallelism() - - else: - raise KeyError(neox_args.moe_type) - - self.layer_past = None # used to cache k/v pairs in inference - - def _get_bias_dropout(self): - if self.bias_dropout_fusion: - fn = ( - bias_dropout_add_fused_train - if self.training - else bias_dropout_add_fused_inference - ) + def bias(self, seq_len_q, seq_len_k, device, dtype): + # [b, np, sq, sk] + # seq_len_q = x.shape[-2] + # seq_len_k = x.shape[-1] + + # Initialize the AliBi matrix to match the first provided key length; grow it exponentially + # afterwards if longer inputs are provided. This is important for inference, where we will + # encounter progressively longer samples; it should have no effect at training time. + if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k: + a = self.cached_matrix else: - fn = get_bias_dropout_add(self.training) - return fn + target_seq_len = ( + seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4 + ) + a = -torch.tril( + torch.arange(target_seq_len) + .view(target_seq_len, 1) + .repeat(1, target_seq_len) + + torch.arange(0, -target_seq_len, -1) + ) + a = a.to(device).to(dtype) + slopes = self.slopes.to(a.device).to(a.dtype) + a = a * slopes.view(self.slopes.shape[0], 1, 1) + self.cached_seq_len = target_seq_len + self.cached_matrix = a + + # If the AliBi matrix is larger than the key length, clip it. + if self.cached_seq_len > seq_len_k: + a = self.cached_matrix[:, :seq_len_k, :seq_len_k] + + if seq_len_q != seq_len_k: + # In the train case x has dimensionality [b, np, sq, sk] with sq == sk + # The number of query tokens is equal to the number of key tokens + # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence) + # In this case we use the appropriate token index of the cache matrix. + # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used + assert ( + seq_len_q == 1 + ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" + a = a[:, seq_len_k - 1, :].view( + a.shape[0], 1, a.shape[2] + ) # seq_len_k - 1 points to the last token index in the current inference batch. + + return a - def forward(self, x, attention_mask, layer_past=None): - layer_past = layer_past if layer_past is not None else self.layer_past - bias_dropout_fn = self._get_bias_dropout() - moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) - # x: [b, s, h] - - - #Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer. - if self.neox_args.fp8_mha: - from megatron.model.transformer_engine import TEDelayedScaling - - fp8_recipe = TEDelayedScaling( - neox_args=self.neox_args - ) - fp8_context = fp8_recipe.get_context() + def forward(self, x): + # [b, np, sq, sk] + seq_len_q = x.shape[-2] + seq_len_k = x.shape[-1] + + # Initialize the AliBi matrix to match the first provided key length; grow it exponentially + # afterwards if longer inputs are provided. This is important for inference, where we will + # encounter progressively longer samples; it should have no effect at training time. + if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k: + a = self.cached_matrix else: - from contextlib import nullcontext - fp8_context = nullcontext() - - with fp8_context: - if self.gpt_j_residual: - # pseudocode: - # x = x + attn(ln(x)) + mlp(ln(x)) - # this means we can avoid doing the allreduce in the attn / mlp outputs - # to save communication time (we can do a single allreduce after we add mlp / attn outputs). - # due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but - # we preserve the functionality for backwards compatibility - - residual = x - # applies the correct normalization depending on if the norms are tied - if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: - x = self.input_layernorm(x) - x1, x2 = x, x - elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: - x2 = x - x = self.input_layernorm(x) - x1 = x - elif self.neox_args.te_layernorm_mlp: - x1, x2 = self.input_layernorm(x), x - else: - x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) - - # attention operator - attention_output, attention_bias = self.attention( - x1, attention_mask, layer_past=layer_past - ) - if self.use_cache: - attention_output, presents = attention_output - self.layer_past = presents - - if attention_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): - attention_output = bias_dropout_fn( - attention_output, - bias=attention_bias.expand_as(attention_output), - residual=None, - prob=self.hidden_dropout, - ) - - # mlp operator - mlp_output, mlp_bias = self.mlp(x2) - if mlp_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): - output = bias_dropout_fn( - mlp_output, - bias=mlp_bias.expand_as(mlp_output), - residual=attention_output, - prob=self.hidden_dropout, - ) - else: - output = mlp_output - - # output = (x + attn(ln(x)) + mlp(ln(x)) - output = residual + self.reduce(output) - else: - # pseudocode: - # x = x + attn(ln1(x)) - # x = x + mlp(ln2(x)) - - residual = x - - # x = x + attn(ln1(x)) - attention_output, attention_bias = self.attention( - self.input_layernorm(x), attention_mask, layer_past=layer_past - ) - if self.use_cache: - attention_output, presents = attention_output - self.layer_past = presents - with torch.enable_grad() if not self.eval else nullcontext(): - if attention_bias is not None: - # Use special bias_dropout_fn if we have a bias term from the above attention layer - attention_output = bias_dropout_fn( - attention_output, - bias=attention_bias.expand_as(residual), - residual=residual, - prob=self.hidden_dropout, - ) - else: - # Otherwise just apply dropout + residual - attention_output = ( - torch.nn.functional.dropout( - attention_output, - p=self.hidden_dropout, - training=self.training, - ) - + residual - ) - - # output = x + mlp(ln2(x)) - if self.neox_args.te_layernorm_mlp: - layernorm_output = attention_output - else: - layernorm_output = self.post_attention_layernorm(attention_output) - mlp_bias = torch.tensor( - 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype - ) - - if self.num_experts == 1: - mlp_output, mlp_bias = self.mlp(layernorm_output) - else: - if self.moe_type == "deepspeed": - mlp_output, moe_loss, _ = self.mlp(layernorm_output) - mlp_bias = ( - None # deepspeed.moe.layer.MoE.forward ignores the bias term - ) - elif self.moe_type == "megablocks": - mlp_output, mlp_bias = self.mlp(layernorm_output) - else: - raise KeyError(self.moe_type) - - with torch.enable_grad() if not self.eval else nullcontext(): - if ( - self.activation == "swiglu" - or self.num_experts > 1 - and self.moe_type == "deepspeed" - ): - # No dropout either - assert mlp_bias is None - output = mlp_output + attention_output - else: - output = bias_dropout_fn( - mlp_output, - bias=mlp_bias.expand_as(attention_output), - residual=attention_output, - prob=self.hidden_dropout, - ) - - return output, moe_loss - - -class ParallelTransformerLayerPipe(ParallelTransformerLayer): - """Extends ParallelTransformerLayer to forward attention_mask through the pipeline.""" - - def forward(self, args): - assert ( - len(args) == 2 - ), "ParallelTransformerLayerPipe expects 2 arguments - hidden_states and attention_mask" - hidden_states, attention_mask = args - # we are returning just [hidden_states, mask] - output, moe_loss = super().forward(hidden_states, attention_mask) - # auxiliary output - self.last_moe_loss = moe_loss - return output, attention_mask - - -class ParallelLinearPipe(ParallelLinear): - """Another helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" - - def forward(self, args): - assert isinstance( - args, torch.Tensor - ), "ParallelLinearPipe expects a single argument - hidden_states" - hidden_state = args - logits, bias = super().forward(hidden_state) - return logits - - -class NormPipe(nn.Module): - """Just a helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" - - def __init__(self, norm_class, hidden_size, eps): - super().__init__() - self.norm = norm_class(hidden_size, eps=eps) - - def forward(self, args): - assert not isinstance( - args, tuple - ), "NormPipe should only receive a single tensor as input" - return self.norm(args) - - -def parallel_lm_logits( - input_, - word_embeddings_weight, - parallel_output, - seq_parallel=False, - seq_dim=1, - bias=None, -): - """LM logits using word embedding weights.""" - # Parallel logits. - if seq_parallel: - # if using Sequence Parallelism, our logits are sharded along the sequence dimension. - # gather them here. (backward pass: reduce-scatter) - input_parallel = mpu.gather_from_sequence_parallel_region( - input_, seq_dim=seq_dim - ) - else: - # Set up backprop all-reduce. - input_parallel = mpu.copy_to_model_parallel_region(input_) - - # Matrix multiply. - if bias is None: - logits_parallel = F.linear(input_parallel, word_embeddings_weight) - else: - logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) - - # Gather if needed. - if parallel_output: - return logits_parallel - - return mpu.gather_from_model_parallel_region(logits_parallel) + target_seq_len = ( + seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4 + ) + a = -torch.tril( + torch.arange(target_seq_len) + .view(target_seq_len, 1) + .repeat(1, target_seq_len) + + torch.arange(0, -target_seq_len, -1) + ) + a = a.to(x.device).to(x.dtype) + slopes = self.slopes.to(a.device).to(a.dtype) + a = a * slopes.view(self.slopes.shape[0], 1, 1) + self.cached_seq_len = target_seq_len + self.cached_matrix = a + + # If the AliBi matrix is larger than the key length, clip it. + if self.cached_seq_len > seq_len_k: + a = self.cached_matrix[:, :seq_len_k, :seq_len_k] + + if seq_len_q != seq_len_k: + # In the train case x has dimensionality [b, np, sq, sk] with sq == sk + # The number of query tokens is equal to the number of key tokens + # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence) + # In this case we use the appropriate token index of the cache matrix. + # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used + assert ( + seq_len_q == 1 + ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" + a = a[:, seq_len_k - 1, :].view( + a.shape[0], 1, a.shape[2] + ) # seq_len_k - 1 points to the last token index in the current inference batch. + + return x + a \ No newline at end of file From 8961dd70196623c3ac8a3a7eeb7242afb83dd917 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sat, 14 Sep 2024 23:41:23 -0500 Subject: [PATCH 10/27] Implemented ColumnParallelLinear with Transformer-Engine --- megatron/model/transformer.py | 61 +++++----- megatron/model/transformer_engine.py | 169 ++++++++++++++++++++++++--- megatron/neox_arguments/neox_args.py | 31 ++--- 3 files changed, 203 insertions(+), 58 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index d112a7461..08e5987de 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -18,8 +18,6 @@ """Transformer.""" import math -from contextlib import nullcontext - import torch import torch.nn.functional as F import torch.nn as nn @@ -50,11 +48,6 @@ from megatron.model.utils import configure_sparse_attention from deepspeed.moe.layer import MoE -try: - from flash_attn.ops.activations import swiglu -except ImportError: - swiglu = None - # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) @@ -114,6 +107,11 @@ def __init__( self.bias_gelu_fusion = neox_args.bias_gelu_fusion self.multiple_of = multiple_of + if neox_args.te_linear: + from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear + else: + from megatron.mpu import ColumnParallelLinear + if neox_args.intermediate_size: ffn_dim = neox_args.intermediate_size elif neox_args.expansion_factor: @@ -124,12 +122,7 @@ def __init__( ffn_dim_in = ffn_dim if self.is_gated: # set activation function to be gated implementation - self.activation_func = Gated_Activation( - self.activation_func, - (swiglu is not None) - and (neox_args.activation == "swiglu") - and neox_args.use_flashattn_swiglu, - ) + self.activation_func = Gated_Activation(self.activation_func) # auto scale so gated activations has equal parameters ffn_dim = int(ffn_dim * 2 / 3) ffn_dim_in = ffn_dim // 2 @@ -142,7 +135,7 @@ def __init__( self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) ) - self.linear1 = mpu.ColumnParallelLinear( + self.linear1 = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=ffn_dim, @@ -170,7 +163,10 @@ def __init__( def forward(self, hidden_states): # [s, b, intermediate_size] intermediate_parallel, bias_parallel = self.linear1(hidden_states) - + # output = self.linear1(hidden_states) + # print(output) + # import sys + # sys.exit() if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): intermediate_parallel = self.activation_func( intermediate_parallel, bias_parallel @@ -186,10 +182,9 @@ def forward(self, hidden_states): class Gated_Activation(torch.nn.Module): - def __init__(self, activation_func, use_swiglu=False): + def __init__(self, activation_func): super().__init__() self.activation_func = activation_func - self.use_swiglu = use_swiglu def forward(self, x, bias=None): x, gate = x.chunk(2, dim=-1) @@ -197,11 +192,8 @@ def forward(self, x, bias=None): bias_1, bias_2 = bias.chunk(2, dim=-1) x = x + bias_1 gate = gate + bias_2 - if not self.use_swiglu: - intermediate_parallel = self.activation_func(gate) - return intermediate_parallel * x - else: - return swiglu(gate, x) + intermediate_parallel = self.activation_func(gate) + return intermediate_parallel * x class ParallelLinear(nn.Module): @@ -217,10 +209,16 @@ def __init__( is_last_layer=False, ): super().__init__() + + if neox_args.te_linear: + from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear + else: + from megatron.mpu import ColumnParallelLinear + self.is_rm = neox_args.train_impl == "rm" parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" if parallelism == "column": - self.final_linear = mpu.ColumnParallelLinear( + self.final_linear = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.padded_vocab_size, @@ -335,6 +333,11 @@ def __init__( ): super().__init__() + if neox_args.te_linear: + from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear + else: + from megatron.mpu import ColumnParallelLinear + self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func @@ -388,7 +391,7 @@ def __init__( if not self.gqa: # Strided linear layer. - self.query_key_value = mpu.ColumnParallelLinear( + self.query_key_value = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, @@ -398,7 +401,7 @@ def __init__( ) else: # QKV proj is smaller if we are using GQA / MQA - self.query_key_value = mpu.ColumnParallelLinear( + self.query_key_value = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, @@ -1191,7 +1194,7 @@ def forward(self, x, attention_mask, layer_past=None): self.layer_past = presents if attention_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): + with torch.enable_grad(): attention_output = bias_dropout_fn( attention_output, bias=attention_bias.expand_as(attention_output), @@ -1202,7 +1205,7 @@ def forward(self, x, attention_mask, layer_past=None): # mlp operator mlp_output, mlp_bias = self.mlp(x2) if mlp_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): + with torch.enable_grad(): output = bias_dropout_fn( mlp_output, bias=mlp_bias.expand_as(mlp_output), @@ -1228,7 +1231,7 @@ def forward(self, x, attention_mask, layer_past=None): if self.use_cache: attention_output, presents = attention_output self.layer_past = presents - with torch.enable_grad() if not self.eval else nullcontext(): + with torch.enable_grad(): if attention_bias is not None: # Use special bias_dropout_fn if we have a bias term from the above attention layer attention_output = bias_dropout_fn( @@ -1267,7 +1270,7 @@ def forward(self, x, attention_mask, layer_past=None): else: raise KeyError(self.moe_type) - with torch.enable_grad() if not self.eval else nullcontext(): + with torch.enable_grad(): if ( self.activation == "swiglu" or self.num_experts > 1 diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 338513a97..8a2a2d165 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -1,4 +1,23 @@ +import math + import torch +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.parameter import Parameter + +from megatron.mpu.initialize import get_model_parallel_rank +from megatron.mpu.initialize import get_model_parallel_world_size +from megatron.mpu.initialize import get_tensor_model_parallel_group +from megatron.mpu.mappings import copy_to_model_parallel_region +from megatron.mpu.mappings import gather_from_model_parallel_region +from megatron.mpu.mappings import reduce_from_model_parallel_region +from megatron.mpu.mappings import scatter_to_model_parallel_region +from megatron.mpu.mappings import reduce_scatter_to_sequence_parallel_region +from megatron.mpu.mappings import gather_from_sequence_parallel_region +from megatron.mpu.random import get_cuda_rng_tracker +from megatron.mpu.utils import divide +from megatron.mpu.utils import VocabUtility +from functools import partial try: import transformer_engine as te @@ -57,14 +76,16 @@ class TELinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer. """ + def __init__(self, in_features, out_features, bias=True): - def __init__(self): - # TODO - return + super(TELinear, self).__init__(in_features,out_features,bias) + + + # self.linear = te.pytorch.Linear(in_features, out_features, bias=use_bias, init_method=weight, **kwargs) - def forward(self, x): - # TODO - return + + # def forward(self, x): + # return self.linear(x) class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): @@ -82,22 +103,138 @@ def forward(self, x): return -class TEColumnParallelLinear(TELinear): +class TEColumnParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar to megatron's `ColumnParallelLinear` layer. """ + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ - def __init__(self): - # TODO - return - - def forward(self, x): - # TODO - return - + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + gather_output=True, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + MOE=False, + MoE_mp_size=1, + mup_rescale_parameters=False, + seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout. + ): + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + self.use_bias = bias + + self.sequence_parallel = neox_args.sequence_parallel + self.seq_dim = seq_dim + + self.init_method = init_method + self.stride = stride + self.mup_rescale_parameters = mup_rescale_parameters + self.use_mup = neox_args.use_mup + self.params_dtype=neox_args.params_dtype + self.parallel_mode="column" + # print("##########################") + # print(self.return_bias) + + super(TEColumnParallelLinear, self).__init__(in_features=self.input_size, out_features=self.output_size, + bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, + tp_size=self.world_size, parallel_mode=self.parallel_mode, return_bias=self.skip_bias_add, + params_dtype=self.params_dtype) + + # Copied from Mup + def width_mult(self): + assert hasattr(self.weight, "infshape"), ( + "Please call set_base_shapes(...). If using torch.nn.DataParallel, " + "switch to distributed training with " + "torch.nn.parallel.DistributedDataParallel instead" + ) + return self.weight.infshape.width_mult() -class TERowParallelLinear(TELinear): + # Copied from Mup + def _rescale_parameters(self): + """Rescale parameters to convert SP initialization to μP initialization. + Warning: This method is NOT idempotent and should be called only once + unless you know what you are doing. + """ + if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params: + raise RuntimeError( + "`_rescale_parameters` has been called once before already. " + "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" + "If you called `set_base_shapes` on a model loaded from a checkpoint, " + "or just want to re-set the base shapes of an existing model, " + "make sure to set the flag `rescale_params=False`.\n" + "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." + ) + if self.bias is not None: + self.bias.data *= self.width_mult() ** 0.5 + self.weight.data *= self.width_mult() ** 0.5 + self._has_rescaled_params = True + + def mup_reinitialize_weights(self, neox_args): + if neox_args.use_cpu_initialization: + self.master_weight = _initialize_affine_weight_cpu( + neox_args, + self.weight, + self.output_size, + self.input_size, + self.output_size_per_partition, + 0, + partial(self.init_method, use_mup=True), + stride=self.stride, + return_master_weight=keep_master_weight_for_test, + ) + else: + _initialize_affine_weight_gpu( + self.weight, + partial(self.init_method, use_mup=True), + partition_dim=0, + stride=self.stride, + ) + + def forward(self, inp): + output = super(TEColumnParallelLinear, self).forward(inp) + if self.skip_bias_add: + return output + else: + return output, None + +class TERowParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar to megatron's `RowParallelLinear` layer. diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 0dbdc8be0..cfa51fcfa 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -309,11 +309,6 @@ class NeoXArgsModel(NeoXArgsTemplate): Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "reglu", "swiglu", "bilinear", "glu"] """ - use_flashattn_swiglu: bool = False - """ - Use flash attention's version of swiglu - """ - scaled_upper_triang_masked_softmax_fusion: bool = False """ Enable fusion of query_key_value_scaling time (upper diagonal) masking and softmax. @@ -501,7 +496,16 @@ class NeoXArgsModel(NeoXArgsTemplate): """ Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ + + te_linear: bool = False + """ + Use TransformerEngine for Linear, ColumnParallelLinear, and RowParallelLinear layers. + """ + te_attention: bool = False + """ + Use TransformerEngine for attention layers. + """ @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): @@ -1052,9 +1056,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): Dataset implementation, can be one of "gpt2" or "pairwise" """ - train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal" + train_impl: Literal["normal", "dpo", "rm"] = "normal" """ - Training implementation, can be one of "normal", "dpo", "kto", or "rm" + Training implementation, can be one of "normal", "dpo", or "rm" """ dpo_fp32: bool = True @@ -1062,16 +1066,12 @@ class NeoXArgsTraining(NeoXArgsTemplate): Whether to cast logits to fp32 for DPO loss calculation. """ - dpo_reference_free: bool = False - """ - Whether to use reference-free DPO. - """ - dpo_beta: float = 0.1 """ Beta value for DPO """ +<<<<<<< HEAD kto_fp32: bool = True """ Whether to cast logits to fp32 for KTO loss calculation. @@ -1095,8 +1095,13 @@ class NeoXArgsTraining(NeoXArgsTemplate): """ kto_beta: float = 0.1 +======= + z_loss: float = 0.0 +>>>>>>> Implemented ColumnParallelLinear with Transformer-Engine """ - Beta value for KTO + Z-loss parameter, only implemented for RM training currently. + https://arxiv.org/pdf/2204.02311 + https://arxiv.org/pdf/2309.10305 """ allow_chopped: bool = True From 5162d54874fc84d4406c295693ba96ecbe3ada86 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 02:23:58 -0500 Subject: [PATCH 11/27] Implemented RowParallelLinear with Transformer-Engine --- megatron/model/transformer.py | 22 ++-- megatron/model/transformer_engine.py | 159 ++++++++++++++++++++++++--- megatron/model/utils.py | 13 +++ megatron/neox_arguments/neox_args.py | 9 +- 4 files changed, 172 insertions(+), 31 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 08e5987de..360122f33 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -47,6 +47,7 @@ ) from megatron.model.utils import configure_sparse_attention from deepspeed.moe.layer import MoE +from .utils import linear_implementation_router # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) @@ -107,10 +108,7 @@ def __init__( self.bias_gelu_fusion = neox_args.bias_gelu_fusion self.multiple_of = multiple_of - if neox_args.te_linear: - from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear - else: - from megatron.mpu import ColumnParallelLinear + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) if neox_args.intermediate_size: ffn_dim = neox_args.intermediate_size @@ -147,7 +145,7 @@ def __init__( bias=neox_args.use_bias_in_mlp, ) # Project back to h. - self.linear2 = mpu.RowParallelLinear( + self.linear2 = RowParallelLinear( neox_args=neox_args, input_size=ffn_dim_in, output_size=neox_args.hidden_size, @@ -210,10 +208,7 @@ def __init__( ): super().__init__() - if neox_args.te_linear: - from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear - else: - from megatron.mpu import ColumnParallelLinear + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) self.is_rm = neox_args.train_impl == "rm" parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" @@ -247,7 +242,7 @@ def __init__( # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here # ) else: # Not using cross entropy loss for RMs - self.rm_linear = mpu.RowParallelLinear( + self.rm_linear = RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=1, @@ -333,10 +328,7 @@ def __init__( ): super().__init__() - if neox_args.te_linear: - from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear - else: - from megatron.mpu import ColumnParallelLinear + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" @@ -509,7 +501,7 @@ def __init__( self.attention_dropout = nn.Dropout(self.dropout_p) # Output. - self.dense = mpu.RowParallelLinear( + self.dense = RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size, diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 8a2a2d165..5de2c3459 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -107,11 +107,6 @@ class TEColumnParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar to megatron's `ColumnParallelLinear` layer. - """ - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. Arguments: input_size: first dimension of matrix A. @@ -145,7 +140,7 @@ def __init__( MOE=False, MoE_mp_size=1, mup_rescale_parameters=False, - seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout. + seq_dim=0, ): # Keep input parameters self.input_size = input_size @@ -186,6 +181,12 @@ def width_mult(self): ) return self.weight.infshape.width_mult() + def set_parallel_output(self, value: bool): + assert isinstance(value, bool) + self.gather_output = ( + not value + ) # if gather_output is True, parallel output is False, so we set the opposite + # Copied from Mup def _rescale_parameters(self): """Rescale parameters to convert SP initialization to μP initialization. @@ -227,8 +228,21 @@ def mup_reinitialize_weights(self, neox_args): stride=self.stride, ) - def forward(self, inp): - output = super(TEColumnParallelLinear, self).forward(inp) + def forward(self, inp, **kwargs): + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() + + output = super(TEColumnParallelLinear, self).forward(inp, **kwargs) + + if self.gather_output: + # All-gather across the partitions. + assert ( + not self.sequence_parallel + ), "sequence_parallel=True and gather_output=True are incompatible!" + output = gather_from_model_parallel_region(output_parallel) + else: + output = output_parallel + if self.skip_bias_add: return output else: @@ -238,15 +252,132 @@ class TERowParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar to megatron's `RowParallelLinear` layer. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. """ + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + input_is_parallel=False, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + MOE=False, + MoE_mp_size=1, + parallel_output=False, + mup_rescale_parameters=False, + ): + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + # Divide the weight matrix along the last dimension. + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + self.use_bias = bias + self.input_is_parallel = input_is_parallel + self.sequence_parallel = neox_args.sequence_parallel - def __init__(self): - # TODO - return + self.init_method = init_method + self.stride = stride + self.mup_rescale_parameters = mup_rescale_parameters + self.use_mup = neox_args.use_mup + self.params_dtype=neox_args.params_dtype + self.parallel_mode="row" + + # if self.input_is_parallel: + # self.input_size = divide(self.input_size, self.world_size) - def forward(self, x): - # TODO - return + super(TERowParallelLinear, self).__init__(in_features=self.input_size, out_features=self.output_size, + bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, + tp_size=self.world_size, parallel_mode=self.parallel_mode, return_bias=self.skip_bias_add, + params_dtype=self.params_dtype) + + # Copied from Mup + def width_mult(self): + assert hasattr(self.weight, "infshape"), ( + "Please call set_base_shapes(...). If using torch.nn.DataParallel, " + "switch to distributed training with " + "torch.nn.parallel.DistributedDataParallel instead" + ) + return self.weight.infshape.width_mult() + + # Copied from Mup + def _rescale_parameters(self): + """Rescale parameters to convert SP initialization to μP initialization. + Warning: This method is NOT idempotent and should be called only once + unless you know what you are doing. + """ + if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params: + raise RuntimeError( + "`_rescale_parameters` has been called once before already. " + "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" + "If you called `set_base_shapes` on a model loaded from a checkpoint, " + "or just want to re-set the base shapes of an existing model, " + "make sure to set the flag `rescale_params=False`.\n" + "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." + ) + if self.bias is not None: + self.bias.data *= self.width_mult() ** 0.5 + self.weight.data *= self.width_mult() ** 0.5 + self._has_rescaled_params = True + + def mup_reinitialize_weights(self, neox_args): + if neox_args.use_cpu_initialization: + self.master_weight = _initialize_affine_weight_cpu( + neox_args, + self.weight, + self.output_size, + self.input_size, + self.input_size_per_partition, + 1, + partial(self.init_method, use_mup=True), + stride=self.stride, + return_master_weight=self.keep_master_weight_for_test, + ) + else: + _initialize_affine_weight_gpu( + self.weight, + partial(self.init_method, use_mup=True), + partition_dim=1, + stride=self.stride, + ) + + def set_parallel_output(self, parallel_output: bool): + assert isinstance(parallel_output, bool) + self.parallel_output = parallel_output + + def forward(self, inp, **kwargs): + # if not self.input_is_parallel: + # inp = scatter_to_model_parallel_region(inp) + + output = super(TERowParallelLinear, self).forward(inp, **kwargs) + if self.skip_bias_add: + return output + else: + return output, None class TEDotProductAttention(te.pytorch.DotProductAttention): diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 8176f1f7a..d1ec2a347 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -402,3 +402,16 @@ def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): for name, param in module_.named_parameters(): if param.requires_grad: param.register_hook(reduce_weight_grads_from_model_parallel_region) + + +def linear_implementation_router(neox_args): + if neox_args.te_columnparallel: + from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear + else: + from megatron.mpu import ColumnParallelLinear + if neox_args.te_rowparallel: + from megatron.model.transformer_engine import TERowParallelLinear as RowParallelLinear + else: + from megatron.mpu import RowParallelLinear + + return ColumnParallelLinear, RowParallelLinear \ No newline at end of file diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index cfa51fcfa..d295bed7b 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -497,9 +497,14 @@ class NeoXArgsModel(NeoXArgsTemplate): Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ - te_linear: bool = False + te_columnparallel: bool = False """ - Use TransformerEngine for Linear, ColumnParallelLinear, and RowParallelLinear layers. + Use TransformerEngine for RowParallelLinear layer. + """ + + te_rowparallel: bool = False + """ + Use TransformerEngine for ColumnParallelLinear layer. """ te_attention: bool = False From 3cad89c907fe7b283d37cb5d9466b81c81c356df Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 16:25:01 -0500 Subject: [PATCH 12/27] Implemented LayerNormMLP with Transformer-Engine --- megatron/model/transformer.py | 29 +++++++-- megatron/model/transformer_engine.py | 97 ++++++++++++++++++++++------ megatron/neox_arguments/neox_args.py | 4 +- 3 files changed, 102 insertions(+), 28 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 360122f33..5003ef1d5 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -132,7 +132,6 @@ def __init__( ffn_dim_in = int( self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) ) - self.linear1 = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -383,7 +382,7 @@ def __init__( if not self.gqa: # Strided linear layer. - self.query_key_value = ColumnParallelLinear( + self.query_key_value = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, @@ -393,7 +392,7 @@ def __init__( ) else: # QKV proj is smaller if we are using GQA / MQA - self.query_key_value = ColumnParallelLinear( + self.query_key_value = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, @@ -1045,6 +1044,17 @@ def get_mlp(**kw): **kw, ) + def get_te_lnmlp(**kw): + from megatron.model.transformer_engine import TELayerNormMLP + return TELayerNormMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + multiple_of=neox_args.mlp_multiple_of, + **kw, + ) + self.num_experts = ( neox_args.moe_num_experts if layer_number % neox_args.expert_interval == 0 @@ -1052,7 +1062,10 @@ def get_mlp(**kw): ) args = neox_args if self.num_experts <= 1: - self.mlp = get_mlp() + if neox_args.te_layernorm_mlp: + self.mlp = get_te_lnmlp() + else: + self.mlp = get_mlp() else: from torch import distributed as dist @@ -1171,9 +1184,15 @@ def forward(self, x, attention_mask, layer_past=None): residual = x # applies the correct normalization depending on if the norms are tied - if self.gpt_j_tied: + if self.gpt_j_tied and not neox_args.te_layernorm_mlp: x = self.input_layernorm(x) x1, x2 = x, x + elif self.gpt_j_tied and neox_args.te_layernorm_mlp: + x2 = x + x = self.input_layernorm(x) + x1 = x + elif neox_args.te_layernorm_mlp: + x1, x2 = self.input_layernorm(x), x else: x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 5de2c3459..7c69cee1f 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -5,6 +5,8 @@ import torch.nn.init as init from torch.nn.parameter import Parameter +from megatron.model.transformer import Gated_Activation +from megatron.model.activations import get_activation from megatron.mpu.initialize import get_model_parallel_rank from megatron.mpu.initialize import get_model_parallel_world_size from megatron.mpu.initialize import get_tensor_model_parallel_group @@ -88,19 +90,84 @@ def __init__(self, in_features, out_features, bias=True): # return self.linear(x) -class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): +class TELayerNormMLP(te.pytorch.LayerNormMLP): """ - Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines - layernorm and linear layers + Wrapper for the Transformer-Engine's `LayerNormMLP` layer that combines + layernorm and followed by the MLP module, consisting of 2 successive + linear transformations, separated by the GeLU activation. """ - def __init__(self): - # TODO - return + def __init__( + self, + neox_args, + init_method, + output_layer_init_method, + parallel_output=False, + multiple_of=256, + MOE=False, + MoE_mp_size=1, + bias=True + ): + self.activation_func, self.is_gated = get_activation(neox_args) + self.activation_type = neox_args.activation + self.bias_gelu_fusion = neox_args.bias_gelu_fusion + self.multiple_of = multiple_of + self.bias = bias + self.init_method = init_method + self.output_layer_init_method = output_layer_init_method - def forward(self, x): - # TODO - return + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.sequence_parallel = neox_args.sequence_parallel + self.seq_len = neox_args.seq_length + self.batch_size = neox_args.train_micro_batch_size_per_gpu + self.params_dtype=neox_args.params_dtype + self.set_parallel_mode=False + if world_size > 1: + self.set_parallel_mode=True + + if neox_args.intermediate_size: + ffn_dim = neox_args.intermediate_size + elif neox_args.expansion_factor: + ffn_dim = int(neox_args.expansion_factor * neox_args.hidden_size) + else: + # 4h is default for ffn_dim + ffn_dim = 4 * neox_args.hidden_size + ffn_dim_in = ffn_dim + if self.is_gated: + # set activation function to be gated implementation + self.activation_func = Gated_Activation(self.activation_func) + # auto scale so gated activations has equal parameters + ffn_dim = int(ffn_dim * 2 / 3) + ffn_dim_in = ffn_dim // 2 + # set multiple + ffn_dim = int( + (2 * self.multiple_of) + * ((ffn_dim + (2 * multiple_of) - 1) // (2 * multiple_of)) + ) + ffn_dim_in = int( + self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) + ) + + if neox_args.norm in ['layernorm','te_layernorm']: + self.eps=1.0e-5 + self.normalization = 'LayerNorm' + elif neox_args.norm == ['rmsnorm','te_rmsnorm']: + self.eps=1.0e-8 + self.normalization = 'RMSNorm' + #TODO handle case if norm is not rmsnorm or layernorm + #TODO check if activation in list ‘gelu’, ‘geglu’, ‘relu’, ‘reglu’, ‘squared_relu’, + #‘swiglu’, ‘qgelu’, ‘srelu’ + #TODO handle MOE and mup + + super(TELayerNormMLP, self).__init__(hidden_size=neox_args.hidden_size, ffn_hidden_size=ffn_dim, + eps=self.eps, bias=self.bias, normalization=self.normalization, activation=neox_args.activation, + init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, + device=torch.cuda.current_device(), set_parallel_mode=self.set_parallel_mode, + sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, tp_size=self.world_size, + return_bias=neox_args.use_bias_in_mlp, params_dtype=self.params_dtype, seq_length=self.seq_len, + micro_batch_size=self.batch_size) class TEColumnParallelLinear(te.pytorch.Linear): @@ -234,15 +301,6 @@ def forward(self, inp, **kwargs): output = super(TEColumnParallelLinear, self).forward(inp, **kwargs) - if self.gather_output: - # All-gather across the partitions. - assert ( - not self.sequence_parallel - ), "sequence_parallel=True and gather_output=True are incompatible!" - output = gather_from_model_parallel_region(output_parallel) - else: - output = output_parallel - if self.skip_bias_add: return output else: @@ -305,9 +363,6 @@ def __init__( self.use_mup = neox_args.use_mup self.params_dtype=neox_args.params_dtype self.parallel_mode="row" - - # if self.input_is_parallel: - # self.input_size = divide(self.input_size, self.world_size) super(TERowParallelLinear, self).__init__(in_features=self.input_size, out_features=self.output_size, bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index d295bed7b..7484212b1 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -507,9 +507,9 @@ class NeoXArgsModel(NeoXArgsTemplate): Use TransformerEngine for ColumnParallelLinear layer. """ - te_attention: bool = False + te_layernorm_mlp: bool = False """ - Use TransformerEngine for attention layers. + Use TransformerEngine for LayerNormMLP layer. """ @dataclass From eedb6c24d808a482aec3918a9f74ec491674c061 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 19:10:27 -0500 Subject: [PATCH 13/27] Implemented MultiheadAttention with Transformer-Engine --- megatron/model/transformer.py | 132 +++++++++++++++++++++++++-- megatron/model/transformer_engine.py | 117 ++++++++++++++++++++---- megatron/neox_arguments/neox_args.py | 5 + 3 files changed, 229 insertions(+), 25 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 5003ef1d5..099a0a899 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -327,8 +327,6 @@ def __init__( ): super().__init__() - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) - self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func @@ -748,6 +746,106 @@ def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe ) + def te_attention( + self, query_layer, key_layer, value_layer, layer_past, attention_mask + ): + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + output_size[2], output_size[0] * output_size[1], -1 + ) + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # preallocating result tensor: [b * np, sq, sk] + matmul_result = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=torch.cuda.current_device(), + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== + + if self.use_cache: + with torch.no_grad(): + attention_mask = attention_mask[ + ..., : attention_scores.size(3), : attention_scores.size(3) + ] + + # =========================== + # Attention probs and dropout + # =========================== + + if exists(self.rpe): + rpe = self.rpe(query_layer.size(0), key_layer.size(0)) + attention_scores += rpe # [1, np, sq, sk] + + if self.pos_emb == "alibi": + attention_scores = self.alibi_embed(attention_scores) + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) + + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + return context_layer + def gqa_project(self, hidden_states, attention_mask, layer_past=None): # QKV projection and separation into separate Q/K/V layers for GQA, # where KV projections may be smaller than Q projection. @@ -1016,7 +1114,9 @@ def __init__( ) # Self attention. - self.attention = ParallelSelfAttention( + if neox_args.te_mha: + from megatron.model.transformer_engine import TEMultiheadAttention + self.attention = TEMultiheadAttention( neox_args=neox_args, attention_mask_func=attention_mask_func, init_method=init_method, @@ -1026,7 +1126,20 @@ def __init__( use_cache=self.use_cache, rotary=rotary, parallel_output=self.gpt_j_residual, - ) + ) + + else: + self.attention = ParallelSelfAttention( + neox_args=neox_args, + attention_mask_func=attention_mask_func, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + rpe=rpe, + use_cache=self.use_cache, + rotary=rotary, + parallel_output=self.gpt_j_residual, + ) # Layernorm on the output of the attention layer. # If GPT-J residuals are used, this is surpurfulous but leaving it in @@ -1184,14 +1297,14 @@ def forward(self, x, attention_mask, layer_past=None): residual = x # applies the correct normalization depending on if the norms are tied - if self.gpt_j_tied and not neox_args.te_layernorm_mlp: + if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: x = self.input_layernorm(x) x1, x2 = x, x - elif self.gpt_j_tied and neox_args.te_layernorm_mlp: + elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: x2 = x x = self.input_layernorm(x) x1 = x - elif neox_args.te_layernorm_mlp: + elif self.neox_args.te_layernorm_mlp: x1, x2 = self.input_layernorm(x), x else: x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) @@ -1263,7 +1376,10 @@ def forward(self, x, attention_mask, layer_past=None): ) # output = x + mlp(ln2(x)) - layernorm_output = self.post_attention_layernorm(attention_output) + if self.neox_args.te_layernorm_mlp: + layernorm_output = attention_output + else: + layernorm_output = self.post_attention_layernorm(attention_output) mlp_bias = torch.tensor( 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype ) diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 7c69cee1f..9a8c0a506 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -78,16 +78,51 @@ class TELinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer. """ - def __init__(self, in_features, out_features, bias=True): + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + init_method=init.xavier_normal_, + stride=1, + skip_bias_add=False, + mup_rescale_parameters=False, + seq_dim=0, + ): + # Keep input parameters + self.input_size = input_size + self.output_size = output_size - super(TELinear, self).__init__(in_features,out_features,bias) - + self.skip_bias_add = skip_bias_add + self.use_bias = bias - # self.linear = te.pytorch.Linear(in_features, out_features, bias=use_bias, init_method=weight, **kwargs) + self.sequence_parallel = neox_args.sequence_parallel + self.seq_dim = seq_dim + + self.init_method = init_method + self.stride = stride + self.mup_rescale_parameters = mup_rescale_parameters + self.use_mup = neox_args.use_mup + self.params_dtype=neox_args.params_dtype + # print("##########################") + # print(self.return_bias) + + super(TELinear, self).__init__(in_features=self.input_size, out_features=self.output_size, + bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), return_bias=self.skip_bias_add, params_dtype=self.params_dtype) + + def forward(self, inp, **kwargs): + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() + + output = super(TELinear, self).forward(inp, **kwargs) - # def forward(self, x): - # return self.linear(x) + if self.skip_bias_add: + return output + else: + return output, None class TELayerNormMLP(te.pytorch.LayerNormMLP): @@ -121,7 +156,7 @@ def __init__( self.tp_group = get_tensor_model_parallel_group() self.sequence_parallel = neox_args.sequence_parallel self.seq_len = neox_args.seq_length - self.batch_size = neox_args.train_micro_batch_size_per_gpu + self.micro_batch_size = neox_args.train_micro_batch_size_per_gpu self.params_dtype=neox_args.params_dtype self.set_parallel_mode=False if world_size > 1: @@ -166,8 +201,8 @@ def __init__( init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, device=torch.cuda.current_device(), set_parallel_mode=self.set_parallel_mode, sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, tp_size=self.world_size, - return_bias=neox_args.use_bias_in_mlp, params_dtype=self.params_dtype, seq_length=self.seq_len, - micro_batch_size=self.batch_size) + return_bias=True, params_dtype=self.params_dtype, seq_length=self.seq_len, + micro_batch_size=self.micro_batch_size) class TEColumnParallelLinear(te.pytorch.Linear): @@ -435,19 +470,67 @@ def forward(self, inp, **kwargs): return output, None -class TEDotProductAttention(te.pytorch.DotProductAttention): +class TEMultiheadAttention(te.pytorch.MultiheadAttention): """ - Wrapper for the Transformer-Engine's `DotProductAttention` layer that also + Wrapper for the Transformer-Engine's `MultiheadAttention` layer that also has "flash attention" enabled. """ - def __init__(self): - # TODO - return + def __init__(self, + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=None, + rotary=False, + use_cache=False, + parallel_output=False): - def forward(self, x): - # TODO - return + self.attention_mask_func = attention_mask_func + self.init_method = init_method + self.output_layer_init_method = output_layer_init_method + self.layer_number = layer_number + 1 + + world_size = get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.sequence_parallel = neox_args.sequence_parallel + self.seq_len = neox_args.seq_length + self.micro_batch_size = neox_args.train_micro_batch_size_per_gpu + self.params_dtype=neox_args.params_dtype + self.set_parallel_mode=False + if world_size > 1: + self.set_parallel_mode=True + + if neox_args.norm in ['layernorm','te_layernorm']: + self.eps=1.0e-5 + self.normalization = 'LayerNorm' + elif neox_args.norm == ['rmsnorm','te_rmsnorm']: + self.eps=1.0e-8 + self.normalization = 'RMSNorm' + + if ( + not neox_args.num_kv_heads + or neox_args.num_kv_heads == neox_args.num_attention_heads + ): + self.gqa = False + self.num_kv_heads = None + else: + self.gqa = True + self.num_kv_heads = neox_args.num_kv_heads + + super(TEMultiheadAttention, self).__init__(hidden_size=neox_args.hidden_size, num_attention_heads=neox_args.num_attention_heads, + attention_dropout=neox_args.attention_dropout, layernorm_epsilon=self.eps, init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, layer_number=self.layer_number, + window_size=neox_args.sliding_window_width, num_gqa_groups=self.num_kv_heads, input_layernorm=False, + normalization=self.normalization, bias=True, device=torch.cuda.current_device(), + set_parallel_mode=self.set_parallel_mode, sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, + tp_size=self.world_size, params_dtype=self.params_dtype, return_bias=True) + + def forward(self, hidden_states, attention_mask, layer_past=None, **kwargs): + output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, **kwargs) + return output class TEDelayedScaling(te.common.recipe.DelayedScaling): diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 7484212b1..b6adf727e 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -512,6 +512,11 @@ class NeoXArgsModel(NeoXArgsTemplate): Use TransformerEngine for LayerNormMLP layer. """ + te_mha: bool = False + """ + Use TransformerEngine for MultiheadAttention layer. + """ + @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): """ From 36ad680ecf1e88f5c7c1ab9434d5d59998ecd5eb Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 19:23:11 -0500 Subject: [PATCH 14/27] Cleaned up transformer.py --- megatron/model/transformer.py | 137 +++++++--------------------------- 1 file changed, 25 insertions(+), 112 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 099a0a899..62f316f3e 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -18,6 +18,8 @@ """Transformer.""" import math +from contextlib import nullcontext + import torch import torch.nn.functional as F import torch.nn as nn @@ -47,6 +49,12 @@ ) from megatron.model.utils import configure_sparse_attention from deepspeed.moe.layer import MoE + +try: + from flash_attn.ops.activations import swiglu +except ImportError: + swiglu = None + from .utils import linear_implementation_router # flags required to enable jit fusion kernels @@ -120,7 +128,12 @@ def __init__( ffn_dim_in = ffn_dim if self.is_gated: # set activation function to be gated implementation - self.activation_func = Gated_Activation(self.activation_func) + self.activation_func = Gated_Activation( + self.activation_func, + (swiglu is not None) + and (neox_args.activation == "swiglu") + and neox_args.use_flashattn_swiglu, + ) # auto scale so gated activations has equal parameters ffn_dim = int(ffn_dim * 2 / 3) ffn_dim_in = ffn_dim // 2 @@ -160,10 +173,6 @@ def __init__( def forward(self, hidden_states): # [s, b, intermediate_size] intermediate_parallel, bias_parallel = self.linear1(hidden_states) - # output = self.linear1(hidden_states) - # print(output) - # import sys - # sys.exit() if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): intermediate_parallel = self.activation_func( intermediate_parallel, bias_parallel @@ -179,9 +188,10 @@ def forward(self, hidden_states): class Gated_Activation(torch.nn.Module): - def __init__(self, activation_func): + def __init__(self, activation_func, use_swiglu=False): super().__init__() self.activation_func = activation_func + self.use_swiglu = use_swiglu def forward(self, x, bias=None): x, gate = x.chunk(2, dim=-1) @@ -189,8 +199,11 @@ def forward(self, x, bias=None): bias_1, bias_2 = bias.chunk(2, dim=-1) x = x + bias_1 gate = gate + bias_2 - intermediate_parallel = self.activation_func(gate) - return intermediate_parallel * x + if not self.use_swiglu: + intermediate_parallel = self.activation_func(gate) + return intermediate_parallel * x + else: + return swiglu(gate, x) class ParallelLinear(nn.Module): @@ -746,106 +759,6 @@ def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe ) - def te_attention( - self, query_layer, key_layer, value_layer, layer_past, attention_mask - ): - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = ( - query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0), - ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view( - output_size[2], output_size[0] * output_size[1], -1 - ) - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - # preallocating result tensor: [b * np, sq, sk] - matmul_result = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device(), - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - # ================================================== - # Update attention mask for inference. [b, np, sq, sk] - # ================================================== - - if self.use_cache: - with torch.no_grad(): - attention_mask = attention_mask[ - ..., : attention_scores.size(3), : attention_scores.size(3) - ] - - # =========================== - # Attention probs and dropout - # =========================== - - if exists(self.rpe): - rpe = self.rpe(query_layer.size(0), key_layer.size(0)) - attention_scores += rpe # [1, np, sq, sk] - - if self.pos_emb == "alibi": - attention_scores = self.alibi_embed(attention_scores) - - # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - with mpu.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = ( - value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3), - ) - - # change view [sk, b * np, hn] - value_layer = value_layer.view( - value_layer.size(0), output_size[0] * output_size[1], -1 - ) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view( - output_size[0] * output_size[1], output_size[2], -1 - ) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - return context_layer - def gqa_project(self, hidden_states, attention_mask, layer_past=None): # QKV projection and separation into separate Q/K/V layers for GQA, # where KV projections may be smaller than Q projection. @@ -1318,7 +1231,7 @@ def forward(self, x, attention_mask, layer_past=None): self.layer_past = presents if attention_bias is not None: - with torch.enable_grad(): + with torch.enable_grad() if not self.eval else nullcontext(): attention_output = bias_dropout_fn( attention_output, bias=attention_bias.expand_as(attention_output), @@ -1329,7 +1242,7 @@ def forward(self, x, attention_mask, layer_past=None): # mlp operator mlp_output, mlp_bias = self.mlp(x2) if mlp_bias is not None: - with torch.enable_grad(): + with torch.enable_grad() if not self.eval else nullcontext(): output = bias_dropout_fn( mlp_output, bias=mlp_bias.expand_as(mlp_output), @@ -1355,7 +1268,7 @@ def forward(self, x, attention_mask, layer_past=None): if self.use_cache: attention_output, presents = attention_output self.layer_past = presents - with torch.enable_grad(): + with torch.enable_grad() if not self.eval else nullcontext(): if attention_bias is not None: # Use special bias_dropout_fn if we have a bias term from the above attention layer attention_output = bias_dropout_fn( @@ -1397,7 +1310,7 @@ def forward(self, x, attention_mask, layer_past=None): else: raise KeyError(self.moe_type) - with torch.enable_grad(): + with torch.enable_grad() if not self.eval else nullcontext(): if ( self.activation == "swiglu" or self.num_experts > 1 From 6963103dc62e4b8c455ba5616c34c337f83c501f Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 19:25:20 -0500 Subject: [PATCH 15/27] Cleaned up neox_args --- megatron/neox_arguments/neox_args.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index b6adf727e..f7de09872 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -309,6 +309,11 @@ class NeoXArgsModel(NeoXArgsTemplate): Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "reglu", "swiglu", "bilinear", "glu"] """ + use_flashattn_swiglu: bool = False + """ + Use flash attention's version of swiglu + """ + scaled_upper_triang_masked_softmax_fusion: bool = False """ Enable fusion of query_key_value_scaling time (upper diagonal) masking and softmax. @@ -1066,9 +1071,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): Dataset implementation, can be one of "gpt2" or "pairwise" """ - train_impl: Literal["normal", "dpo", "rm"] = "normal" + train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal" """ - Training implementation, can be one of "normal", "dpo", or "rm" + Training implementation, can be one of "normal", "dpo", "kto", or "rm" """ dpo_fp32: bool = True @@ -1076,12 +1081,16 @@ class NeoXArgsTraining(NeoXArgsTemplate): Whether to cast logits to fp32 for DPO loss calculation. """ + dpo_reference_free: bool = False + """ + Whether to use reference-free DPO. + """ + dpo_beta: float = 0.1 """ Beta value for DPO """ -<<<<<<< HEAD kto_fp32: bool = True """ Whether to cast logits to fp32 for KTO loss calculation. @@ -1105,13 +1114,8 @@ class NeoXArgsTraining(NeoXArgsTemplate): """ kto_beta: float = 0.1 -======= - z_loss: float = 0.0 ->>>>>>> Implemented ColumnParallelLinear with Transformer-Engine """ - Z-loss parameter, only implemented for RM training currently. - https://arxiv.org/pdf/2204.02311 - https://arxiv.org/pdf/2309.10305 + Beta value for KTO """ allow_chopped: bool = True From afc9c9293c509a09345ef506090306d1638bc5df Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Sun, 15 Sep 2024 19:26:04 -0500 Subject: [PATCH 16/27] Cleaned up neox_args --- megatron/neox_arguments/neox_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index f7de09872..b74556bad 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1118,6 +1118,7 @@ class NeoXArgsTraining(NeoXArgsTemplate): Beta value for KTO """ + allow_chopped: bool = True """ WARNING: if your packing impl is packed, this is ignored. From a0e7acd79492e92a691ff29e44db2e059f680635 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Thu, 26 Sep 2024 08:18:51 +0000 Subject: [PATCH 17/27] - Fixed TE_MHA and added rope support - Implemented delayed scaling --- megatron/model/positional_embeddings.py | 1578 ++++++++++++++++++++--- megatron/model/transformer.py | 260 ++-- megatron/model/transformer_engine.py | 127 +- megatron/neox_arguments/neox_args.py | 40 + 4 files changed, 1638 insertions(+), 367 deletions(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index fcded9e96..fdf384a4f 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -1,4 +1,7 @@ -# Copyright (c) 2024, EleutherAI +# Copyright (c) 2024 EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,241 +15,1406 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch +"""Transformer.""" + import math +from contextlib import nullcontext + +import torch +import torch.nn.functional as F +import torch.nn as nn +from pkg_resources import packaging +from importlib.metadata import version +from .norms import get_norm +from megatron import mpu +from megatron.model import megablocks_utils +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.activations import get_activation +from megatron.model.utils import exists, get_fusion_type +from megatron.model.positional_embeddings import ( + RotaryEmbedding, + apply_rotary_pos_emb_torch, + apply_rotary_pos_emb, + AliBi, +) +from megatron.model.fused_rope import ( + FusedRoPEFunc, + fused_apply_rotary_pos_emb_cached, +) +from megatron.model.fused_bias_dropout import ( + get_bias_dropout_add, + bias_dropout_add_fused_train, + bias_dropout_add_fused_inference, +) +from megatron.model.utils import configure_sparse_attention +from deepspeed.moe.layer import MoE + +try: + from flash_attn.ops.activations import swiglu +except ImportError: + swiglu = None + +from .utils import linear_implementation_router + +# flags required to enable jit fusion kernels +torch._C._jit_set_profiling_mode(False) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_override_can_fuse_on_cpu(True) +torch._C._jit_override_can_fuse_on_gpu(True) + +""" We use the following notation throughout this file: + h: hidden size + n: number of attention heads + kv: number of key or value heads + p: number of model parallel partitions + np: n/p + kvp: kv/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + l: number of layers + Transformer takes input of size [s, b, h] and returns a + tensor of the same size. We use the following arguments: + hyperparameters: transformer hyperparameters + attention_mask_func: a function that takes `unmasked-attention-scores` + with size [b, np, s, s] and an `attention-mask` and will apply + the masking. The function should return a masked score of the + same size [b, np, s, s]. + masked-attention-scores = attention_mask_func( + unmasked-attention-scores, attention-mask) +""" + + +class ParallelMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + """ -class SinusoidalPositionalEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000, precision=torch.half): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - self.precision = precision - - def forward(self, x, seq_dim=1): - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) - sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq) - if self.precision == torch.bfloat16: - sinusoid_inp = sinusoid_inp.float() - sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos() - if self.precision == torch.bfloat16: - sin, cos = sin.bfloat16(), cos.bfloat16() - emb = torch.cat((sin, cos), dim=-1) - return emb[None, :, :] - - -class RotaryEmbedding(torch.nn.Module): def __init__( - self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False + self, + neox_args, + init_method, + output_layer_init_method, + parallel_output=False, + multiple_of=256, + MOE=False, + MoE_mp_size=1, ): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) - self.seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - self.precision = precision - self.max_seq_len = max_seq_len - self.base = base - self.dim = dim + assert ( + neox_args.intermediate_size == None or neox_args.expansion_factor == None + ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" + + self.activation_func, self.is_gated = get_activation(neox_args) + self.activation_type = neox_args.activation + self.bias_gelu_fusion = neox_args.bias_gelu_fusion + self.multiple_of = multiple_of - # precompute cos_cached, sin_cached in fp32 - cos_cached, sin_cached, inv_freq = self._prepare_cache( - max_seq_len, precision, base + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) + + if neox_args.intermediate_size: + ffn_dim = neox_args.intermediate_size + elif neox_args.expansion_factor: + ffn_dim = int(neox_args.expansion_factor * neox_args.hidden_size) + else: + # 4h is default for ffn_dim + ffn_dim = 4 * neox_args.hidden_size + ffn_dim_in = ffn_dim + if self.is_gated: + # set activation function to be gated implementation + self.activation_func = Gated_Activation( + self.activation_func, + (swiglu is not None) + and (neox_args.activation == "swiglu") + and neox_args.use_flashattn_swiglu, + ) + # auto scale so gated activations has equal parameters + ffn_dim = int(ffn_dim * 2 / 3) + ffn_dim_in = ffn_dim // 2 + # set multiple + ffn_dim = int( + (2 * self.multiple_of) + * ((ffn_dim + (2 * multiple_of) - 1) // (2 * multiple_of)) + ) + ffn_dim_in = int( + self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) + ) + self.linear1 = ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=ffn_dim, + gather_output=False, + init_method=init_method, + skip_bias_add=True, + MOE=MOE, + MoE_mp_size=MoE_mp_size, + bias=neox_args.use_bias_in_mlp, + ) + # Project back to h. + self.linear2 = RowParallelLinear( + neox_args=neox_args, + input_size=ffn_dim_in, + output_size=neox_args.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + parallel_output=parallel_output, + skip_bias_add=True, + MOE=MOE, + MoE_mp_size=MoE_mp_size, + bias=neox_args.use_bias_in_mlp, ) - self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) - self.cos_cached = cos_cached - self.sin_cached = sin_cached + def forward(self, hidden_states): + # [s, b, intermediate_size] + intermediate_parallel, bias_parallel = self.linear1(hidden_states) + if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): + intermediate_parallel = self.activation_func( + intermediate_parallel, bias_parallel + ) + else: + intermediate_parallel = self.activation_func( + intermediate_parallel + bias_parallel + ) - def _prepare_cache(self, seq_len, precision, base): - # precompute cos_cached, sin_cached in fp32 - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + # [s, b, h] + output, output_bias = self.linear2(intermediate_parallel) + return output, output_bias - t = torch.arange(seq_len).type_as(inv_freq) - freqs = torch.einsum("i,j->ij", t, inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - cos_cached = emb.cos()[:, None, None, :] - sin_cached = emb.sin()[:, None, None, :] +class Gated_Activation(torch.nn.Module): + def __init__(self, activation_func, use_swiglu=False): + super().__init__() + self.activation_func = activation_func + self.use_swiglu = use_swiglu - return ( - cos_cached.to(precision), - sin_cached.to(precision), - inv_freq.to(precision), - ) + def forward(self, x, bias=None): + x, gate = x.chunk(2, dim=-1) + if bias is not None: + bias_1, bias_2 = bias.chunk(2, dim=-1) + x = x + bias_1 + gate = gate + bias_2 + if not self.use_swiglu: + intermediate_parallel = self.activation_func(gate) + return intermediate_parallel * x + else: + return swiglu(gate, x) + + +class ParallelLinear(nn.Module): + """ + A Parallel Linear Layer transforming the transformer outputs from hidden_size -> vocab_size + """ - def forward(self, x, seq_dim=0, seq_len=None): - if seq_len is None: - seq_len = x.shape[seq_dim] + def __init__( + self, + neox_args, + parallel_output=True, + init_method=nn.init.xavier_normal_, + is_last_layer=False, + ): + super().__init__() - assert seq_len <= self.max_seq_len + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) - if seq_len != self.max_seq_len: - # y, z, _ = self._prepare_cache(seq_len, self.precision, self.base) - return ( - self.cos_cached[:seq_len, ...].to(x.device), - self.sin_cached[:seq_len, ...].to(x.device), + self.is_rm = neox_args.train_impl == "rm" + parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" + if parallelism == "column": + self.final_linear = ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.padded_vocab_size, + bias=False, + init_method=init_method, + gather_output=not parallel_output, + skip_bias_add=False, + mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here + seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0 ) else: - return self.cos_cached.to(x.device), self.sin_cached.to(x.device) + if not self.is_rm: + print( + 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.' + ) + exit() + # self.final_linear = mpu.RowParallelLinear( + # neox_args=neox_args, + # input_size=neox_args.hidden_size, + # output_size=neox_args.padded_vocab_size, + # bias=False, + # input_is_parallel=False, + # init_method=init_method, + # parallel_output=parallel_output, + # skip_bias_add=False, + # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here + # ) + else: # Not using cross entropy loss for RMs + self.rm_linear = RowParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=1, + bias=False, + input_is_parallel=False, + init_method=init_method, + parallel_output=False, + skip_bias_add=False, + mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here + ) + def forward(self, hidden_states): + if not self.is_rm: + return self.final_linear(hidden_states) + else: + return self.rm_linear(hidden_states) + + +class _MegablocksAdapter(nn.Module): + def __init__( + self, neox_args, layer_cls, init_method, output_layer_init_method, ep_group + ): + super().__init__() + megablocks_utils.assert_megablocks_is_available() + args = megablocks_utils.as_megablocks_args(neox_args) + args.device = torch.cuda.current_device() + args.init_method = init_method + args.output_layer_init_method = output_layer_init_method + + # NOTE: Shard the MoE layers over the data parallel group. Expert + # parallel sharding and data parallel sharding could be decoupled + # by extending the optimizer to handle data parallel reductions for + # MoE and non-MoE parameters separately. + if args.moe_expert_model_parallelism: + args.expert_parallel_group = ep_group + + self.moe = layer_cls(args) -# rotary pos emb helpers: + def forward(self, x): + return self.moe.forward(x) -def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - return torch.cat( - (-x2, x1), dim=x1.ndim - 1 - ) # dim=-1 triggers a bug in earlier torch versions +class MbMoE(_MegablocksAdapter): + def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): + super().__init__( + neox_args, + megablocks_utils.moe.MoE, + init_method, + output_layer_init_method, + ep_group, + ) -@torch.jit.script -def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): - cos, sin = ( - cos[offset : q.shape[0] + offset, ...], - sin[offset : q.shape[0] + offset, ...], - ) - return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) +class dMoE(_MegablocksAdapter): + def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): + super().__init__( + neox_args, + megablocks_utils.dmoe.dMoE, + init_method, + output_layer_init_method, + ep_group, + ) -def apply_rotary_pos_emb_torch( - q, k, cos, sin, offset: int = 0 -): # jitting fails with bf16 - cos, sin = ( - cos[offset : q.shape[0] + offset, ...], - sin[offset : q.shape[0] + offset, ...], - ) - return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) +class ParallelSelfAttention(nn.Module): + """Parallel self-attention layer abstract class. + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + """ -class AliBi(torch.nn.Module): - def __init__(self, num_heads, mp_size=1, mp_rank=1): + def __init__( + self, + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=None, + rotary=False, + use_cache=False, + parallel_output=False, + ): super().__init__() - # megatron splits across heads, so we need to make sure each - # head receives the correct matrix - assert mp_size <= num_heads and mp_rank <= mp_size - self.mp_size = mp_size - self.mp_rank = mp_rank - self.num_heads = num_heads - self.slice_size = num_heads // mp_size - self.cached_matrix = None - self.cached_seq_len = None - slopes = torch.Tensor(self._get_slopes(num_heads))[ - mp_rank * self.slice_size : (mp_rank + 1) * self.slice_size - ] - self.register_buffer("slopes", slopes) - - def _get_slopes(self, n): - """ - Get slopes for Alibi positional embedding - n : int = number of heads. - For best performance, restrict n to a power of 2. - """ - - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + self._get_slopes(2 * closest_power_of_2)[0::2][ - : n - closest_power_of_2 - ] + + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) + + self.fp16 = neox_args.precision == "fp16" + self.bf16 = neox_args.precision == "bfloat16" + self.attention_mask_func = attention_mask_func + self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling + self.use_cache = use_cache + self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = layer_number + # Per attention head and per partition values. + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) + self.hidden_size_per_attention_head = mpu.divide( + neox_args.hidden_size, neox_args.num_attention_heads + ) + self.num_attention_heads_per_partition = mpu.divide( + neox_args.num_attention_heads, world_size + ) + self.pos_emb = neox_args.pos_emb + + self.use_qk_layernorm = neox_args.use_qk_layernorm + if self.use_qk_layernorm: + norm, eps = get_norm(neox_args) + self.qk_layernorm = norm( + [ + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ], + eps=eps, + ) + + self.sliding_window_width = neox_args.sliding_window_width + + if ( + not neox_args.num_kv_heads + or neox_args.num_kv_heads == neox_args.num_attention_heads + ): + self.gqa = False + else: + self.gqa = True + if self.gqa: + self.num_kv_heads_per_partition = mpu.divide( + neox_args.num_kv_heads, world_size + ) # we do not yet clone KV heads in MQA across TP ranks... + self.kv_hidden_size = ( + neox_args.num_kv_heads * self.hidden_size_per_attention_head + ) # how large the total hidden dim for each of K and V is + else: + self.num_kv_heads_per_partition = self.num_attention_heads_per_partition + self.kv_hidden_size = neox_args.hidden_size + + if not self.gqa: + # Strided linear layer. + self.query_key_value = ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=3 * neox_args.hidden_size, + gather_output=False, + init_method=init_method, + bias=neox_args.use_bias_in_attn_linear, + ) + else: + # QKV proj is smaller if we are using GQA / MQA + self.query_key_value = ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, + gather_output=False, + init_method=init_method, + bias=neox_args.use_bias_in_attn_linear, + ) + + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = max(1, self.layer_number) + self.norm_factor *= coeff + + if neox_args.use_mup: + self.norm_factor = self.hidden_size_per_attention_head + + self.rpe = rpe + + if self.pos_emb == "alibi": + self.alibi_embed = AliBi( + neox_args.num_attention_heads, + neox_args.model_parallel_size, + mpu.get_model_parallel_rank(), + ) + + # TODO: this arg shouldn't need to be passed in - get from neox_args + if rotary: + if neox_args.rotary_pct == 1: + self.rotary_ndims = None + else: + assert neox_args.rotary_pct < 1 + self.rotary_ndims = int( + self.hidden_size_per_attention_head * neox_args.rotary_pct + ) + dim = ( + self.rotary_ndims + if self.rotary_ndims is not None + else self.hidden_size_per_attention_head + ) + self.rotary_emb = RotaryEmbedding( + dim, + base=neox_args.rotary_emb_base, + max_seq_len=neox_args.seq_length, + precision=neox_args.params_dtype, + save_inv_freqs=neox_args.rotary_save_freqs_buffer, + ) + else: + self.rotary_emb = None + + self.rope_fusion = neox_args.rope_fusion + self.attention_type = neox_args.attention_config[layer_number] + self.use_flash_attention = self.attention_type == "flash" + self.use_triton = ( + self.use_flash_attention + and self.pos_emb == "alibi" + and ( + not packaging.version.Version(version("flash-attn")) + >= packaging.version.Version("2.4.0.post1") + ) + ) + self.sparse = self.attention_type not in ("global", "flash") + + if self.gqa: + assert not self.sparse + + if self.sparse: + self.sparse_attn = configure_sparse_attention( + neox_args, + self.attention_type, + self.num_attention_heads_per_partition, + mpu=mpu, ) + else: + if self.use_flash_attention: + # we now use Flash Attention 2's provided interface. + # TODO: we no longer need to use flash_triton_fn since flash cuda supports alibi. + # consider adding OpenAI's more recent Flash-2 Triton kernel in future + # from https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py + from flash_attn.flash_attn_interface import ( + flash_attn_func, + flash_attn_varlen_func, + ) + from flash_attn.flash_attn_triton import ( + flash_attn_func as flash_attn_unpadded_unpacked_func_triton, + ) + + self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton + self.flash_qkv_fn = flash_attn_func + self.flash_varlen_qkv_fn = flash_attn_varlen_func + else: + self.scale_mask_softmax = FusedScaleMaskSoftmax( + input_in_fp16=self.fp16, + input_in_bf16=self.bf16, + fusion_type=get_fusion_type(neox_args), + mask_func=self.attention_mask_func, + softmax_in_fp32=self.attention_softmax_in_fp32, + scale=coeff, + ) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.dropout_p = neox_args.attention_dropout + self.attention_dropout = nn.Dropout(self.dropout_p) + + # Output. + self.dense = RowParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True, + parallel_output=parallel_output, + bias=neox_args.use_bias_in_attn_linear, + ) + + def attention( + self, query_layer, key_layer, value_layer, layer_past, attention_mask + ): + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== - def bias(self, seq_len_q, seq_len_k, device, dtype): # [b, np, sq, sk] - # seq_len_q = x.shape[-2] - # seq_len_k = x.shape[-1] - - # Initialize the AliBi matrix to match the first provided key length; grow it exponentially - # afterwards if longer inputs are provided. This is important for inference, where we will - # encounter progressively longer samples; it should have no effect at training time. - if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k: - a = self.cached_matrix - else: - target_seq_len = ( - seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4 - ) - a = -torch.tril( - torch.arange(target_seq_len) - .view(target_seq_len, 1) - .repeat(1, target_seq_len) - + torch.arange(0, -target_seq_len, -1) - ) - a = a.to(device).to(dtype) - slopes = self.slopes.to(a.device).to(a.dtype) - a = a * slopes.view(self.slopes.shape[0], 1, 1) - self.cached_seq_len = target_seq_len - self.cached_matrix = a - - # If the AliBi matrix is larger than the key length, clip it. - if self.cached_seq_len > seq_len_k: - a = self.cached_matrix[:, :seq_len_k, :seq_len_k] - - if seq_len_q != seq_len_k: - # In the train case x has dimensionality [b, np, sq, sk] with sq == sk - # The number of query tokens is equal to the number of key tokens - # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence) - # In this case we use the appropriate token index of the cache matrix. - # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used - assert ( - seq_len_q == 1 - ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" - a = a[:, seq_len_k - 1, :].view( - a.shape[0], 1, a.shape[2] - ) # seq_len_k - 1 points to the last token index in the current inference batch. - - return a + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + output_size[2], output_size[0] * output_size[1], -1 + ) + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # preallocating result tensor: [b * np, sq, sk] + matmul_result = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=torch.cuda.current_device(), + ) - def forward(self, x): + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== + + if self.use_cache: + with torch.no_grad(): + attention_mask = attention_mask[ + ..., : attention_scores.size(3), : attention_scores.size(3) + ] + + # =========================== + # Attention probs and dropout + # =========================== + + if exists(self.rpe): + rpe = self.rpe(query_layer.size(0), key_layer.size(0)) + attention_scores += rpe # [1, np, sq, sk] + + if self.pos_emb == "alibi": + attention_scores = self.alibi_embed(attention_scores) + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) + + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + return context_layer + + def flash_attention(self, query_layer, key_layer, value_layer): # [b, np, sq, sk] - seq_len_q = x.shape[-2] - seq_len_k = x.shape[-1] - - # Initialize the AliBi matrix to match the first provided key length; grow it exponentially - # afterwards if longer inputs are provided. This is important for inference, where we will - # encounter progressively longer samples; it should have no effect at training time. - if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k: - a = self.cached_matrix - else: - target_seq_len = ( - seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4 - ) - a = -torch.tril( - torch.arange(target_seq_len) - .view(target_seq_len, 1) - .repeat(1, target_seq_len) - + torch.arange(0, -target_seq_len, -1) - ) - a = a.to(x.device).to(x.dtype) - slopes = self.slopes.to(a.device).to(a.dtype) - a = a * slopes.view(self.slopes.shape[0], 1, 1) - self.cached_seq_len = target_seq_len - self.cached_matrix = a - - # If the AliBi matrix is larger than the key length, clip it. - if self.cached_seq_len > seq_len_k: - a = self.cached_matrix[:, :seq_len_k, :seq_len_k] - - if seq_len_q != seq_len_k: - # In the train case x has dimensionality [b, np, sq, sk] with sq == sk - # The number of query tokens is equal to the number of key tokens - # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence) - # In this case we use the appropriate token index of the cache matrix. - # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used - assert ( - seq_len_q == 1 - ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" - a = a[:, seq_len_k - 1, :].view( - a.shape[0], 1, a.shape[2] - ) # seq_len_k - 1 points to the last token index in the current inference batch. - - return x + a + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + + if self.use_flash_attention and not self.use_triton: + + # [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn] + key_layer = key_layer.transpose(0, 1).reshape( + output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 + ) + value_layer = value_layer.transpose(0, 1).reshape( + output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 + ) + + # [sq, b, np, hn] -> [b, sq, np, hn] + query_layer = query_layer.transpose(0, 1).reshape( + output_size[0], output_size[2], output_size[1], -1 + ) + + # only pass in window_size or alibi_slopes kwarg + # if we use Sliding Window Attention / AliBi. + # Flash attn defaults to (-1,-1), or + # does not have this kwarg prior to v2.3.0 + extra_kwargs = ( + {"window_size": (self.sliding_window_width, -1)} + if self.sliding_window_width is not None + else {} + ) + if self.pos_emb == "alibi": + extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to( + query_layer.device + ).to(torch.float32) + + if not self.training: + batch_size = output_size[0] + max_seqlen_q = output_size[2] + max_seqlen_k = output_size[3] + + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * max_seqlen_q, + step=max_seqlen_q, + dtype=torch.int32, + device=query_layer.device, + ) + + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * max_seqlen_k, + step=max_seqlen_k, + dtype=torch.int32, + device=key_layer.device, + ) + + q_shape = query_layer.shape + k_shape = key_layer.shape + v_shape = value_layer.shape + is_causal = max_seqlen_q == max_seqlen_k + output = self.flash_varlen_qkv_fn( + query_layer.reshape( + (q_shape[0] * q_shape[1], q_shape[2], q_shape[3]) + ), + key_layer.reshape( + (k_shape[0] * k_shape[1], k_shape[2], k_shape[3]) + ), + value_layer.reshape( + (v_shape[0] * v_shape[1], v_shape[2], v_shape[3]) + ), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=is_causal, + **extra_kwargs, + ) + output = output.reshape(q_shape) + else: + output = self.flash_qkv_fn( + query_layer, + key_layer, + value_layer, + self.dropout_p if self.training else 0.0, + softmax_scale=None, + causal=True, + **extra_kwargs, + ) + + matmul_result = output + # [b, sq, np, hn] -> [b, np, sq, hn] + matmul_result = matmul_result.transpose(1, 2) + + else: + # we still use Triton if using AliBi with flash-attn<2.4.0.post1. + + # [sq, b, np, hn] -> [b, sq, np, hn] + sq = query_layer.size(0) + b = query_layer.size(1) + sk = key_layer.size(0) + + query_layer = query_layer.transpose(0, 1) + key_layer = key_layer.transpose(0, 1) + value_layer = value_layer.transpose(0, 1) + + bias = self.alibi_embed.bias(sq, sk, query_layer.device, query_layer.dtype) + bias = bias.unsqueeze(0).tile((b, 1, 1, 1)) + + matmul_result = self.flash_triton_fn( + query_layer, key_layer, value_layer, bias=bias, causal=True + ) + matmul_result = matmul_result.transpose(1, 2) + + return matmul_result + + def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): + # TODO: sparse attn dropout? + # TODO: pad to block size + # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn] + query_layer, key_layer, value_layer = map( + lambda t: t.permute(1, 2, 0, 3).contiguous(), + (query_layer, key_layer, value_layer), + ) + # output shape [b, np(heads), sq, hn] + attn_mask = attention_mask.to(query_layer.dtype) * -10000 + if exists(self.rpe): + rpe = self.rpe(query_layer.size(0), key_layer.size(0)) + else: + rpe = None + return self.sparse_attn( + query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe + ) + + def gqa_project(self, hidden_states, attention_mask, layer_past=None): + # QKV projection and separation into separate Q/K/V layers for GQA, + # where KV projections may be smaller than Q projection. + # the logic for this is explained in comments of this function + # detailing the intermediate sizes of tensors at each reshape. + + # pass through projection: [sq, b, h] --> [sq, b, ((np + 2 * kvp) * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # First: reshape so we have seqlen, batch, and num. query heads each as separate dims + # Final dim is not exactly head dim: the first (head dim) dims are query heads, + # The last (head dim * ratio of kv to q heads) each are the "k/v heads" + # (right now we treat like we have same num. heads, but smaller head dim) + + # [sq, b, ((np + 2 * kvp) * hn)] --> [sq, b, np, (hn * (1 + 2 * (kvp / np)))] + new_qkv_shape = ( + mixed_x_layer.shape[0], + mixed_x_layer.shape[1], + self.num_attention_heads_per_partition, + int( + self.hidden_size_per_attention_head + * ( + 1 + + 2 + * ( + self.num_kv_heads_per_partition + / self.num_attention_heads_per_partition + ) + ) + ), + ) + mixed_x_layer = mixed_x_layer.reshape(*new_qkv_shape) + + # Next: split our fake head dim. (last dim) so that the first (head dim) dimensions go to Q, + # the last smaller 2 * (head dim * kv to q head ratio) each divided between K and V separately + split_sizes = ( + self.hidden_size_per_attention_head, + int( + ( + self.num_kv_heads_per_partition + / self.num_attention_heads_per_partition + ) + * self.hidden_size_per_attention_head + ), + int( + ( + self.num_kv_heads_per_partition + / self.num_attention_heads_per_partition + ) + * self.hidden_size_per_attention_head + ), + ) + + # [sq, b, np, (hn * (1 + 2 * (kvp / np)))] --> 1 x [sq, b, np, hn] , 2 x [sq, b, np, (hn * (kvp / np))] + (query_layer, key_layer, value_layer) = [ + x.contiguous() + for x in torch.split( + mixed_x_layer, + split_sizes, + dim=mixed_x_layer.dim() - 1, + ) + ] + + # reshape K/V to proper output shape (last dim = correct full "real" head size again) + # 2 x [sq, b, np, (hn * (kvp / np))] --> 2 x [sq, b, kvp, hn] + new_kv_shape = ( + key_layer.size(0), + key_layer.size(1), + self.num_kv_heads_per_partition, + self.hidden_size_per_attention_head, + ) + + key_layer = key_layer.view(*new_kv_shape) + + value_layer = value_layer.view(*new_kv_shape) + + # if not using Flash attention, we repeat K/V heads to match Q head counts + if not self.use_flash_attention: + key_layer = torch.repeat_interleave( + key_layer, + repeats=int( + self.num_attention_heads_per_partition + // self.num_kv_heads_per_partition + ), + dim=2, + ) + value_layer = torch.repeat_interleave( + value_layer, + repeats=int( + self.num_attention_heads_per_partition + // self.num_kv_heads_per_partition + ), + dim=2, + ) + + return query_layer, key_layer, value_layer + + def forward(self, hidden_states, attention_mask, layer_past=None): + + # hidden_states: [sq, b, h] + + # ===================== + # Query, Key, and Value + # ===================== + if not self.gqa: + # QKV projection for MHA. + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim( + mixed_x_layer, 3 + ) + else: + # Grouped Query Attention (GQA) - specific logic for performing QKV proj + # and separating out Q, K, and V outputs. + + # output shapes: 1 x [sq, b, np, hn], 2 x [sq, b, kvp, hn] if using flash + query_layer, key_layer, value_layer = self.gqa_project( + hidden_states, attention_mask, layer_past=layer_past + ) + # QK Normalization https://arxiv.org/abs/2302.05442 + if self.use_qk_layernorm: + query_layer = self.qk_layernorm(query_layer) + key_layer = self.qk_layernorm(key_layer) + + if exists(self.rotary_emb): + if exists(self.rotary_ndims): + # partial rotary + query_rot, query_pass = ( + query_layer[..., : self.rotary_ndims], + query_layer[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_layer[..., : self.rotary_ndims], + key_layer[..., self.rotary_ndims :], + ) + else: + # full rotary + query_rot, key_rot = query_layer, key_layer + + seq_len = key_layer.shape[0] + offset = 0 + if exists(layer_past) and layer_past.numel() > 0: + offset = layer_past[0].shape[0] + seq_len += offset + cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + if self.rope_fusion: + query_layer, key_layer = ( + fused_apply_rotary_pos_emb_cached(rot, cos, sin) + for rot in [query_rot, key_rot] + ) + else: + if self.bf16: + apply_rotary_fn = apply_rotary_pos_emb_torch + else: + apply_rotary_fn = apply_rotary_pos_emb + query_layer, key_layer = apply_rotary_fn( + query_rot, key_rot, cos, sin, offset=offset + ) + + if exists(self.rotary_ndims): + query_layer = torch.cat((query_layer, query_pass), dim=-1) + key_layer = torch.cat((key_layer, key_pass), dim=-1) + + + # ================================== + # Cache key and value for inference + # ================================== + + if exists(layer_past) and layer_past.numel() > 0: + past_key, past_value = layer_past + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=0 + ) + + if self.use_cache: + present = torch.stack((key_layer, value_layer)) + + if self.use_flash_attention: + context_layer = self.flash_attention(query_layer, key_layer, value_layer) + elif not self.sparse: + context_layer = self.attention( + query_layer, key_layer, value_layer, layer_past, attention_mask + ) + else: + context_layer = self.sparse_attention( + query_layer, key_layer, value_layer, attention_mask + ) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, + ) + context_layer = context_layer.view(*new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + if self.use_cache: + output = [output, present] + + return output, bias + + +class ParallelTransformerLayer(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [b, s, h] and returns an + output of the same size. + """ + + def __init__( + self, + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=None, + rotary=False, + use_cache=False, + ): + + super().__init__() + self.layer_number = layer_number + self.neox_args = neox_args + + norm, eps = get_norm(neox_args) + + # Layernorm on the input data. + self.input_layernorm = norm(neox_args.hidden_size, eps=eps) + self.use_cache = use_cache + + self.hidden_dropout = neox_args.hidden_dropout + self.bias_dropout_fusion = neox_args.bias_dropout_fusion + self.gpt_j_residual = neox_args.gpt_j_residual + self.gpt_j_tied = neox_args.gpt_j_tied + self.moe_type = neox_args.moe_type + self.activation = neox_args.activation + + if self.gpt_j_residual: + # GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers. + # the reduction we use is a simple allreduce for pure Tensor Parallel, + # but needs to be a reduce-scatter when using Megatron-style Sequence Parallel (LN sharding.) + self.reduce = ( + mpu.mappings.reduce_from_model_parallel_region + if not neox_args.sequence_parallel + else mpu.mappings.reduce_scatter_to_sequence_parallel_region + ) + + # Self attention. + if neox_args.te_mha or neox_args.fp8_mha: + from megatron.model.transformer_engine import TEMultiheadAttention + self.attention = TEMultiheadAttention( + neox_args=neox_args, + attention_mask_func=attention_mask_func, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + rpe=rpe, + use_cache=self.use_cache, + rotary=rotary, + parallel_output=self.gpt_j_residual, + ) + + else: + self.attention = ParallelSelfAttention( + neox_args=neox_args, + attention_mask_func=attention_mask_func, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + rpe=rpe, + use_cache=self.use_cache, + rotary=rotary, + parallel_output=self.gpt_j_residual, + ) + + # Layernorm on the output of the attention layer. + # If GPT-J residuals are used, this is surpurfulous but leaving it in + # leads to cleaner code + self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps) + + # MLP + def get_mlp(**kw): + return ParallelMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + multiple_of=neox_args.mlp_multiple_of, + **kw, + ) + + def get_te_lnmlp(**kw): + from megatron.model.transformer_engine import TELayerNormMLP + return TELayerNormMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + multiple_of=neox_args.mlp_multiple_of, + **kw, + ) + + self.num_experts = ( + neox_args.moe_num_experts + if layer_number % neox_args.expert_interval == 0 + else 1 + ) + args = neox_args + if self.num_experts <= 1: + if neox_args.te_layernorm_mlp: + self.mlp = get_te_lnmlp() + else: + self.mlp = get_mlp() + else: + from torch import distributed as dist + + if self.num_experts > dist.get_world_size(): + moe_mp_size = 1 + else: + moe_mp_size = dist.get_world_size() // self.num_experts + + if neox_args.moe_type == "deepspeed": + self.mlp = MoE( + args.hidden_size, + get_mlp( + "regular", + MOE=True, + MoE_mp_size=moe_mp_size, + ), + num_experts=self.num_experts, + ep_size=args.moe_expert_parallel_size, + k=args.moe_top_k, + use_residual=args.moe_use_residual, + capacity_factor=args.moe_train_capacity_factor, + eval_capacity_factor=args.moe_eval_capacity_factor, + min_capacity=args.moe_min_capacity, + drop_tokens=args.moe_token_dropping, + use_tutel=args.use_tutel, + enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, + ) + elif neox_args.moe_type == "megablocks": + + def integrate_megablocks_with_ds_expert_parallelism(): + # We make megablocks work with DS parallelism. + # + # We fool DS into accepting these MoE parameters as its own DS MoE params, + # which makes things work with the underlying expert parallelism, + # including TED parallelism. + # + # Effectively, we want to: + # + # - Make DS's data parallel gradient all-reduction skip these params. + # - But make these params participate in the expert parallel all-reduction! + # + # Further background: + # + # Normally, with the original megablocks demo codebase, it + # only supports 1 copy of any expert throughout + # the network, since it uses EP group = DP group. + # + # First, we trigger DS initialization of the MoE expert parallel groups and internal state. + throwaway = MoE( + args.hidden_size, + get_mlp( + "regular", + MOE=True, + MoE_mp_size=moe_mp_size, + ), + num_experts=self.num_experts, + ep_size=args.moe_expert_parallel_size, + k=args.moe_top_k, + use_residual=args.moe_use_residual, + capacity_factor=args.moe_train_capacity_factor, + eval_capacity_factor=args.moe_eval_capacity_factor, + min_capacity=args.moe_min_capacity, + drop_tokens=args.moe_token_dropping, + use_tutel=args.use_tutel, + enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, + ) + throwaway.set_deepspeed_parallelism() + + ep_group = throwaway.deepspeed_moe.ep_group + if args.moe_token_dropping: + self.mlp = MbMoE( + neox_args, init_method, output_layer_init_method, ep_group + ) + else: + self.mlp = dMoE( + neox_args, init_method, output_layer_init_method, ep_group + ) + + # Next, we trick DS into seeing these as its own MoE params. + for param in self.mlp.parameters(): + if getattr(param, "expert_model_parallel", None) is not None: + # is_moe_param looks for this attr. + param.allreduce = False + param.group_name = throwaway.expert_group_name + + integrate_megablocks_with_ds_expert_parallelism() + + else: + raise KeyError(neox_args.moe_type) + + self.layer_past = None # used to cache k/v pairs in inference + + def _get_bias_dropout(self): + if self.bias_dropout_fusion: + fn = ( + bias_dropout_add_fused_train + if self.training + else bias_dropout_add_fused_inference + ) + else: + fn = get_bias_dropout_add(self.training) + return fn + + def forward(self, x, attention_mask, layer_past=None): + layer_past = layer_past if layer_past is not None else self.layer_past + bias_dropout_fn = self._get_bias_dropout() + moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) + # x: [b, s, h] + + + #Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer. + if self.neox_args.fp8_mha: + from megatron.model.transformer_engine import TEDelayedScaling + + fp8_recipe = TEDelayedScaling( + neox_args=self.neox_args + ) + fp8_context = fp8_recipe.get_context() + else: + from contextlib import nullcontext + fp8_context = nullcontext() + + with fp8_context: + if self.gpt_j_residual: + # pseudocode: + # x = x + attn(ln(x)) + mlp(ln(x)) + # this means we can avoid doing the allreduce in the attn / mlp outputs + # to save communication time (we can do a single allreduce after we add mlp / attn outputs). + # due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but + # we preserve the functionality for backwards compatibility + + residual = x + # applies the correct normalization depending on if the norms are tied + if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: + x = self.input_layernorm(x) + x1, x2 = x, x + elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: + x2 = x + x = self.input_layernorm(x) + x1 = x + elif self.neox_args.te_layernorm_mlp: + x1, x2 = self.input_layernorm(x), x + else: + x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) + + # attention operator + attention_output, attention_bias = self.attention( + x1, attention_mask, layer_past=layer_past + ) + if self.use_cache: + attention_output, presents = attention_output + self.layer_past = presents + + if attention_bias is not None: + with torch.enable_grad() if not self.eval else nullcontext(): + attention_output = bias_dropout_fn( + attention_output, + bias=attention_bias.expand_as(attention_output), + residual=None, + prob=self.hidden_dropout, + ) + + # mlp operator + mlp_output, mlp_bias = self.mlp(x2) + if mlp_bias is not None: + with torch.enable_grad() if not self.eval else nullcontext(): + output = bias_dropout_fn( + mlp_output, + bias=mlp_bias.expand_as(mlp_output), + residual=attention_output, + prob=self.hidden_dropout, + ) + else: + output = mlp_output + + # output = (x + attn(ln(x)) + mlp(ln(x)) + output = residual + self.reduce(output) + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + + residual = x + + # x = x + attn(ln1(x)) + attention_output, attention_bias = self.attention( + self.input_layernorm(x), attention_mask, layer_past=layer_past + ) + if self.use_cache: + attention_output, presents = attention_output + self.layer_past = presents + with torch.enable_grad() if not self.eval else nullcontext(): + if attention_bias is not None: + # Use special bias_dropout_fn if we have a bias term from the above attention layer + attention_output = bias_dropout_fn( + attention_output, + bias=attention_bias.expand_as(residual), + residual=residual, + prob=self.hidden_dropout, + ) + else: + # Otherwise just apply dropout + residual + attention_output = ( + torch.nn.functional.dropout( + attention_output, + p=self.hidden_dropout, + training=self.training, + ) + + residual + ) + + # output = x + mlp(ln2(x)) + if self.neox_args.te_layernorm_mlp: + layernorm_output = attention_output + else: + layernorm_output = self.post_attention_layernorm(attention_output) + mlp_bias = torch.tensor( + 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype + ) + + if self.num_experts == 1: + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + if self.moe_type == "deepspeed": + mlp_output, moe_loss, _ = self.mlp(layernorm_output) + mlp_bias = ( + None # deepspeed.moe.layer.MoE.forward ignores the bias term + ) + elif self.moe_type == "megablocks": + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + raise KeyError(self.moe_type) + + with torch.enable_grad() if not self.eval else nullcontext(): + if ( + self.activation == "swiglu" + or self.num_experts > 1 + and self.moe_type == "deepspeed" + ): + # No dropout either + assert mlp_bias is None + output = mlp_output + attention_output + else: + output = bias_dropout_fn( + mlp_output, + bias=mlp_bias.expand_as(attention_output), + residual=attention_output, + prob=self.hidden_dropout, + ) + + return output, moe_loss + + +class ParallelTransformerLayerPipe(ParallelTransformerLayer): + """Extends ParallelTransformerLayer to forward attention_mask through the pipeline.""" + + def forward(self, args): + assert ( + len(args) == 2 + ), "ParallelTransformerLayerPipe expects 2 arguments - hidden_states and attention_mask" + hidden_states, attention_mask = args + # we are returning just [hidden_states, mask] + output, moe_loss = super().forward(hidden_states, attention_mask) + # auxiliary output + self.last_moe_loss = moe_loss + return output, attention_mask + + +class ParallelLinearPipe(ParallelLinear): + """Another helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" + + def forward(self, args): + assert isinstance( + args, torch.Tensor + ), "ParallelLinearPipe expects a single argument - hidden_states" + hidden_state = args + logits, bias = super().forward(hidden_state) + return logits + + +class NormPipe(nn.Module): + """Just a helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" + + def __init__(self, norm_class, hidden_size, eps): + super().__init__() + self.norm = norm_class(hidden_size, eps=eps) + + def forward(self, args): + assert not isinstance( + args, tuple + ), "NormPipe should only receive a single tensor as input" + return self.norm(args) + + +def parallel_lm_logits( + input_, + word_embeddings_weight, + parallel_output, + seq_parallel=False, + seq_dim=1, + bias=None, +): + """LM logits using word embedding weights.""" + # Parallel logits. + if seq_parallel: + # if using Sequence Parallelism, our logits are sharded along the sequence dimension. + # gather them here. (backward pass: reduce-scatter) + input_parallel = mpu.gather_from_sequence_parallel_region( + input_, seq_dim=seq_dim + ) + else: + # Set up backprop all-reduce. + input_parallel = mpu.copy_to_model_parallel_region(input_) + + # Matrix multiply. + if bias is None: + logits_parallel = F.linear(input_parallel, word_embeddings_weight) + else: + logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) + + # Gather if needed. + if parallel_output: + return logits_parallel + + return mpu.gather_from_model_parallel_region(logits_parallel) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 62f316f3e..fdf384a4f 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -340,6 +340,8 @@ def __init__( ): super().__init__() + ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) + self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func @@ -393,7 +395,7 @@ def __init__( if not self.gqa: # Strided linear layer. - self.query_key_value = mpu.ColumnParallelLinear( + self.query_key_value = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, @@ -403,7 +405,7 @@ def __init__( ) else: # QKV proj is smaller if we are using GQA / MQA - self.query_key_value = mpu.ColumnParallelLinear( + self.query_key_value = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, @@ -412,6 +414,7 @@ def __init__( bias=neox_args.use_bias_in_attn_linear, ) + coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: @@ -857,19 +860,17 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None): return query_layer, key_layer, value_layer def forward(self, hidden_states, attention_mask, layer_past=None): - + # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== - if not self.gqa: # QKV projection for MHA. # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_attention_heads_per_partition, @@ -889,7 +890,6 @@ def forward(self, hidden_states, attention_mask, layer_past=None): query_layer, key_layer, value_layer = self.gqa_project( hidden_states, attention_mask, layer_past=layer_past ) - # QK Normalization https://arxiv.org/abs/2302.05442 if self.use_qk_layernorm: query_layer = self.qk_layernorm(query_layer) @@ -934,6 +934,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None): query_layer = torch.cat((query_layer, query_pass), dim=-1) key_layer = torch.cat((key_layer, key_pass), dim=-1) + # ================================== # Cache key and value for inference # ================================== @@ -1027,7 +1028,7 @@ def __init__( ) # Self attention. - if neox_args.te_mha: + if neox_args.te_mha or neox_args.fp8_mha: from megatron.model.transformer_engine import TEMultiheadAttention self.attention = TEMultiheadAttention( neox_args=neox_args, @@ -1200,134 +1201,149 @@ def forward(self, x, attention_mask, layer_past=None): bias_dropout_fn = self._get_bias_dropout() moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) # x: [b, s, h] - if self.gpt_j_residual: - # pseudocode: - # x = x + attn(ln(x)) + mlp(ln(x)) - # this means we can avoid doing the allreduce in the attn / mlp outputs - # to save communication time (we can do a single allreduce after we add mlp / attn outputs). - # due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but - # we preserve the functionality for backwards compatibility - - residual = x - # applies the correct normalization depending on if the norms are tied - if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: - x = self.input_layernorm(x) - x1, x2 = x, x - elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: - x2 = x - x = self.input_layernorm(x) - x1 = x - elif self.neox_args.te_layernorm_mlp: - x1, x2 = self.input_layernorm(x), x - else: - x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) - - # attention operator - attention_output, attention_bias = self.attention( - x1, attention_mask, layer_past=layer_past + + + #Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer. + if self.neox_args.fp8_mha: + from megatron.model.transformer_engine import TEDelayedScaling + + fp8_recipe = TEDelayedScaling( + neox_args=self.neox_args ) - if self.use_cache: - attention_output, presents = attention_output - self.layer_past = presents - - if attention_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): - attention_output = bias_dropout_fn( - attention_output, - bias=attention_bias.expand_as(attention_output), - residual=None, - prob=self.hidden_dropout, - ) - - # mlp operator - mlp_output, mlp_bias = self.mlp(x2) - if mlp_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): - output = bias_dropout_fn( - mlp_output, - bias=mlp_bias.expand_as(mlp_output), - residual=attention_output, - prob=self.hidden_dropout, - ) - else: - output = mlp_output - - # output = (x + attn(ln(x)) + mlp(ln(x)) - output = residual + self.reduce(output) + fp8_context = fp8_recipe.get_context() else: - # pseudocode: - # x = x + attn(ln1(x)) - # x = x + mlp(ln2(x)) + from contextlib import nullcontext + fp8_context = nullcontext() + + with fp8_context: + if self.gpt_j_residual: + # pseudocode: + # x = x + attn(ln(x)) + mlp(ln(x)) + # this means we can avoid doing the allreduce in the attn / mlp outputs + # to save communication time (we can do a single allreduce after we add mlp / attn outputs). + # due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but + # we preserve the functionality for backwards compatibility + + residual = x + # applies the correct normalization depending on if the norms are tied + if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: + x = self.input_layernorm(x) + x1, x2 = x, x + elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: + x2 = x + x = self.input_layernorm(x) + x1 = x + elif self.neox_args.te_layernorm_mlp: + x1, x2 = self.input_layernorm(x), x + else: + x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) - residual = x + # attention operator + attention_output, attention_bias = self.attention( + x1, attention_mask, layer_past=layer_past + ) + if self.use_cache: + attention_output, presents = attention_output + self.layer_past = presents - # x = x + attn(ln1(x)) - attention_output, attention_bias = self.attention( - self.input_layernorm(x), attention_mask, layer_past=layer_past - ) - if self.use_cache: - attention_output, presents = attention_output - self.layer_past = presents - with torch.enable_grad() if not self.eval else nullcontext(): if attention_bias is not None: - # Use special bias_dropout_fn if we have a bias term from the above attention layer - attention_output = bias_dropout_fn( - attention_output, - bias=attention_bias.expand_as(residual), - residual=residual, - prob=self.hidden_dropout, - ) - else: - # Otherwise just apply dropout + residual - attention_output = ( - torch.nn.functional.dropout( + with torch.enable_grad() if not self.eval else nullcontext(): + attention_output = bias_dropout_fn( attention_output, - p=self.hidden_dropout, - training=self.training, + bias=attention_bias.expand_as(attention_output), + residual=None, + prob=self.hidden_dropout, ) - + residual - ) - # output = x + mlp(ln2(x)) - if self.neox_args.te_layernorm_mlp: - layernorm_output = attention_output - else: - layernorm_output = self.post_attention_layernorm(attention_output) - mlp_bias = torch.tensor( - 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype - ) + # mlp operator + mlp_output, mlp_bias = self.mlp(x2) + if mlp_bias is not None: + with torch.enable_grad() if not self.eval else nullcontext(): + output = bias_dropout_fn( + mlp_output, + bias=mlp_bias.expand_as(mlp_output), + residual=attention_output, + prob=self.hidden_dropout, + ) + else: + output = mlp_output - if self.num_experts == 1: - mlp_output, mlp_bias = self.mlp(layernorm_output) + # output = (x + attn(ln(x)) + mlp(ln(x)) + output = residual + self.reduce(output) else: - if self.moe_type == "deepspeed": - mlp_output, moe_loss, _ = self.mlp(layernorm_output) - mlp_bias = ( - None # deepspeed.moe.layer.MoE.forward ignores the bias term - ) - elif self.moe_type == "megablocks": - mlp_output, mlp_bias = self.mlp(layernorm_output) + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + + residual = x + + # x = x + attn(ln1(x)) + attention_output, attention_bias = self.attention( + self.input_layernorm(x), attention_mask, layer_past=layer_past + ) + if self.use_cache: + attention_output, presents = attention_output + self.layer_past = presents + with torch.enable_grad() if not self.eval else nullcontext(): + if attention_bias is not None: + # Use special bias_dropout_fn if we have a bias term from the above attention layer + attention_output = bias_dropout_fn( + attention_output, + bias=attention_bias.expand_as(residual), + residual=residual, + prob=self.hidden_dropout, + ) + else: + # Otherwise just apply dropout + residual + attention_output = ( + torch.nn.functional.dropout( + attention_output, + p=self.hidden_dropout, + training=self.training, + ) + + residual + ) + + # output = x + mlp(ln2(x)) + if self.neox_args.te_layernorm_mlp: + layernorm_output = attention_output else: - raise KeyError(self.moe_type) - - with torch.enable_grad() if not self.eval else nullcontext(): - if ( - self.activation == "swiglu" - or self.num_experts > 1 - and self.moe_type == "deepspeed" - ): - # No dropout either - assert mlp_bias is None - output = mlp_output + attention_output + layernorm_output = self.post_attention_layernorm(attention_output) + mlp_bias = torch.tensor( + 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype + ) + + if self.num_experts == 1: + mlp_output, mlp_bias = self.mlp(layernorm_output) else: - output = bias_dropout_fn( - mlp_output, - bias=mlp_bias.expand_as(attention_output), - residual=attention_output, - prob=self.hidden_dropout, - ) + if self.moe_type == "deepspeed": + mlp_output, moe_loss, _ = self.mlp(layernorm_output) + mlp_bias = ( + None # deepspeed.moe.layer.MoE.forward ignores the bias term + ) + elif self.moe_type == "megablocks": + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + raise KeyError(self.moe_type) + + with torch.enable_grad() if not self.eval else nullcontext(): + if ( + self.activation == "swiglu" + or self.num_experts > 1 + and self.moe_type == "deepspeed" + ): + # No dropout either + assert mlp_bias is None + output = mlp_output + attention_output + else: + output = bias_dropout_fn( + mlp_output, + bias=mlp_bias.expand_as(attention_output), + residual=attention_output, + prob=self.hidden_dropout, + ) - return output, moe_loss + return output, moe_loss class ParallelTransformerLayerPipe(ParallelTransformerLayer): diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 9a8c0a506..07559bcfb 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -16,10 +16,13 @@ from megatron.mpu.mappings import scatter_to_model_parallel_region from megatron.mpu.mappings import reduce_scatter_to_sequence_parallel_region from megatron.mpu.mappings import gather_from_sequence_parallel_region +from megatron.mpu.layers import _initialize_affine_weight_gpu, _initialize_affine_weight_cpu from megatron.mpu.random import get_cuda_rng_tracker from megatron.mpu.utils import divide from megatron.mpu.utils import VocabUtility from functools import partial +from megatron.model.positional_embeddings import RotaryEmbedding +from megatron import mpu try: import transformer_engine as te @@ -90,7 +93,6 @@ def __init__( mup_rescale_parameters=False, seq_dim=0, ): - # Keep input parameters self.input_size = input_size self.output_size = output_size @@ -106,9 +108,6 @@ def __init__( self.use_mup = neox_args.use_mup self.params_dtype=neox_args.params_dtype - # print("##########################") - # print(self.return_bias) - super(TELinear, self).__init__(in_features=self.input_size, out_features=self.output_size, bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, device=torch.cuda.current_device(), return_bias=self.skip_bias_add, params_dtype=self.params_dtype) @@ -145,7 +144,6 @@ def __init__( ): self.activation_func, self.is_gated = get_activation(neox_args) self.activation_type = neox_args.activation - self.bias_gelu_fusion = neox_args.bias_gelu_fusion self.multiple_of = multiple_of self.bias = bias self.init_method = init_method @@ -188,16 +186,17 @@ def __init__( if neox_args.norm in ['layernorm','te_layernorm']: self.eps=1.0e-5 self.normalization = 'LayerNorm' - elif neox_args.norm == ['rmsnorm','te_rmsnorm']: + elif neox_args.norm in ['rmsnorm','te_rmsnorm']: self.eps=1.0e-8 self.normalization = 'RMSNorm' - #TODO handle case if norm is not rmsnorm or layernorm - #TODO check if activation in list ‘gelu’, ‘geglu’, ‘relu’, ‘reglu’, ‘squared_relu’, - #‘swiglu’, ‘qgelu’, ‘srelu’ - #TODO handle MOE and mup + else: + raise ValueError("Only LayerNorm and RMSNorm are supported with TransformerEngine") + + if self.activation_type not in ["gelu", "geglu", "relu", "reglu", "squared_relu","swiglu", "qgelu", "srelu"]: + raise ValueError("Only gelu, geglu, relu, reglu, squared_relu, swiglu, qgelu, and srelu are supported with TransformerEngine") super(TELayerNormMLP, self).__init__(hidden_size=neox_args.hidden_size, ffn_hidden_size=ffn_dim, - eps=self.eps, bias=self.bias, normalization=self.normalization, activation=neox_args.activation, + eps=self.eps, bias=self.bias, normalization=self.normalization, activation=self.activation_type, init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, device=torch.cuda.current_device(), set_parallel_mode=self.set_parallel_mode, sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, tp_size=self.world_size, @@ -265,8 +264,6 @@ def __init__( self.use_mup = neox_args.use_mup self.params_dtype=neox_args.params_dtype self.parallel_mode="column" - # print("##########################") - # print(self.return_bias) super(TEColumnParallelLinear, self).__init__(in_features=self.input_size, out_features=self.output_size, bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, @@ -283,12 +280,6 @@ def width_mult(self): ) return self.weight.infshape.width_mult() - def set_parallel_output(self, value: bool): - assert isinstance(value, bool) - self.gather_output = ( - not value - ) # if gather_output is True, parallel output is False, so we set the opposite - # Copied from Mup def _rescale_parameters(self): """Rescale parameters to convert SP initialization to μP initialization. @@ -308,7 +299,7 @@ def _rescale_parameters(self): self.bias.data *= self.width_mult() ** 0.5 self.weight.data *= self.width_mult() ** 0.5 self._has_rescaled_params = True - + def mup_reinitialize_weights(self, neox_args): if neox_args.use_cpu_initialization: self.master_weight = _initialize_affine_weight_cpu( @@ -316,26 +307,25 @@ def mup_reinitialize_weights(self, neox_args): self.weight, self.output_size, self.input_size, - self.output_size_per_partition, - 0, + self.input_size_per_partition, + 1, partial(self.init_method, use_mup=True), stride=self.stride, - return_master_weight=keep_master_weight_for_test, + return_master_weight=self.keep_master_weight_for_test, ) else: _initialize_affine_weight_gpu( self.weight, partial(self.init_method, use_mup=True), - partition_dim=0, + partition_dim=1, stride=self.stride, ) - + def forward(self, inp, **kwargs): if self.use_mup and self.mup_rescale_parameters: input_ /= self.width_mult() output = super(TEColumnParallelLinear, self).forward(inp, **kwargs) - if self.skip_bias_add: return output else: @@ -455,21 +445,17 @@ def mup_reinitialize_weights(self, neox_args): stride=self.stride, ) - def set_parallel_output(self, parallel_output: bool): - assert isinstance(parallel_output, bool) - self.parallel_output = parallel_output - def forward(self, inp, **kwargs): - # if not self.input_is_parallel: - # inp = scatter_to_model_parallel_region(inp) + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() output = super(TERowParallelLinear, self).forward(inp, **kwargs) + if self.skip_bias_add: return output else: return output, None - class TEMultiheadAttention(te.pytorch.MultiheadAttention): """ Wrapper for the Transformer-Engine's `MultiheadAttention` layer that also @@ -487,6 +473,7 @@ def __init__(self, use_cache=False, parallel_output=False): + self.neox_args = neox_args self.attention_mask_func = attention_mask_func self.init_method = init_method self.output_layer_init_method = output_layer_init_method @@ -524,12 +511,43 @@ def __init__(self, attention_dropout=neox_args.attention_dropout, layernorm_epsilon=self.eps, init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, layer_number=self.layer_number, window_size=neox_args.sliding_window_width, num_gqa_groups=self.num_kv_heads, input_layernorm=False, - normalization=self.normalization, bias=True, device=torch.cuda.current_device(), + normalization=self.normalization, bias=True, device=torch.cuda.current_device(),get_rng_state_tracker=get_cuda_rng_tracker, set_parallel_mode=self.set_parallel_mode, sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, - tp_size=self.world_size, params_dtype=self.params_dtype, return_bias=True) + tp_size=self.world_size, params_dtype=self.params_dtype, return_bias=True, qkv_format="sbhd", fuse_qkv_params=True) + + + + if neox_args.pos_emb == "rotary": + self.hidden_size_per_attention_head = mpu.divide( + neox_args.hidden_size, neox_args.num_attention_heads) + + if neox_args.rotary_pct == 1: + self.rotary_ndims = None + else: + assert neox_args.rotary_pct < 1 + self.rotary_ndims = int( + self.hidden_size_per_attention_head * neox_args.rotary_pct + ) + dim = ( + self.rotary_ndims + if self.rotary_ndims is not None + else self.hidden_size_per_attention_head + ) + self.rotary_embeddings = RotaryEmbedding( + dim, + base=neox_args.rotary_emb_base, + max_seq_len=neox_args.seq_length, + precision=neox_args.params_dtype, + save_inv_freqs=neox_args.rotary_save_freqs_buffer, + return_embeddings=True + ) + + def forward(self, hidden_states, attention_mask, layer_past=None, rope_emb=None, **kwargs): + if self.neox_args.pos_emb == "rotary": + rope_emb=self.rotary_embeddings(hidden_states) + + output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, rotary_pos_emb=rope_emb, **kwargs) - def forward(self, hidden_states, attention_mask, layer_past=None, **kwargs): - output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, **kwargs) return output @@ -537,7 +555,36 @@ class TEDelayedScaling(te.common.recipe.DelayedScaling): """ Wrapper for the Transformer-Engine's `DelayedScaling` layer. """ + ##TODO Test with H100 + def __init__( + self, + neox_args): + + self.neox_args = neox_args + self.tp_group = get_tensor_model_parallel_group() + + if neox_args.fp8_format == "e4m3": + fp8_format = te.common.recipe.Format.E4M3 + elif neox_args.fp8_format == "hybrid": + fp8_format = te.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") - def __init__(self): - # TODO - return + override_linear_precision = (False, False, not neox_args.fp8_wgrad) + + super().__init__( + margin=neox_args.fp8_margin, + fp8_format=fp8_format, + amax_compute_algo=neox_args.fp8_amax_compute_algo, + amax_history_len=neox_args.fp8_amax_history_len, + override_linear_precision=override_linear_precision, + fp8_mha=neox_args.fp8_mha, + ) + + def fp8_context(self): + fp8_group = None + if self.tp_group: + fp8_group = self.tp_group + fp8_context = te.pytorch.fp8_autocast(enabled=True, fp8_recipe=self, fp8_group=fp8_group) + + return get_context \ No newline at end of file diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index b74556bad..8c91d5168 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -522,6 +522,46 @@ class NeoXArgsModel(NeoXArgsTemplate): Use TransformerEngine for MultiheadAttention layer. """ + fp8_format: Literal["e4m3", "hybrid"] = "hybrid" + """ + Controls the FP8 data format used during forward and backward pass by TransformerEngine. + Hybrid uses E4M3 during forward pass, E5M2 during backward pass. + """ + + fp8_wgrad: bool = True + """ + When set to False, override FP8 config options and do the wgrad computation + in higher precision. + """ + + fp8_amax_history_len: int = 1 + """ + The length of the amax history window used for scaling factor computation. + """ + + fp8_amax_compute_algo: str = "most_recent" + """ + Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2 + predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent` + always chooses the most recently seen value. + """ + + fp8_wgrad: bool = True + """ + When set to False, override FP8 config options and do the wgrad computation + in higher precision. + """ + + fp8_margin: int = 0 + """ + Margin for the scaling factor computation. + """ + + fp8_mha: bool = False + """ + When set to True, use the FP8 implementation of Multi Head Attention. + """ + @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): """ From 0b4bdc576ec2e9f495d712f239153465247f749d Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Thu, 26 Sep 2024 08:25:49 +0000 Subject: [PATCH 18/27] Fixed mixed files. --- megatron/model/positional_embeddings.py | 1577 +++-------------------- 1 file changed, 207 insertions(+), 1370 deletions(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index fdf384a4f..38d7b12c6 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -1,7 +1,4 @@ -# Copyright (c) 2024 EleutherAI -# This file is based on code by the authors denoted below and has been modified from its original version. -# -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, EleutherAI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,1406 +12,246 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Transformer.""" - -import math -from contextlib import nullcontext - import torch -import torch.nn.functional as F -import torch.nn as nn -from pkg_resources import packaging -from importlib.metadata import version - -from .norms import get_norm -from megatron import mpu -from megatron.model import megablocks_utils -from megatron.model.fused_softmax import FusedScaleMaskSoftmax -from megatron.model.activations import get_activation -from megatron.model.utils import exists, get_fusion_type -from megatron.model.positional_embeddings import ( - RotaryEmbedding, - apply_rotary_pos_emb_torch, - apply_rotary_pos_emb, - AliBi, -) -from megatron.model.fused_rope import ( - FusedRoPEFunc, - fused_apply_rotary_pos_emb_cached, -) -from megatron.model.fused_bias_dropout import ( - get_bias_dropout_add, - bias_dropout_add_fused_train, - bias_dropout_add_fused_inference, -) -from megatron.model.utils import configure_sparse_attention -from deepspeed.moe.layer import MoE - -try: - from flash_attn.ops.activations import swiglu -except ImportError: - swiglu = None - -from .utils import linear_implementation_router - -# flags required to enable jit fusion kernels -torch._C._jit_set_profiling_mode(False) -torch._C._jit_set_profiling_executor(False) -torch._C._jit_override_can_fuse_on_cpu(True) -torch._C._jit_override_can_fuse_on_gpu(True) - -""" We use the following notation throughout this file: - h: hidden size - n: number of attention heads - kv: number of key or value heads - p: number of model parallel partitions - np: n/p - kvp: kv/p - hp: h/p - hn: h/n - b: batch size - s: sequence length - l: number of layers - Transformer takes input of size [s, b, h] and returns a - tensor of the same size. We use the following arguments: - hyperparameters: transformer hyperparameters - attention_mask_func: a function that takes `unmasked-attention-scores` - with size [b, np, s, s] and an `attention-mask` and will apply - the masking. The function should return a masked score of the - same size [b, np, s, s]. - masked-attention-scores = attention_mask_func( - unmasked-attention-scores, attention-mask) -""" - - -class ParallelMLP(nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. At the end, dropout is also - applied. - """ - - def __init__( - self, - neox_args, - init_method, - output_layer_init_method, - parallel_output=False, - multiple_of=256, - MOE=False, - MoE_mp_size=1, - ): - super().__init__() - assert ( - neox_args.intermediate_size == None or neox_args.expansion_factor == None - ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" - - self.activation_func, self.is_gated = get_activation(neox_args) - self.activation_type = neox_args.activation - self.bias_gelu_fusion = neox_args.bias_gelu_fusion - self.multiple_of = multiple_of - - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) - - if neox_args.intermediate_size: - ffn_dim = neox_args.intermediate_size - elif neox_args.expansion_factor: - ffn_dim = int(neox_args.expansion_factor * neox_args.hidden_size) - else: - # 4h is default for ffn_dim - ffn_dim = 4 * neox_args.hidden_size - ffn_dim_in = ffn_dim - if self.is_gated: - # set activation function to be gated implementation - self.activation_func = Gated_Activation( - self.activation_func, - (swiglu is not None) - and (neox_args.activation == "swiglu") - and neox_args.use_flashattn_swiglu, - ) - # auto scale so gated activations has equal parameters - ffn_dim = int(ffn_dim * 2 / 3) - ffn_dim_in = ffn_dim // 2 - # set multiple - ffn_dim = int( - (2 * self.multiple_of) - * ((ffn_dim + (2 * multiple_of) - 1) // (2 * multiple_of)) - ) - ffn_dim_in = int( - self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) - ) - self.linear1 = ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=ffn_dim, - gather_output=False, - init_method=init_method, - skip_bias_add=True, - MOE=MOE, - MoE_mp_size=MoE_mp_size, - bias=neox_args.use_bias_in_mlp, - ) - # Project back to h. - self.linear2 = RowParallelLinear( - neox_args=neox_args, - input_size=ffn_dim_in, - output_size=neox_args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - parallel_output=parallel_output, - skip_bias_add=True, - MOE=MOE, - MoE_mp_size=MoE_mp_size, - bias=neox_args.use_bias_in_mlp, - ) - - def forward(self, hidden_states): - # [s, b, intermediate_size] - intermediate_parallel, bias_parallel = self.linear1(hidden_states) - if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): - intermediate_parallel = self.activation_func( - intermediate_parallel, bias_parallel - ) - else: - intermediate_parallel = self.activation_func( - intermediate_parallel + bias_parallel - ) - - # [s, b, h] - output, output_bias = self.linear2(intermediate_parallel) - return output, output_bias - - -class Gated_Activation(torch.nn.Module): - def __init__(self, activation_func, use_swiglu=False): - super().__init__() - self.activation_func = activation_func - self.use_swiglu = use_swiglu - - def forward(self, x, bias=None): - x, gate = x.chunk(2, dim=-1) - if bias is not None: - bias_1, bias_2 = bias.chunk(2, dim=-1) - x = x + bias_1 - gate = gate + bias_2 - if not self.use_swiglu: - intermediate_parallel = self.activation_func(gate) - return intermediate_parallel * x - else: - return swiglu(gate, x) - +import math -class ParallelLinear(nn.Module): - """ - A Parallel Linear Layer transforming the transformer outputs from hidden_size -> vocab_size - """ - def __init__( - self, - neox_args, - parallel_output=True, - init_method=nn.init.xavier_normal_, - is_last_layer=False, - ): +class SinusoidalPositionalEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half): super().__init__() - - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) - - self.is_rm = neox_args.train_impl == "rm" - parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" - if parallelism == "column": - self.final_linear = ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=neox_args.padded_vocab_size, - bias=False, - init_method=init_method, - gather_output=not parallel_output, - skip_bias_add=False, - mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here - seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0 - ) - else: - if not self.is_rm: - print( - 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.' - ) - exit() - # self.final_linear = mpu.RowParallelLinear( - # neox_args=neox_args, - # input_size=neox_args.hidden_size, - # output_size=neox_args.padded_vocab_size, - # bias=False, - # input_is_parallel=False, - # init_method=init_method, - # parallel_output=parallel_output, - # skip_bias_add=False, - # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here - # ) - else: # Not using cross entropy loss for RMs - self.rm_linear = RowParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=1, - bias=False, - input_is_parallel=False, - init_method=init_method, - parallel_output=False, - skip_bias_add=False, - mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here - ) - - def forward(self, hidden_states): - if not self.is_rm: - return self.final_linear(hidden_states) - else: - return self.rm_linear(hidden_states) - - -class _MegablocksAdapter(nn.Module): + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.precision = precision + + def forward(self, x, seq_dim=1): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq) + if self.precision == torch.bfloat16: + sinusoid_inp = sinusoid_inp.float() + sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos() + if self.precision == torch.bfloat16: + sin, cos = sin.bfloat16(), cos.bfloat16() + emb = torch.cat((sin, cos), dim=-1) + return emb[None, :, :] + + +class RotaryEmbedding(torch.nn.Module): def __init__( - self, neox_args, layer_cls, init_method, output_layer_init_method, ep_group + self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False, return_embeddings=False ): super().__init__() - megablocks_utils.assert_megablocks_is_available() - args = megablocks_utils.as_megablocks_args(neox_args) - args.device = torch.cuda.current_device() - args.init_method = init_method - args.output_layer_init_method = output_layer_init_method - - # NOTE: Shard the MoE layers over the data parallel group. Expert - # parallel sharding and data parallel sharding could be decoupled - # by extending the optimizer to handle data parallel reductions for - # MoE and non-MoE parameters separately. - if args.moe_expert_model_parallelism: - args.expert_parallel_group = ep_group - - self.moe = layer_cls(args) - - def forward(self, x): - return self.moe.forward(x) - - -class MbMoE(_MegablocksAdapter): - def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): - super().__init__( - neox_args, - megablocks_utils.moe.MoE, - init_method, - output_layer_init_method, - ep_group, + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) + self.seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + self.max_seq_len = max_seq_len + self.base = base + self.dim = dim + self.return_embeddings = return_embeddings + + # precompute cos_cached, sin_cached in fp32 + cos_cached, sin_cached, inv_freq = self._prepare_cache( + max_seq_len, precision, base ) + self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) + self.cos_cached = cos_cached + self.sin_cached = sin_cached -class dMoE(_MegablocksAdapter): - def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): - super().__init__( - neox_args, - megablocks_utils.dmoe.dMoE, - init_method, - output_layer_init_method, - ep_group, - ) + def _prepare_cache(self, seq_len, precision, base): + # precompute cos_cached, sin_cached in fp32 + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + t = torch.arange(seq_len).type_as(inv_freq) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) -class ParallelSelfAttention(nn.Module): - """Parallel self-attention layer abstract class. + self.emb = emb.reshape(emb.size(0), 1, 1, emb.size(1)) - Self-attention layer takes input with size [b, s, h] - and returns output of the same size. - """ + cos_cached = emb.cos()[:, None, None, :] + sin_cached = emb.sin()[:, None, None, :] - def __init__( - self, - neox_args, - attention_mask_func, - init_method, - output_layer_init_method, - layer_number, - rpe=None, - rotary=False, - use_cache=False, - parallel_output=False, - ): - super().__init__() - - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) - - self.fp16 = neox_args.precision == "fp16" - self.bf16 = neox_args.precision == "bfloat16" - self.attention_mask_func = attention_mask_func - self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling - self.use_cache = use_cache - self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = layer_number - # Per attention head and per partition values. - world_size = mpu.get_model_parallel_world_size() - self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) - self.hidden_size_per_attention_head = mpu.divide( - neox_args.hidden_size, neox_args.num_attention_heads + return ( + cos_cached.to(precision), + sin_cached.to(precision), + inv_freq.to(precision), ) - self.num_attention_heads_per_partition = mpu.divide( - neox_args.num_attention_heads, world_size - ) - self.pos_emb = neox_args.pos_emb - self.use_qk_layernorm = neox_args.use_qk_layernorm - if self.use_qk_layernorm: - norm, eps = get_norm(neox_args) - self.qk_layernorm = norm( - [ - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ], - eps=eps, - ) + def forward(self, x, seq_dim=0, seq_len=None): + if self.return_embeddings: + return self.emb.to(self.precision).to(x.device) + if seq_len is None: + seq_len = x.shape[seq_dim] - self.sliding_window_width = neox_args.sliding_window_width + assert seq_len <= self.max_seq_len - if ( - not neox_args.num_kv_heads - or neox_args.num_kv_heads == neox_args.num_attention_heads - ): - self.gqa = False - else: - self.gqa = True - if self.gqa: - self.num_kv_heads_per_partition = mpu.divide( - neox_args.num_kv_heads, world_size - ) # we do not yet clone KV heads in MQA across TP ranks... - self.kv_hidden_size = ( - neox_args.num_kv_heads * self.hidden_size_per_attention_head - ) # how large the total hidden dim for each of K and V is - else: - self.num_kv_heads_per_partition = self.num_attention_heads_per_partition - self.kv_hidden_size = neox_args.hidden_size - - if not self.gqa: - # Strided linear layer. - self.query_key_value = ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=3 * neox_args.hidden_size, - gather_output=False, - init_method=init_method, - bias=neox_args.use_bias_in_attn_linear, + if seq_len != self.max_seq_len: + # y, z, _ = self._prepare_cache(seq_len, self.precision, self.base) + return ( + self.cos_cached[:seq_len, ...].to(x.device), + self.sin_cached[:seq_len, ...].to(x.device), ) else: - # QKV proj is smaller if we are using GQA / MQA - self.query_key_value = ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, - gather_output=False, - init_method=init_method, - bias=neox_args.use_bias_in_attn_linear, - ) - - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = max(1, self.layer_number) - self.norm_factor *= coeff - - if neox_args.use_mup: - self.norm_factor = self.hidden_size_per_attention_head - - self.rpe = rpe - - if self.pos_emb == "alibi": - self.alibi_embed = AliBi( - neox_args.num_attention_heads, - neox_args.model_parallel_size, - mpu.get_model_parallel_rank(), - ) + return self.cos_cached.to(x.device), self.sin_cached.to(x.device) - # TODO: this arg shouldn't need to be passed in - get from neox_args - if rotary: - if neox_args.rotary_pct == 1: - self.rotary_ndims = None - else: - assert neox_args.rotary_pct < 1 - self.rotary_ndims = int( - self.hidden_size_per_attention_head * neox_args.rotary_pct - ) - dim = ( - self.rotary_ndims - if self.rotary_ndims is not None - else self.hidden_size_per_attention_head - ) - self.rotary_emb = RotaryEmbedding( - dim, - base=neox_args.rotary_emb_base, - max_seq_len=neox_args.seq_length, - precision=neox_args.params_dtype, - save_inv_freqs=neox_args.rotary_save_freqs_buffer, - ) - else: - self.rotary_emb = None - self.rope_fusion = neox_args.rope_fusion - self.attention_type = neox_args.attention_config[layer_number] - self.use_flash_attention = self.attention_type == "flash" - self.use_triton = ( - self.use_flash_attention - and self.pos_emb == "alibi" - and ( - not packaging.version.Version(version("flash-attn")) - >= packaging.version.Version("2.4.0.post1") - ) - ) - self.sparse = self.attention_type not in ("global", "flash") +# rotary pos emb helpers: - if self.gqa: - assert not self.sparse - if self.sparse: - self.sparse_attn = configure_sparse_attention( - neox_args, - self.attention_type, - self.num_attention_heads_per_partition, - mpu=mpu, - ) - else: - if self.use_flash_attention: - # we now use Flash Attention 2's provided interface. - # TODO: we no longer need to use flash_triton_fn since flash cuda supports alibi. - # consider adding OpenAI's more recent Flash-2 Triton kernel in future - # from https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py - from flash_attn.flash_attn_interface import ( - flash_attn_func, - flash_attn_varlen_func, - ) - from flash_attn.flash_attn_triton import ( - flash_attn_func as flash_attn_unpadded_unpacked_func_triton, - ) +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat( + (-x2, x1), dim=x1.ndim - 1 + ) # dim=-1 triggers a bug in earlier torch versions - self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton - self.flash_qkv_fn = flash_attn_func - self.flash_varlen_qkv_fn = flash_attn_varlen_func - else: - self.scale_mask_softmax = FusedScaleMaskSoftmax( - input_in_fp16=self.fp16, - input_in_bf16=self.bf16, - fusion_type=get_fusion_type(neox_args), - mask_func=self.attention_mask_func, - softmax_in_fp32=self.attention_softmax_in_fp32, - scale=coeff, - ) - - # Dropout. Note that for a single iteration, this layer will generate - # different outputs on different number of parallel partitions but - # on average it should not be partition dependent. - self.dropout_p = neox_args.attention_dropout - self.attention_dropout = nn.Dropout(self.dropout_p) - - # Output. - self.dense = RowParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=neox_args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True, - parallel_output=parallel_output, - bias=neox_args.use_bias_in_attn_linear, - ) - - def attention( - self, query_layer, key_layer, value_layer, layer_past, attention_mask - ): - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = ( - query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0), - ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view( - output_size[2], output_size[0] * output_size[1], -1 - ) - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - # preallocating result tensor: [b * np, sq, sk] - matmul_result = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device(), - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - # ================================================== - # Update attention mask for inference. [b, np, sq, sk] - # ================================================== - - if self.use_cache: - with torch.no_grad(): - attention_mask = attention_mask[ - ..., : attention_scores.size(3), : attention_scores.size(3) - ] - # =========================== - # Attention probs and dropout - # =========================== +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos, sin = ( + cos[offset : q.shape[0] + offset, ...], + sin[offset : q.shape[0] + offset, ...], + ) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - if exists(self.rpe): - rpe = self.rpe(query_layer.size(0), key_layer.size(0)) - attention_scores += rpe # [1, np, sq, sk] - if self.pos_emb == "alibi": - attention_scores = self.alibi_embed(attention_scores) - - # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - with mpu.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = ( - value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3), - ) - - # change view [sk, b * np, hn] - value_layer = value_layer.view( - value_layer.size(0), output_size[0] * output_size[1], -1 - ) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view( - output_size[0] * output_size[1], output_size[2], -1 - ) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - return context_layer - - def flash_attention(self, query_layer, key_layer, value_layer): - # [b, np, sq, sk] - output_size = ( - query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0), - ) - - if self.use_flash_attention and not self.use_triton: - - # [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn] - key_layer = key_layer.transpose(0, 1).reshape( - output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 - ) - value_layer = value_layer.transpose(0, 1).reshape( - output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 - ) - - # [sq, b, np, hn] -> [b, sq, np, hn] - query_layer = query_layer.transpose(0, 1).reshape( - output_size[0], output_size[2], output_size[1], -1 - ) - - # only pass in window_size or alibi_slopes kwarg - # if we use Sliding Window Attention / AliBi. - # Flash attn defaults to (-1,-1), or - # does not have this kwarg prior to v2.3.0 - extra_kwargs = ( - {"window_size": (self.sliding_window_width, -1)} - if self.sliding_window_width is not None - else {} - ) - if self.pos_emb == "alibi": - extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to( - query_layer.device - ).to(torch.float32) - - if not self.training: - batch_size = output_size[0] - max_seqlen_q = output_size[2] - max_seqlen_k = output_size[3] - - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * max_seqlen_q, - step=max_seqlen_q, - dtype=torch.int32, - device=query_layer.device, - ) - - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * max_seqlen_k, - step=max_seqlen_k, - dtype=torch.int32, - device=key_layer.device, - ) - - q_shape = query_layer.shape - k_shape = key_layer.shape - v_shape = value_layer.shape - is_causal = max_seqlen_q == max_seqlen_k - output = self.flash_varlen_qkv_fn( - query_layer.reshape( - (q_shape[0] * q_shape[1], q_shape[2], q_shape[3]) - ), - key_layer.reshape( - (k_shape[0] * k_shape[1], k_shape[2], k_shape[3]) - ), - value_layer.reshape( - (v_shape[0] * v_shape[1], v_shape[2], v_shape[3]) - ), - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale=None, - causal=is_causal, - **extra_kwargs, - ) - output = output.reshape(q_shape) - else: - output = self.flash_qkv_fn( - query_layer, - key_layer, - value_layer, - self.dropout_p if self.training else 0.0, - softmax_scale=None, - causal=True, - **extra_kwargs, - ) - - matmul_result = output - # [b, sq, np, hn] -> [b, np, sq, hn] - matmul_result = matmul_result.transpose(1, 2) - - else: - # we still use Triton if using AliBi with flash-attn<2.4.0.post1. +def apply_rotary_pos_emb_torch( + q, k, cos, sin, offset: int = 0 +): # jitting fails with bf16 + cos, sin = ( + cos[offset : q.shape[0] + offset, ...], + sin[offset : q.shape[0] + offset, ...], + ) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - # [sq, b, np, hn] -> [b, sq, np, hn] - sq = query_layer.size(0) - b = query_layer.size(1) - sk = key_layer.size(0) - - query_layer = query_layer.transpose(0, 1) - key_layer = key_layer.transpose(0, 1) - value_layer = value_layer.transpose(0, 1) - - bias = self.alibi_embed.bias(sq, sk, query_layer.device, query_layer.dtype) - bias = bias.unsqueeze(0).tile((b, 1, 1, 1)) - - matmul_result = self.flash_triton_fn( - query_layer, key_layer, value_layer, bias=bias, causal=True - ) - matmul_result = matmul_result.transpose(1, 2) - - return matmul_result - - def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): - # TODO: sparse attn dropout? - # TODO: pad to block size - # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn] - query_layer, key_layer, value_layer = map( - lambda t: t.permute(1, 2, 0, 3).contiguous(), - (query_layer, key_layer, value_layer), - ) - # output shape [b, np(heads), sq, hn] - attn_mask = attention_mask.to(query_layer.dtype) * -10000 - if exists(self.rpe): - rpe = self.rpe(query_layer.size(0), key_layer.size(0)) - else: - rpe = None - return self.sparse_attn( - query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe - ) - - def gqa_project(self, hidden_states, attention_mask, layer_past=None): - # QKV projection and separation into separate Q/K/V layers for GQA, - # where KV projections may be smaller than Q projection. - # the logic for this is explained in comments of this function - # detailing the intermediate sizes of tensors at each reshape. - - # pass through projection: [sq, b, h] --> [sq, b, ((np + 2 * kvp) * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) - - # First: reshape so we have seqlen, batch, and num. query heads each as separate dims - # Final dim is not exactly head dim: the first (head dim) dims are query heads, - # The last (head dim * ratio of kv to q heads) each are the "k/v heads" - # (right now we treat like we have same num. heads, but smaller head dim) - - # [sq, b, ((np + 2 * kvp) * hn)] --> [sq, b, np, (hn * (1 + 2 * (kvp / np)))] - new_qkv_shape = ( - mixed_x_layer.shape[0], - mixed_x_layer.shape[1], - self.num_attention_heads_per_partition, - int( - self.hidden_size_per_attention_head - * ( - 1 - + 2 - * ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - ) - ), - ) - mixed_x_layer = mixed_x_layer.reshape(*new_qkv_shape) - - # Next: split our fake head dim. (last dim) so that the first (head dim) dimensions go to Q, - # the last smaller 2 * (head dim * kv to q head ratio) each divided between K and V separately - split_sizes = ( - self.hidden_size_per_attention_head, - int( - ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - * self.hidden_size_per_attention_head - ), - int( - ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - * self.hidden_size_per_attention_head - ), - ) - - # [sq, b, np, (hn * (1 + 2 * (kvp / np)))] --> 1 x [sq, b, np, hn] , 2 x [sq, b, np, (hn * (kvp / np))] - (query_layer, key_layer, value_layer) = [ - x.contiguous() - for x in torch.split( - mixed_x_layer, - split_sizes, - dim=mixed_x_layer.dim() - 1, - ) - ] - - # reshape K/V to proper output shape (last dim = correct full "real" head size again) - # 2 x [sq, b, np, (hn * (kvp / np))] --> 2 x [sq, b, kvp, hn] - new_kv_shape = ( - key_layer.size(0), - key_layer.size(1), - self.num_kv_heads_per_partition, - self.hidden_size_per_attention_head, - ) - - key_layer = key_layer.view(*new_kv_shape) - - value_layer = value_layer.view(*new_kv_shape) - - # if not using Flash attention, we repeat K/V heads to match Q head counts - if not self.use_flash_attention: - key_layer = torch.repeat_interleave( - key_layer, - repeats=int( - self.num_attention_heads_per_partition - // self.num_kv_heads_per_partition - ), - dim=2, - ) - value_layer = torch.repeat_interleave( - value_layer, - repeats=int( - self.num_attention_heads_per_partition - // self.num_kv_heads_per_partition - ), - dim=2, - ) - - return query_layer, key_layer, value_layer - - def forward(self, hidden_states, attention_mask, layer_past=None): - - # hidden_states: [sq, b, h] - - # ===================== - # Query, Key, and Value - # ===================== - if not self.gqa: - # QKV projection for MHA. - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim( - mixed_x_layer, 3 - ) - else: - # Grouped Query Attention (GQA) - specific logic for performing QKV proj - # and separating out Q, K, and V outputs. - - # output shapes: 1 x [sq, b, np, hn], 2 x [sq, b, kvp, hn] if using flash - query_layer, key_layer, value_layer = self.gqa_project( - hidden_states, attention_mask, layer_past=layer_past - ) - # QK Normalization https://arxiv.org/abs/2302.05442 - if self.use_qk_layernorm: - query_layer = self.qk_layernorm(query_layer) - key_layer = self.qk_layernorm(key_layer) - - if exists(self.rotary_emb): - if exists(self.rotary_ndims): - # partial rotary - query_rot, query_pass = ( - query_layer[..., : self.rotary_ndims], - query_layer[..., self.rotary_ndims :], - ) - key_rot, key_pass = ( - key_layer[..., : self.rotary_ndims], - key_layer[..., self.rotary_ndims :], - ) - else: - # full rotary - query_rot, key_rot = query_layer, key_layer - - seq_len = key_layer.shape[0] - offset = 0 - if exists(layer_past) and layer_past.numel() > 0: - offset = layer_past[0].shape[0] - seq_len += offset - cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) - if self.rope_fusion: - query_layer, key_layer = ( - fused_apply_rotary_pos_emb_cached(rot, cos, sin) - for rot in [query_rot, key_rot] - ) - else: - if self.bf16: - apply_rotary_fn = apply_rotary_pos_emb_torch - else: - apply_rotary_fn = apply_rotary_pos_emb - query_layer, key_layer = apply_rotary_fn( - query_rot, key_rot, cos, sin, offset=offset - ) - - if exists(self.rotary_ndims): - query_layer = torch.cat((query_layer, query_pass), dim=-1) - key_layer = torch.cat((key_layer, key_pass), dim=-1) - - - # ================================== - # Cache key and value for inference - # ================================== - - if exists(layer_past) and layer_past.numel() > 0: - past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) - value_layer = torch.cat( - (past_value.type_as(value_layer), value_layer), dim=0 - ) - - if self.use_cache: - present = torch.stack((key_layer, value_layer)) - - if self.use_flash_attention: - context_layer = self.flash_attention(query_layer, key_layer, value_layer) - elif not self.sparse: - context_layer = self.attention( - query_layer, key_layer, value_layer, layer_past, attention_mask - ) - else: - context_layer = self.sparse_attention( - query_layer, key_layer, value_layer, attention_mask - ) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + ( - self.hidden_size_per_partition, - ) - context_layer = context_layer.view(*new_context_layer_shape) - - # ================= - # Output. [sq, b, h] - # ================= - - output, bias = self.dense(context_layer) - - if self.use_cache: - output = [output, present] - - return output, bias - - -class ParallelTransformerLayer(nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [b, s, h] and returns an - output of the same size. - """ - - def __init__( - self, - neox_args, - attention_mask_func, - init_method, - output_layer_init_method, - layer_number, - rpe=None, - rotary=False, - use_cache=False, - ): +class AliBi(torch.nn.Module): + def __init__(self, num_heads, mp_size=1, mp_rank=1): super().__init__() - self.layer_number = layer_number - self.neox_args = neox_args - - norm, eps = get_norm(neox_args) - - # Layernorm on the input data. - self.input_layernorm = norm(neox_args.hidden_size, eps=eps) - self.use_cache = use_cache - - self.hidden_dropout = neox_args.hidden_dropout - self.bias_dropout_fusion = neox_args.bias_dropout_fusion - self.gpt_j_residual = neox_args.gpt_j_residual - self.gpt_j_tied = neox_args.gpt_j_tied - self.moe_type = neox_args.moe_type - self.activation = neox_args.activation - - if self.gpt_j_residual: - # GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers. - # the reduction we use is a simple allreduce for pure Tensor Parallel, - # but needs to be a reduce-scatter when using Megatron-style Sequence Parallel (LN sharding.) - self.reduce = ( - mpu.mappings.reduce_from_model_parallel_region - if not neox_args.sequence_parallel - else mpu.mappings.reduce_scatter_to_sequence_parallel_region - ) - - # Self attention. - if neox_args.te_mha or neox_args.fp8_mha: - from megatron.model.transformer_engine import TEMultiheadAttention - self.attention = TEMultiheadAttention( - neox_args=neox_args, - attention_mask_func=attention_mask_func, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - rpe=rpe, - use_cache=self.use_cache, - rotary=rotary, - parallel_output=self.gpt_j_residual, - ) - + # megatron splits across heads, so we need to make sure each + # head receives the correct matrix + assert mp_size <= num_heads and mp_rank <= mp_size + self.mp_size = mp_size + self.mp_rank = mp_rank + self.num_heads = num_heads + self.slice_size = num_heads // mp_size + self.cached_matrix = None + self.cached_seq_len = None + slopes = torch.Tensor(self._get_slopes(num_heads))[ + mp_rank * self.slice_size : (mp_rank + 1) * self.slice_size + ] + self.register_buffer("slopes", slopes) + + def _get_slopes(self, n): + """ + Get slopes for Alibi positional embedding + n : int = number of heads. + For best performance, restrict n to a power of 2. + """ + + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) else: - self.attention = ParallelSelfAttention( - neox_args=neox_args, - attention_mask_func=attention_mask_func, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - rpe=rpe, - use_cache=self.use_cache, - rotary=rotary, - parallel_output=self.gpt_j_residual, - ) - - # Layernorm on the output of the attention layer. - # If GPT-J residuals are used, this is surpurfulous but leaving it in - # leads to cleaner code - self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps) - - # MLP - def get_mlp(**kw): - return ParallelMLP( - neox_args=neox_args, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - parallel_output=self.gpt_j_residual, - multiple_of=neox_args.mlp_multiple_of, - **kw, - ) - - def get_te_lnmlp(**kw): - from megatron.model.transformer_engine import TELayerNormMLP - return TELayerNormMLP( - neox_args=neox_args, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - parallel_output=self.gpt_j_residual, - multiple_of=neox_args.mlp_multiple_of, - **kw, + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + self._get_slopes(2 * closest_power_of_2)[0::2][ + : n - closest_power_of_2 + ] ) - self.num_experts = ( - neox_args.moe_num_experts - if layer_number % neox_args.expert_interval == 0 - else 1 - ) - args = neox_args - if self.num_experts <= 1: - if neox_args.te_layernorm_mlp: - self.mlp = get_te_lnmlp() - else: - self.mlp = get_mlp() - else: - from torch import distributed as dist - - if self.num_experts > dist.get_world_size(): - moe_mp_size = 1 - else: - moe_mp_size = dist.get_world_size() // self.num_experts - - if neox_args.moe_type == "deepspeed": - self.mlp = MoE( - args.hidden_size, - get_mlp( - "regular", - MOE=True, - MoE_mp_size=moe_mp_size, - ), - num_experts=self.num_experts, - ep_size=args.moe_expert_parallel_size, - k=args.moe_top_k, - use_residual=args.moe_use_residual, - capacity_factor=args.moe_train_capacity_factor, - eval_capacity_factor=args.moe_eval_capacity_factor, - min_capacity=args.moe_min_capacity, - drop_tokens=args.moe_token_dropping, - use_tutel=args.use_tutel, - enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, - ) - elif neox_args.moe_type == "megablocks": - - def integrate_megablocks_with_ds_expert_parallelism(): - # We make megablocks work with DS parallelism. - # - # We fool DS into accepting these MoE parameters as its own DS MoE params, - # which makes things work with the underlying expert parallelism, - # including TED parallelism. - # - # Effectively, we want to: - # - # - Make DS's data parallel gradient all-reduction skip these params. - # - But make these params participate in the expert parallel all-reduction! - # - # Further background: - # - # Normally, with the original megablocks demo codebase, it - # only supports 1 copy of any expert throughout - # the network, since it uses EP group = DP group. - # - # First, we trigger DS initialization of the MoE expert parallel groups and internal state. - throwaway = MoE( - args.hidden_size, - get_mlp( - "regular", - MOE=True, - MoE_mp_size=moe_mp_size, - ), - num_experts=self.num_experts, - ep_size=args.moe_expert_parallel_size, - k=args.moe_top_k, - use_residual=args.moe_use_residual, - capacity_factor=args.moe_train_capacity_factor, - eval_capacity_factor=args.moe_eval_capacity_factor, - min_capacity=args.moe_min_capacity, - drop_tokens=args.moe_token_dropping, - use_tutel=args.use_tutel, - enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, - ) - throwaway.set_deepspeed_parallelism() - - ep_group = throwaway.deepspeed_moe.ep_group - if args.moe_token_dropping: - self.mlp = MbMoE( - neox_args, init_method, output_layer_init_method, ep_group - ) - else: - self.mlp = dMoE( - neox_args, init_method, output_layer_init_method, ep_group - ) - - # Next, we trick DS into seeing these as its own MoE params. - for param in self.mlp.parameters(): - if getattr(param, "expert_model_parallel", None) is not None: - # is_moe_param looks for this attr. - param.allreduce = False - param.group_name = throwaway.expert_group_name - - integrate_megablocks_with_ds_expert_parallelism() - - else: - raise KeyError(neox_args.moe_type) - - self.layer_past = None # used to cache k/v pairs in inference - - def _get_bias_dropout(self): - if self.bias_dropout_fusion: - fn = ( - bias_dropout_add_fused_train - if self.training - else bias_dropout_add_fused_inference - ) + def bias(self, seq_len_q, seq_len_k, device, dtype): + # [b, np, sq, sk] + # seq_len_q = x.shape[-2] + # seq_len_k = x.shape[-1] + + # Initialize the AliBi matrix to match the first provided key length; grow it exponentially + # afterwards if longer inputs are provided. This is important for inference, where we will + # encounter progressively longer samples; it should have no effect at training time. + if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k: + a = self.cached_matrix else: - fn = get_bias_dropout_add(self.training) - return fn + target_seq_len = ( + seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4 + ) + a = -torch.tril( + torch.arange(target_seq_len) + .view(target_seq_len, 1) + .repeat(1, target_seq_len) + + torch.arange(0, -target_seq_len, -1) + ) + a = a.to(device).to(dtype) + slopes = self.slopes.to(a.device).to(a.dtype) + a = a * slopes.view(self.slopes.shape[0], 1, 1) + self.cached_seq_len = target_seq_len + self.cached_matrix = a + + # If the AliBi matrix is larger than the key length, clip it. + if self.cached_seq_len > seq_len_k: + a = self.cached_matrix[:, :seq_len_k, :seq_len_k] + + if seq_len_q != seq_len_k: + # In the train case x has dimensionality [b, np, sq, sk] with sq == sk + # The number of query tokens is equal to the number of key tokens + # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence) + # In this case we use the appropriate token index of the cache matrix. + # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used + assert ( + seq_len_q == 1 + ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" + a = a[:, seq_len_k - 1, :].view( + a.shape[0], 1, a.shape[2] + ) # seq_len_k - 1 points to the last token index in the current inference batch. + + return a - def forward(self, x, attention_mask, layer_past=None): - layer_past = layer_past if layer_past is not None else self.layer_past - bias_dropout_fn = self._get_bias_dropout() - moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) - # x: [b, s, h] - - - #Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer. - if self.neox_args.fp8_mha: - from megatron.model.transformer_engine import TEDelayedScaling - - fp8_recipe = TEDelayedScaling( - neox_args=self.neox_args - ) - fp8_context = fp8_recipe.get_context() + def forward(self, x): + # [b, np, sq, sk] + seq_len_q = x.shape[-2] + seq_len_k = x.shape[-1] + + # Initialize the AliBi matrix to match the first provided key length; grow it exponentially + # afterwards if longer inputs are provided. This is important for inference, where we will + # encounter progressively longer samples; it should have no effect at training time. + if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k: + a = self.cached_matrix else: - from contextlib import nullcontext - fp8_context = nullcontext() - - with fp8_context: - if self.gpt_j_residual: - # pseudocode: - # x = x + attn(ln(x)) + mlp(ln(x)) - # this means we can avoid doing the allreduce in the attn / mlp outputs - # to save communication time (we can do a single allreduce after we add mlp / attn outputs). - # due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but - # we preserve the functionality for backwards compatibility - - residual = x - # applies the correct normalization depending on if the norms are tied - if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: - x = self.input_layernorm(x) - x1, x2 = x, x - elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: - x2 = x - x = self.input_layernorm(x) - x1 = x - elif self.neox_args.te_layernorm_mlp: - x1, x2 = self.input_layernorm(x), x - else: - x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) - - # attention operator - attention_output, attention_bias = self.attention( - x1, attention_mask, layer_past=layer_past - ) - if self.use_cache: - attention_output, presents = attention_output - self.layer_past = presents - - if attention_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): - attention_output = bias_dropout_fn( - attention_output, - bias=attention_bias.expand_as(attention_output), - residual=None, - prob=self.hidden_dropout, - ) - - # mlp operator - mlp_output, mlp_bias = self.mlp(x2) - if mlp_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): - output = bias_dropout_fn( - mlp_output, - bias=mlp_bias.expand_as(mlp_output), - residual=attention_output, - prob=self.hidden_dropout, - ) - else: - output = mlp_output - - # output = (x + attn(ln(x)) + mlp(ln(x)) - output = residual + self.reduce(output) - else: - # pseudocode: - # x = x + attn(ln1(x)) - # x = x + mlp(ln2(x)) - - residual = x - - # x = x + attn(ln1(x)) - attention_output, attention_bias = self.attention( - self.input_layernorm(x), attention_mask, layer_past=layer_past - ) - if self.use_cache: - attention_output, presents = attention_output - self.layer_past = presents - with torch.enable_grad() if not self.eval else nullcontext(): - if attention_bias is not None: - # Use special bias_dropout_fn if we have a bias term from the above attention layer - attention_output = bias_dropout_fn( - attention_output, - bias=attention_bias.expand_as(residual), - residual=residual, - prob=self.hidden_dropout, - ) - else: - # Otherwise just apply dropout + residual - attention_output = ( - torch.nn.functional.dropout( - attention_output, - p=self.hidden_dropout, - training=self.training, - ) - + residual - ) - - # output = x + mlp(ln2(x)) - if self.neox_args.te_layernorm_mlp: - layernorm_output = attention_output - else: - layernorm_output = self.post_attention_layernorm(attention_output) - mlp_bias = torch.tensor( - 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype - ) - - if self.num_experts == 1: - mlp_output, mlp_bias = self.mlp(layernorm_output) - else: - if self.moe_type == "deepspeed": - mlp_output, moe_loss, _ = self.mlp(layernorm_output) - mlp_bias = ( - None # deepspeed.moe.layer.MoE.forward ignores the bias term - ) - elif self.moe_type == "megablocks": - mlp_output, mlp_bias = self.mlp(layernorm_output) - else: - raise KeyError(self.moe_type) - - with torch.enable_grad() if not self.eval else nullcontext(): - if ( - self.activation == "swiglu" - or self.num_experts > 1 - and self.moe_type == "deepspeed" - ): - # No dropout either - assert mlp_bias is None - output = mlp_output + attention_output - else: - output = bias_dropout_fn( - mlp_output, - bias=mlp_bias.expand_as(attention_output), - residual=attention_output, - prob=self.hidden_dropout, - ) - - return output, moe_loss - - -class ParallelTransformerLayerPipe(ParallelTransformerLayer): - """Extends ParallelTransformerLayer to forward attention_mask through the pipeline.""" - - def forward(self, args): - assert ( - len(args) == 2 - ), "ParallelTransformerLayerPipe expects 2 arguments - hidden_states and attention_mask" - hidden_states, attention_mask = args - # we are returning just [hidden_states, mask] - output, moe_loss = super().forward(hidden_states, attention_mask) - # auxiliary output - self.last_moe_loss = moe_loss - return output, attention_mask - - -class ParallelLinearPipe(ParallelLinear): - """Another helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" - - def forward(self, args): - assert isinstance( - args, torch.Tensor - ), "ParallelLinearPipe expects a single argument - hidden_states" - hidden_state = args - logits, bias = super().forward(hidden_state) - return logits - - -class NormPipe(nn.Module): - """Just a helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" - - def __init__(self, norm_class, hidden_size, eps): - super().__init__() - self.norm = norm_class(hidden_size, eps=eps) - - def forward(self, args): - assert not isinstance( - args, tuple - ), "NormPipe should only receive a single tensor as input" - return self.norm(args) - - -def parallel_lm_logits( - input_, - word_embeddings_weight, - parallel_output, - seq_parallel=False, - seq_dim=1, - bias=None, -): - """LM logits using word embedding weights.""" - # Parallel logits. - if seq_parallel: - # if using Sequence Parallelism, our logits are sharded along the sequence dimension. - # gather them here. (backward pass: reduce-scatter) - input_parallel = mpu.gather_from_sequence_parallel_region( - input_, seq_dim=seq_dim - ) - else: - # Set up backprop all-reduce. - input_parallel = mpu.copy_to_model_parallel_region(input_) - - # Matrix multiply. - if bias is None: - logits_parallel = F.linear(input_parallel, word_embeddings_weight) - else: - logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) - - # Gather if needed. - if parallel_output: - return logits_parallel - - return mpu.gather_from_model_parallel_region(logits_parallel) + target_seq_len = ( + seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4 + ) + a = -torch.tril( + torch.arange(target_seq_len) + .view(target_seq_len, 1) + .repeat(1, target_seq_len) + + torch.arange(0, -target_seq_len, -1) + ) + a = a.to(x.device).to(x.dtype) + slopes = self.slopes.to(a.device).to(a.dtype) + a = a * slopes.view(self.slopes.shape[0], 1, 1) + self.cached_seq_len = target_seq_len + self.cached_matrix = a + + # If the AliBi matrix is larger than the key length, clip it. + if self.cached_seq_len > seq_len_k: + a = self.cached_matrix[:, :seq_len_k, :seq_len_k] + + if seq_len_q != seq_len_k: + # In the train case x has dimensionality [b, np, sq, sk] with sq == sk + # The number of query tokens is equal to the number of key tokens + # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence) + # In this case we use the appropriate token index of the cache matrix. + # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used + assert ( + seq_len_q == 1 + ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" + a = a[:, seq_len_k - 1, :].view( + a.shape[0], 1, a.shape[2] + ) # seq_len_k - 1 points to the last token index in the current inference batch. + + return x + a \ No newline at end of file From bb7651040f4a20002bf93f9d180fdcfb6232e780 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Fri, 27 Sep 2024 20:41:31 +0000 Subject: [PATCH 19/27] Changed get_linear name --- megatron/model/transformer.py | 9 +++++---- megatron/model/utils.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index fdf384a4f..80481e334 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -55,7 +55,7 @@ except ImportError: swiglu = None -from .utils import linear_implementation_router +from .utils import get_parallel_linear # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) @@ -116,7 +116,7 @@ def __init__( self.bias_gelu_fusion = neox_args.bias_gelu_fusion self.multiple_of = multiple_of - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) + ColumnParallelLinear, RowParallelLinear = get_parallel_linear(neox_args) if neox_args.intermediate_size: ffn_dim = neox_args.intermediate_size @@ -220,7 +220,7 @@ def __init__( ): super().__init__() - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) + ColumnParallelLinear, RowParallelLinear = get_parallel_linear(neox_args) self.is_rm = neox_args.train_impl == "rm" parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" @@ -340,7 +340,7 @@ def __init__( ): super().__init__() - ColumnParallelLinear, RowParallelLinear = linear_implementation_router(neox_args) + ColumnParallelLinear, RowParallelLinear = get_parallel_linear(neox_args) self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" @@ -1281,6 +1281,7 @@ def forward(self, x, attention_mask, layer_past=None): attention_output, attention_bias = self.attention( self.input_layernorm(x), attention_mask, layer_past=layer_past ) + if self.use_cache: attention_output, presents = attention_output self.layer_past = presents diff --git a/megatron/model/utils.py b/megatron/model/utils.py index d1ec2a347..d39da6194 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -404,7 +404,7 @@ def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): param.register_hook(reduce_weight_grads_from_model_parallel_region) -def linear_implementation_router(neox_args): +def get_parallel_linear(neox_args): if neox_args.te_columnparallel: from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear else: From 43cf4ee6697894a3cca426eafed96f3ee5635c8f Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Tue, 1 Oct 2024 02:40:26 +0000 Subject: [PATCH 20/27] Added rng tracker to lnmlp and placed rope in te_mha init instead of forward --- megatron/model/positional_embeddings.py | 9 ++++----- megatron/model/transformer_engine.py | 12 ++++-------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 38d7b12c6..3ff34b189 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -37,8 +37,7 @@ def forward(self, x, seq_dim=1): class RotaryEmbedding(torch.nn.Module): def __init__( - self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False, return_embeddings=False - ): + self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) @@ -49,7 +48,6 @@ def __init__( self.max_seq_len = max_seq_len self.base = base self.dim = dim - self.return_embeddings = return_embeddings # precompute cos_cached, sin_cached in fp32 cos_cached, sin_cached, inv_freq = self._prepare_cache( @@ -79,9 +77,10 @@ def _prepare_cache(self, seq_len, precision, base): inv_freq.to(precision), ) + def get_emb(self): + return self.emb.to(self.precision).cuda() + def forward(self, x, seq_dim=0, seq_len=None): - if self.return_embeddings: - return self.emb.to(self.precision).to(x.device) if seq_len is None: seq_len = x.shape[seq_dim] diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 07559bcfb..e4bbe4120 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -200,7 +200,7 @@ def __init__( init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, device=torch.cuda.current_device(), set_parallel_mode=self.set_parallel_mode, sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, tp_size=self.world_size, - return_bias=True, params_dtype=self.params_dtype, seq_length=self.seq_len, + return_bias=True, params_dtype=self.params_dtype, seq_length=self.seq_len, get_rng_state_tracker=get_cuda_rng_tracker, micro_batch_size=self.micro_batch_size) @@ -538,16 +538,12 @@ def __init__(self, base=neox_args.rotary_emb_base, max_seq_len=neox_args.seq_length, precision=neox_args.params_dtype, - save_inv_freqs=neox_args.rotary_save_freqs_buffer, - return_embeddings=True + save_inv_freqs=neox_args.rotary_save_freqs_buffer ) + self.rope_emb=self.rotary_embeddings.get_emb() def forward(self, hidden_states, attention_mask, layer_past=None, rope_emb=None, **kwargs): - if self.neox_args.pos_emb == "rotary": - rope_emb=self.rotary_embeddings(hidden_states) - - output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, rotary_pos_emb=rope_emb, **kwargs) - + output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, rotary_pos_emb=self.rope_emb, **kwargs) return output From 42716f2df5b96bd9396de24589f9a8c5092c4bb6 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Tue, 1 Oct 2024 04:27:57 +0000 Subject: [PATCH 21/27] Updated fp8 arguments to te_fp8 --- megatron/model/transformer.py | 4 ++-- megatron/model/transformer_engine.py | 14 +++++++------- megatron/neox_arguments/neox_args.py | 18 ++++++------------ 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 80481e334..0266e3f49 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1028,7 +1028,7 @@ def __init__( ) # Self attention. - if neox_args.te_mha or neox_args.fp8_mha: + if neox_args.te_mha or neox_args.te_fp8_mha: from megatron.model.transformer_engine import TEMultiheadAttention self.attention = TEMultiheadAttention( neox_args=neox_args, @@ -1204,7 +1204,7 @@ def forward(self, x, attention_mask, layer_past=None): #Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer. - if self.neox_args.fp8_mha: + if self.neox_args.te_fp8_mha: from megatron.model.transformer_engine import TEDelayedScaling fp8_recipe = TEDelayedScaling( diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index e4bbe4120..59d68a7b7 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -559,22 +559,22 @@ def __init__( self.neox_args = neox_args self.tp_group = get_tensor_model_parallel_group() - if neox_args.fp8_format == "e4m3": + if neox_args.te_fp8_format == "e4m3": fp8_format = te.common.recipe.Format.E4M3 - elif neox_args.fp8_format == "hybrid": + elif neox_args.te_fp8_format == "hybrid": fp8_format = te.common.recipe.Format.HYBRID else: raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") - override_linear_precision = (False, False, not neox_args.fp8_wgrad) + override_linear_precision = (False, False, not neox_args.te_fp8_wgrad) super().__init__( margin=neox_args.fp8_margin, - fp8_format=fp8_format, - amax_compute_algo=neox_args.fp8_amax_compute_algo, - amax_history_len=neox_args.fp8_amax_history_len, + fp8_format=te_fp8_format, + amax_compute_algo=neox_args.te_fp8_amax_compute_algo, + amax_history_len=neox_args.te_fp8_amax_history_len, override_linear_precision=override_linear_precision, - fp8_mha=neox_args.fp8_mha, + fp8_mha=neox_args.te_fp8_mha, ) def fp8_context(self): diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 8c91d5168..f1ec04f82 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -522,42 +522,36 @@ class NeoXArgsModel(NeoXArgsTemplate): Use TransformerEngine for MultiheadAttention layer. """ - fp8_format: Literal["e4m3", "hybrid"] = "hybrid" + te_fp8_format: Literal["e4m3", "hybrid"] = "hybrid" """ Controls the FP8 data format used during forward and backward pass by TransformerEngine. Hybrid uses E4M3 during forward pass, E5M2 during backward pass. """ - fp8_wgrad: bool = True + te_fp8_wgrad: bool = True """ When set to False, override FP8 config options and do the wgrad computation in higher precision. """ - fp8_amax_history_len: int = 1 + te_fp8_amax_history_len: int = 1 """ The length of the amax history window used for scaling factor computation. """ - fp8_amax_compute_algo: str = "most_recent" + te_fp8_amax_compute_algo: str = "most_recent" """ Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2 predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent` always chooses the most recently seen value. """ - fp8_wgrad: bool = True - """ - When set to False, override FP8 config options and do the wgrad computation - in higher precision. - """ - - fp8_margin: int = 0 + te_fp8_margin: int = 0 """ Margin for the scaling factor computation. """ - fp8_mha: bool = False + te_fp8_mha: bool = False """ When set to True, use the FP8 implementation of Multi Head Attention. """ From b3255e66f69f542cccd9d6488d4cbdfee9f053aa Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Tue, 1 Oct 2024 05:28:48 +0000 Subject: [PATCH 22/27] Added EAI copyright --- megatron/model/transformer_engine.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 59d68a7b7..beba2c184 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, EleutherAI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import torch From 98f0388a27f19878338361f430172dec7ea92f49 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Tue, 8 Oct 2024 12:29:37 -0700 Subject: [PATCH 23/27] precommit --- megatron/model/positional_embeddings.py | 7 +- megatron/model/transformer.py | 38 ++-- megatron/model/transformer_engine.py | 249 ++++++++++++++++-------- megatron/model/utils.py | 10 +- megatron/neox_arguments/neox_args.py | 5 +- 5 files changed, 195 insertions(+), 114 deletions(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 3ff34b189..072aad8b4 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -37,7 +37,8 @@ def forward(self, x, seq_dim=1): class RotaryEmbedding(torch.nn.Module): def __init__( - self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False): + self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False + ): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs) @@ -79,7 +80,7 @@ def _prepare_cache(self, seq_len, precision, base): def get_emb(self): return self.emb.to(self.precision).cuda() - + def forward(self, x, seq_dim=0, seq_len=None): if seq_len is None: seq_len = x.shape[seq_dim] @@ -253,4 +254,4 @@ def forward(self, x): a.shape[0], 1, a.shape[2] ) # seq_len_k - 1 points to the last token index in the current inference batch. - return x + a \ No newline at end of file + return x + a diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 0266e3f49..a84748b5c 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -414,7 +414,6 @@ def __init__( bias=neox_args.use_bias_in_attn_linear, ) - coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: @@ -860,7 +859,7 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None): return query_layer, key_layer, value_layer def forward(self, hidden_states, attention_mask, layer_past=None): - + # hidden_states: [sq, b, h] # ===================== @@ -934,7 +933,6 @@ def forward(self, hidden_states, attention_mask, layer_past=None): query_layer = torch.cat((query_layer, query_pass), dim=-1) key_layer = torch.cat((key_layer, key_pass), dim=-1) - # ================================== # Cache key and value for inference # ================================== @@ -1030,16 +1028,17 @@ def __init__( # Self attention. if neox_args.te_mha or neox_args.te_fp8_mha: from megatron.model.transformer_engine import TEMultiheadAttention + self.attention = TEMultiheadAttention( - neox_args=neox_args, - attention_mask_func=attention_mask_func, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - rpe=rpe, - use_cache=self.use_cache, - rotary=rotary, - parallel_output=self.gpt_j_residual, + neox_args=neox_args, + attention_mask_func=attention_mask_func, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + rpe=rpe, + use_cache=self.use_cache, + rotary=rotary, + parallel_output=self.gpt_j_residual, ) else: @@ -1073,6 +1072,7 @@ def get_mlp(**kw): def get_te_lnmlp(**kw): from megatron.model.transformer_engine import TELayerNormMLP + return TELayerNormMLP( neox_args=neox_args, init_method=init_method, @@ -1201,18 +1201,16 @@ def forward(self, x, attention_mask, layer_past=None): bias_dropout_fn = self._get_bias_dropout() moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) # x: [b, s, h] - - - #Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer. + + # Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer. if self.neox_args.te_fp8_mha: from megatron.model.transformer_engine import TEDelayedScaling - fp8_recipe = TEDelayedScaling( - neox_args=self.neox_args - ) + fp8_recipe = TEDelayedScaling(neox_args=self.neox_args) fp8_context = fp8_recipe.get_context() else: from contextlib import nullcontext + fp8_context = nullcontext() with fp8_context: @@ -1319,9 +1317,7 @@ def forward(self, x, attention_mask, layer_past=None): else: if self.moe_type == "deepspeed": mlp_output, moe_loss, _ = self.mlp(layernorm_output) - mlp_bias = ( - None # deepspeed.moe.layer.MoE.forward ignores the bias term - ) + mlp_bias = None # deepspeed.moe.layer.MoE.forward ignores the bias term elif self.moe_type == "megablocks": mlp_output, mlp_bias = self.mlp(layernorm_output) else: diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index beba2c184..e67071f88 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -30,7 +30,10 @@ from megatron.mpu.mappings import scatter_to_model_parallel_region from megatron.mpu.mappings import reduce_scatter_to_sequence_parallel_region from megatron.mpu.mappings import gather_from_sequence_parallel_region -from megatron.mpu.layers import _initialize_affine_weight_gpu, _initialize_affine_weight_cpu +from megatron.mpu.layers import ( + _initialize_affine_weight_gpu, + _initialize_affine_weight_cpu, +) from megatron.mpu.random import get_cuda_rng_tracker from megatron.mpu.utils import divide from megatron.mpu.utils import VocabUtility @@ -95,6 +98,7 @@ class TELinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer. """ + def __init__( self, neox_args, @@ -105,7 +109,7 @@ def __init__( stride=1, skip_bias_add=False, mup_rescale_parameters=False, - seq_dim=0, + seq_dim=0, ): self.input_size = input_size self.output_size = output_size @@ -120,16 +124,23 @@ def __init__( self.stride = stride self.mup_rescale_parameters = mup_rescale_parameters self.use_mup = neox_args.use_mup - self.params_dtype=neox_args.params_dtype + self.params_dtype = neox_args.params_dtype + + super(TELinear, self).__init__( + in_features=self.input_size, + out_features=self.output_size, + bias=self.use_bias, + init_method=self.init_method, + get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), + return_bias=self.skip_bias_add, + params_dtype=self.params_dtype, + ) - super(TELinear, self).__init__(in_features=self.input_size, out_features=self.output_size, - bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, - device=torch.cuda.current_device(), return_bias=self.skip_bias_add, params_dtype=self.params_dtype) - def forward(self, inp, **kwargs): if self.use_mup and self.mup_rescale_parameters: input_ /= self.width_mult() - + output = super(TELinear, self).forward(inp, **kwargs) if self.skip_bias_add: @@ -141,7 +152,7 @@ def forward(self, inp, **kwargs): class TELayerNormMLP(te.pytorch.LayerNormMLP): """ Wrapper for the Transformer-Engine's `LayerNormMLP` layer that combines - layernorm and followed by the MLP module, consisting of 2 successive + layernorm and followed by the MLP module, consisting of 2 successive linear transformations, separated by the GeLU activation. """ @@ -154,7 +165,7 @@ def __init__( multiple_of=256, MOE=False, MoE_mp_size=1, - bias=True + bias=True, ): self.activation_func, self.is_gated = get_activation(neox_args) self.activation_type = neox_args.activation @@ -169,10 +180,10 @@ def __init__( self.sequence_parallel = neox_args.sequence_parallel self.seq_len = neox_args.seq_length self.micro_batch_size = neox_args.train_micro_batch_size_per_gpu - self.params_dtype=neox_args.params_dtype - self.set_parallel_mode=False + self.params_dtype = neox_args.params_dtype + self.set_parallel_mode = False if world_size > 1: - self.set_parallel_mode=True + self.set_parallel_mode = True if neox_args.intermediate_size: ffn_dim = neox_args.intermediate_size @@ -197,25 +208,51 @@ def __init__( self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) ) - if neox_args.norm in ['layernorm','te_layernorm']: - self.eps=1.0e-5 - self.normalization = 'LayerNorm' - elif neox_args.norm in ['rmsnorm','te_rmsnorm']: - self.eps=1.0e-8 - self.normalization = 'RMSNorm' + if neox_args.norm in ["layernorm", "te_layernorm"]: + self.eps = 1.0e-5 + self.normalization = "LayerNorm" + elif neox_args.norm in ["rmsnorm", "te_rmsnorm"]: + self.eps = 1.0e-8 + self.normalization = "RMSNorm" else: - raise ValueError("Only LayerNorm and RMSNorm are supported with TransformerEngine") - - if self.activation_type not in ["gelu", "geglu", "relu", "reglu", "squared_relu","swiglu", "qgelu", "srelu"]: - raise ValueError("Only gelu, geglu, relu, reglu, squared_relu, swiglu, qgelu, and srelu are supported with TransformerEngine") + raise ValueError( + "Only LayerNorm and RMSNorm are supported with TransformerEngine" + ) - super(TELayerNormMLP, self).__init__(hidden_size=neox_args.hidden_size, ffn_hidden_size=ffn_dim, - eps=self.eps, bias=self.bias, normalization=self.normalization, activation=self.activation_type, - init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, - device=torch.cuda.current_device(), set_parallel_mode=self.set_parallel_mode, - sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, tp_size=self.world_size, - return_bias=True, params_dtype=self.params_dtype, seq_length=self.seq_len, get_rng_state_tracker=get_cuda_rng_tracker, - micro_batch_size=self.micro_batch_size) + if self.activation_type not in [ + "gelu", + "geglu", + "relu", + "reglu", + "squared_relu", + "swiglu", + "qgelu", + "srelu", + ]: + raise ValueError( + "Only gelu, geglu, relu, reglu, squared_relu, swiglu, qgelu, and srelu are supported with TransformerEngine" + ) + + super(TELayerNormMLP, self).__init__( + hidden_size=neox_args.hidden_size, + ffn_hidden_size=ffn_dim, + eps=self.eps, + bias=self.bias, + normalization=self.normalization, + activation=self.activation_type, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + device=torch.cuda.current_device(), + set_parallel_mode=self.set_parallel_mode, + sequence_parallel=self.sequence_parallel, + tp_group=self.tp_group, + tp_size=self.world_size, + return_bias=True, + params_dtype=self.params_dtype, + seq_length=self.seq_len, + get_rng_state_tracker=get_cuda_rng_tracker, + micro_batch_size=self.micro_batch_size, + ) class TEColumnParallelLinear(te.pytorch.Linear): @@ -255,7 +292,7 @@ def __init__( MOE=False, MoE_mp_size=1, mup_rescale_parameters=False, - seq_dim=0, + seq_dim=0, ): # Keep input parameters self.input_size = input_size @@ -276,14 +313,23 @@ def __init__( self.stride = stride self.mup_rescale_parameters = mup_rescale_parameters self.use_mup = neox_args.use_mup - self.params_dtype=neox_args.params_dtype - self.parallel_mode="column" - - super(TEColumnParallelLinear, self).__init__(in_features=self.input_size, out_features=self.output_size, - bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, - device=torch.cuda.current_device(), sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, - tp_size=self.world_size, parallel_mode=self.parallel_mode, return_bias=self.skip_bias_add, - params_dtype=self.params_dtype) + self.params_dtype = neox_args.params_dtype + self.parallel_mode = "column" + + super(TEColumnParallelLinear, self).__init__( + in_features=self.input_size, + out_features=self.output_size, + bias=self.use_bias, + init_method=self.init_method, + get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), + sequence_parallel=self.sequence_parallel, + tp_group=self.tp_group, + tp_size=self.world_size, + parallel_mode=self.parallel_mode, + return_bias=self.skip_bias_add, + params_dtype=self.params_dtype, + ) # Copied from Mup def width_mult(self): @@ -313,7 +359,7 @@ def _rescale_parameters(self): self.bias.data *= self.width_mult() ** 0.5 self.weight.data *= self.width_mult() ** 0.5 self._has_rescaled_params = True - + def mup_reinitialize_weights(self, neox_args): if neox_args.use_cpu_initialization: self.master_weight = _initialize_affine_weight_cpu( @@ -338,13 +384,14 @@ def mup_reinitialize_weights(self, neox_args): def forward(self, inp, **kwargs): if self.use_mup and self.mup_rescale_parameters: input_ /= self.width_mult() - + output = super(TEColumnParallelLinear, self).forward(inp, **kwargs) if self.skip_bias_add: return output else: return output, None + class TERowParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar @@ -367,6 +414,7 @@ class TERowParallelLinear(te.pytorch.Linear): can be fused with other elementwise operations. we skip adding bias but instead return it. """ + def __init__( self, neox_args, @@ -400,14 +448,23 @@ def __init__( self.stride = stride self.mup_rescale_parameters = mup_rescale_parameters self.use_mup = neox_args.use_mup - self.params_dtype=neox_args.params_dtype - self.parallel_mode="row" - - super(TERowParallelLinear, self).__init__(in_features=self.input_size, out_features=self.output_size, - bias= self.use_bias, init_method=self.init_method, get_rng_state_tracker=get_cuda_rng_tracker, - device=torch.cuda.current_device(), sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, - tp_size=self.world_size, parallel_mode=self.parallel_mode, return_bias=self.skip_bias_add, - params_dtype=self.params_dtype) + self.params_dtype = neox_args.params_dtype + self.parallel_mode = "row" + + super(TERowParallelLinear, self).__init__( + in_features=self.input_size, + out_features=self.output_size, + bias=self.use_bias, + init_method=self.init_method, + get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), + sequence_parallel=self.sequence_parallel, + tp_group=self.tp_group, + tp_size=self.world_size, + parallel_mode=self.parallel_mode, + return_bias=self.skip_bias_add, + params_dtype=self.params_dtype, + ) # Copied from Mup def width_mult(self): @@ -462,7 +519,7 @@ def mup_reinitialize_weights(self, neox_args): def forward(self, inp, **kwargs): if self.use_mup and self.mup_rescale_parameters: input_ /= self.width_mult() - + output = super(TERowParallelLinear, self).forward(inp, **kwargs) if self.skip_bias_add: @@ -470,13 +527,15 @@ def forward(self, inp, **kwargs): else: return output, None + class TEMultiheadAttention(te.pytorch.MultiheadAttention): """ Wrapper for the Transformer-Engine's `MultiheadAttention` layer that also has "flash attention" enabled. """ - def __init__(self, + def __init__( + self, neox_args, attention_mask_func, init_method, @@ -485,31 +544,32 @@ def __init__(self, rpe=None, rotary=False, use_cache=False, - parallel_output=False): + parallel_output=False, + ): self.neox_args = neox_args self.attention_mask_func = attention_mask_func self.init_method = init_method self.output_layer_init_method = output_layer_init_method self.layer_number = layer_number + 1 - + world_size = get_model_parallel_world_size() self.world_size = world_size self.tp_group = get_tensor_model_parallel_group() self.sequence_parallel = neox_args.sequence_parallel self.seq_len = neox_args.seq_length self.micro_batch_size = neox_args.train_micro_batch_size_per_gpu - self.params_dtype=neox_args.params_dtype - self.set_parallel_mode=False + self.params_dtype = neox_args.params_dtype + self.set_parallel_mode = False if world_size > 1: - self.set_parallel_mode=True + self.set_parallel_mode = True - if neox_args.norm in ['layernorm','te_layernorm']: - self.eps=1.0e-5 - self.normalization = 'LayerNorm' - elif neox_args.norm == ['rmsnorm','te_rmsnorm']: - self.eps=1.0e-8 - self.normalization = 'RMSNorm' + if neox_args.norm in ["layernorm", "te_layernorm"]: + self.eps = 1.0e-5 + self.normalization = "LayerNorm" + elif neox_args.norm == ["rmsnorm", "te_rmsnorm"]: + self.eps = 1.0e-8 + self.normalization = "RMSNorm" if ( not neox_args.num_kv_heads @@ -521,20 +581,36 @@ def __init__(self, self.gqa = True self.num_kv_heads = neox_args.num_kv_heads - super(TEMultiheadAttention, self).__init__(hidden_size=neox_args.hidden_size, num_attention_heads=neox_args.num_attention_heads, - attention_dropout=neox_args.attention_dropout, layernorm_epsilon=self.eps, init_method=self.init_method, - output_layer_init_method=self.output_layer_init_method, layer_number=self.layer_number, - window_size=neox_args.sliding_window_width, num_gqa_groups=self.num_kv_heads, input_layernorm=False, - normalization=self.normalization, bias=True, device=torch.cuda.current_device(),get_rng_state_tracker=get_cuda_rng_tracker, - set_parallel_mode=self.set_parallel_mode, sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, - tp_size=self.world_size, params_dtype=self.params_dtype, return_bias=True, qkv_format="sbhd", fuse_qkv_params=True) - - + super(TEMultiheadAttention, self).__init__( + hidden_size=neox_args.hidden_size, + num_attention_heads=neox_args.num_attention_heads, + attention_dropout=neox_args.attention_dropout, + layernorm_epsilon=self.eps, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + layer_number=self.layer_number, + window_size=neox_args.sliding_window_width, + num_gqa_groups=self.num_kv_heads, + input_layernorm=False, + normalization=self.normalization, + bias=True, + device=torch.cuda.current_device(), + get_rng_state_tracker=get_cuda_rng_tracker, + set_parallel_mode=self.set_parallel_mode, + sequence_parallel=self.sequence_parallel, + tp_group=self.tp_group, + tp_size=self.world_size, + params_dtype=self.params_dtype, + return_bias=True, + qkv_format="sbhd", + fuse_qkv_params=True, + ) if neox_args.pos_emb == "rotary": self.hidden_size_per_attention_head = mpu.divide( - neox_args.hidden_size, neox_args.num_attention_heads) - + neox_args.hidden_size, neox_args.num_attention_heads + ) + if neox_args.rotary_pct == 1: self.rotary_ndims = None else: @@ -552,12 +628,16 @@ def __init__(self, base=neox_args.rotary_emb_base, max_seq_len=neox_args.seq_length, precision=neox_args.params_dtype, - save_inv_freqs=neox_args.rotary_save_freqs_buffer + save_inv_freqs=neox_args.rotary_save_freqs_buffer, ) - self.rope_emb=self.rotary_embeddings.get_emb() + self.rope_emb = self.rotary_embeddings.get_emb() - def forward(self, hidden_states, attention_mask, layer_past=None, rope_emb=None, **kwargs): - output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, rotary_pos_emb=self.rope_emb, **kwargs) + def forward( + self, hidden_states, attention_mask, layer_past=None, rope_emb=None, **kwargs + ): + output = super(TEMultiheadAttention, self).forward( + hidden_states, attention_mask, rotary_pos_emb=self.rope_emb, **kwargs + ) return output @@ -565,14 +645,13 @@ class TEDelayedScaling(te.common.recipe.DelayedScaling): """ Wrapper for the Transformer-Engine's `DelayedScaling` layer. """ + ##TODO Test with H100 - def __init__( - self, - neox_args): + def __init__(self, neox_args): self.neox_args = neox_args self.tp_group = get_tensor_model_parallel_group() - + if neox_args.te_fp8_format == "e4m3": fp8_format = te.common.recipe.Format.E4M3 elif neox_args.te_fp8_format == "hybrid": @@ -595,6 +674,8 @@ def fp8_context(self): fp8_group = None if self.tp_group: fp8_group = self.tp_group - fp8_context = te.pytorch.fp8_autocast(enabled=True, fp8_recipe=self, fp8_group=fp8_group) - - return get_context \ No newline at end of file + fp8_context = te.pytorch.fp8_autocast( + enabled=True, fp8_recipe=self, fp8_group=fp8_group + ) + + return get_context diff --git a/megatron/model/utils.py b/megatron/model/utils.py index d39da6194..5515c41f5 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -406,12 +406,16 @@ def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): def get_parallel_linear(neox_args): if neox_args.te_columnparallel: - from megatron.model.transformer_engine import TEColumnParallelLinear as ColumnParallelLinear + from megatron.model.transformer_engine import ( + TEColumnParallelLinear as ColumnParallelLinear, + ) else: from megatron.mpu import ColumnParallelLinear if neox_args.te_rowparallel: - from megatron.model.transformer_engine import TERowParallelLinear as RowParallelLinear + from megatron.model.transformer_engine import ( + TERowParallelLinear as RowParallelLinear, + ) else: from megatron.mpu import RowParallelLinear - return ColumnParallelLinear, RowParallelLinear \ No newline at end of file + return ColumnParallelLinear, RowParallelLinear diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 791d33a37..039f9e4da 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -501,7 +501,7 @@ class NeoXArgsModel(NeoXArgsTemplate): """ Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ - + te_columnparallel: bool = False """ Use TransformerEngine for RowParallelLinear layer. @@ -555,7 +555,7 @@ class NeoXArgsModel(NeoXArgsTemplate): """ When set to True, use the FP8 implementation of Multi Head Attention. """ - + dim_att: int = None """ Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size. @@ -1169,7 +1169,6 @@ class NeoXArgsTraining(NeoXArgsTemplate): Beta value for KTO """ - allow_chopped: bool = True """ WARNING: if your packing impl is packed, this is ignored. From 7e7dbfbdd64b322578cc7930eb04760ef5359cc0 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Tue, 8 Oct 2024 12:36:56 -0700 Subject: [PATCH 24/27] add sample TE config --- configs/1-3B-transformer-engine.yml | 105 ++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 configs/1-3B-transformer-engine.yml diff --git a/configs/1-3B-transformer-engine.yml b/configs/1-3B-transformer-engine.yml new file mode 100644 index 000000000..079a5c31d --- /dev/null +++ b/configs/1-3B-transformer-engine.yml @@ -0,0 +1,105 @@ +# GPT-2 pretraining setup +{ + # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages + # across the node boundaries ) + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + # model settings + "num_layers": 24, + "hidden_size": 2048, + "num_attention_heads": 16, + "seq_length": 2048, + "max_position_embeddings": 2048, + "norm": "layernorm", + "pos_emb": "rotary", + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + + # Transformer Engine settings + "te_columnparallel": false, + "te_rowparallel": false, + "te_layernorm_mlp": true, + "te_mha": true, + "te_fp8_format": "hybrid", + "te_fp8_wgrad": true, + "te_fp8_amax_history_len": 1, + "te_fp8_amax_compute_algo": "most_recent", + "te_fp8_margin": 0, + "te_fp8_mha": false, + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00002, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 100, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, +} From 5757be678e66dccb709acfbd2b79198a787e31ea Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Tue, 8 Oct 2024 13:32:57 -0700 Subject: [PATCH 25/27] add te to readme --- README.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0d4e2939f..50619165c 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg * Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, monitor experiments via [WandB](https://wandb.ai/site)/[Comet](https://www.comet.com/site/)/TensorBoard, and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness). ## News +**[10/9/2024]** We now support [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) integration + **[9/9/2024]** We now support preference learning via [DPO](https://arxiv.org/abs/2305.18290), [KTO](https://arxiv.org/abs/2402.01306), and reward modeling **[9/9/2024]** We now support integration with [Comet ML](https://www.comet.com/site/), a machine learning monitoring platform @@ -60,6 +62,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA * [Environment and Dependencies](#environment-and-dependencies) + [Host Setup](#host-setup) + [Flash Attention](#flash-attention) + + [Transformer Engine](#transformer-engine) + [Multi-Node Launching](#multi-node-launching) + [Containerized Setup](#containerized-setup) * [Usage](#usage) @@ -130,7 +133,20 @@ This will automatically adapts building process over different GPU vendors (AMD, ### Flash Attention -To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details. +To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). Then set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details. + +### Transformer Engine + +To use [Transformer Engine (TE)](https://github.com/NVIDIA/TransformerEngine), install the additional dependencies in `./requirements/requirements-transformer-engine.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). See [this config](https://github.com/EleutherAI/gpt-neox/configs/1-3B-transformer-engine.yml) for an example of using TE on a 1.3B model. This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere and Hopper GPUs; see the repository for more details. + + +TE provides very efficient kernels for both A100 and H100 GPUs. We've run some sample ablations on A100: + + + +and H100: + + ### Multi-Node Launching From 9ea3dcf11b549500ab054a1e164d4c161bf1ea11 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Tue, 8 Oct 2024 13:33:34 -0700 Subject: [PATCH 26/27] remove pip install prefix from reqs file --- requirements/requirements-transformerengine.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-transformerengine.txt b/requirements/requirements-transformerengine.txt index 2050d7566..eb8fad4e5 100644 --- a/requirements/requirements-transformerengine.txt +++ b/requirements/requirements-transformerengine.txt @@ -1 +1 @@ -pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable +git+https://github.com/NVIDIA/TransformerEngine.git@stable From f3e40e9ee6b8e0e8ad20d19bc34255752d9a00c4 Mon Sep 17 00:00:00 2001 From: aurelion-source Date: Wed, 16 Oct 2024 20:54:41 +0000 Subject: [PATCH 27/27] Force TE pytorch in requirements file --- requirements/requirements-transformerengine.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-transformerengine.txt b/requirements/requirements-transformerengine.txt index eb8fad4e5..10a1f3b82 100644 --- a/requirements/requirements-transformerengine.txt +++ b/requirements/requirements-transformerengine.txt @@ -1 +1 @@ -git+https://github.com/NVIDIA/TransformerEngine.git@stable +transformer-engine[pytorch]