diff --git a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py index c7e9762fb801..74d5fecf6123 100644 --- a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py @@ -185,7 +185,7 @@ def __init__(self, config: ChatGLMv2Config, empty_init=True): if self.post_layer_norm: LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon) + self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config) def get_input_embeddings(self): return self.embedding.word_embeddings diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index bbfb6e52f481..ff920f83c6b1 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -12,19 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import math +from functools import partial from typing import Any, Dict, List, Optional, Tuple +import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.utils import recompute from paddle.utils import map_structure from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies from ...utils.converter import StateDictNameMapping, init_name_mappings -from .. import PretrainedModel, register_base_model +from .. import PretrainedModel, linear_utils, register_base_model from ..model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithPast, @@ -32,6 +37,20 @@ ) from .configuration import CHATGLM_V2_PRETRAINED_RESOURCE_FILES_MAP, ChatGLMv2Config +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + mark_as_sequence_parallel_parameter, + ) +except: + pass + +try: + from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd +except: + FusedDropoutAdd = None + __all__ = [ "ChatGLMv2Model", "ChatGLMv2PretrainedModel", @@ -39,6 +58,41 @@ ] +def seed_guard_context(name=None): + if ( + not isinstance(paddle.base.framework._current_expected_place(), paddle.core.CPUPlace) + and name in get_rng_state_tracker().states_ + ): + # todo fix it + # ValueError: Length of gpu state list should be equal to the gpu device count + # /usr/local/lib/python3.10/dist-packages/paddle/incubate/framework/random.py:119: ValueError + # return contextlib.nullcontext() + return get_rng_state_tracker().rng_state(name) + else: + return contextlib.nullcontext() + + +def parallel_matmul(lm_output, logit_weights, parallel_output): + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + world_size = hcg.get_model_parallel_world_size() + + if world_size > 1: + # _c_identity is backwards is reduce + input_parallel = paddle.distributed.collective._c_identity(lm_output, group=model_parallel_group) + + logits = paddle.matmul(input_parallel, logit_weights, transpose_y=False) + + if parallel_output: + return logits + + # _c_concat has not grad backwards + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + else: + logits = paddle.matmul(lm_output, logit_weights, transpose_y=False) + return logits + + class RotaryEmbedding(nn.Layer): def __init__(self, dim, original_impl=False): super().__init__() @@ -97,7 +151,7 @@ def apply_rotary_pos_emb(x: paddle.Tensor, rope_cache: paddle.Tensor) -> paddle. class RMSNorm(nn.Layer): - def __init__(self, hidden_size, epsilon=None): + def __init__(self, hidden_size, config: ChatGLMv2Config, epsilon=None): super().__init__() self.hidden_size = hidden_size self.weight = paddle.create_parameter( @@ -107,6 +161,9 @@ def __init__(self, hidden_size, epsilon=None): ) self.epsilon = 1e-5 if epsilon is None else epsilon + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.weight) + def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) @@ -131,14 +188,19 @@ def __init__(self, config: ChatGLMv2Config, layer_number): self.num_attention_heads_per_partition = config.num_attention_heads self.hidden_size_per_partition = config.kv_channels * self.num_attention_heads_per_partition self.hidden_size_per_attention_head = self.hidden_size_per_partition // self.num_attention_heads_per_partition - + self.tensor_parallel_degree = config.tensor_parallel_degree + if self.tensor_parallel_degree > 1: + assert ( + self.hidden_size_per_partition % self.tensor_parallel_degree == 0 + ), "hidden_size_per_partition % tensor_parallel_degree must be zero." + self.hidden_size_per_partition = self.hidden_size_per_partition // self.tensor_parallel_degree coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.coeff = coeff - + self.config = config self.attention_dropout = nn.Dropout(config.attention_dropout) def forward(self, query_layer, key_layer, value_layer, attention_mask): @@ -176,7 +238,8 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) + with seed_guard_context("local_seed"): + attention_probs = self.attention_dropout(attention_probs) # [batch_size, num_heads, query_length, key_length] # value_layer -> context layer. @@ -198,6 +261,10 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): new_context_shape = context_layer.shape[:-2] + [self.hidden_size_per_partition] context_layer = context_layer.reshape(new_context_shape) + if self.config.sequence_parallel: + sq, b, hp = context_layer.shape + context_layer = context_layer.reshape([sq * b, hp]) + return context_layer @@ -221,33 +288,83 @@ def __init__(self, config: ChatGLMv2Config, layer_number, device=None): self.num_multi_query_groups_per_partition = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num self.num_attention_heads_per_partition = config.num_attention_heads + self.config = config + self.seq_length = config.seq_length + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.tensor_parallel_degree = config.tensor_parallel_degree + self.sequence_parallel = config.sequence_parallel - self.query_key_value = nn.Linear( - config.hidden_size, - config.hidden_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num, - bias_attr=config.add_bias_linear or config.add_qkv_bias, - ) - # Output. - self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=config.add_bias_linear) + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + if config.tensor_parallel_degree > 1: + self.query_key_value = ColumnParallelLinear( + config.hidden_size, + config.hidden_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num, + has_bias=config.add_bias_linear or config.add_qkv_bias, + gather_output=False, + ) + self.dense = RowParallelLinear( + config.hidden_size, config.hidden_size, input_is_parallel=True, has_bias=config.add_bias_linear + ) + self.num_attention_heads_per_partition = config.num_attention_heads // config.tensor_parallel_degree + assert ( + self.num_multi_query_groups_per_partition % self.tensor_parallel_degree == 0 + ), "`multi_query_group_num` % `tensor_parallel_degree` must equal to `0`" + self.num_multi_query_groups_per_partition = ( + self.num_multi_query_groups_per_partition // self.tensor_parallel_degree + ) + else: + self.query_key_value = nn.Linear( + config.hidden_size, + config.hidden_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num, + bias_attr=config.add_bias_linear or config.add_qkv_bias, + ) + # Output. + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=config.add_bias_linear) - def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): - seq_length, batch_size, hidden_size = hidden_states.shape - mixed_x_layer = self.query_key_value(hidden_states) + def _core_attention(self, q, k, v, attention_mask=None, output_attentions=False): + outputs = self.core_attention(q, k, v, attention_mask) + return outputs + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, output_attentions=False + ): + # seq_length, batch_size = self.config.seq_length, hidden_states.shape[0]//self.config.seq_length + mixed_x_layer = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head * self.multi_query_group_num, - self.hidden_size_per_attention_head * self.multi_query_group_num, + self.hidden_size_per_attention_head * self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head * self.num_multi_query_groups_per_partition, ], axis=-1, ) - - query_layer = query_layer.reshape( - [seq_length, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head] - ) - key_layer = key_layer.reshape([seq_length, batch_size, -1, self.hidden_size_per_attention_head]) - value_layer = value_layer.reshape([seq_length, batch_size, -1, self.hidden_size_per_attention_head]) + if self.sequence_parallel: + query_layer = query_layer.reshape( + [self.seq_length, -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head] + ) + key_layer = key_layer.reshape( + [self.seq_length, -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] + ) + value_layer = value_layer.reshape( + [self.seq_length, -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] + ) + else: + query_layer = query_layer.reshape( + [0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head] + ) + key_layer = key_layer.reshape( + [0, 0, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] + ) + value_layer = value_layer.reshape( + [0, 0, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] + ) # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: @@ -278,13 +395,28 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, # ================================== # core attention computation # ================================== + attention_fuc = self._core_attention - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - + has_gradient = ( + (not query_layer.stop_gradient) or (not key_layer.stop_gradient) or (not value_layer.stop_gradient) + ) + if self.enable_recompute and self.config.recompute_granularity == "core_attn" and has_gradient: + context_layer = recompute( + attention_fuc, + query_layer, + key_layer, + value_layer, + attention_mask, + output_attentions, + use_reentrant=False, + ) + else: + context_layer = attention_fuc( + query_layer, key_layer, value_layer, attention_mask=attention_mask, output_attentions=output_attentions + ) # ================= # Output. [seq_length, b, h] # ================= - output = self.dense(context_layer) return output, kv_cache @@ -302,14 +434,28 @@ def __init__(self, config: ChatGLMv2Config): self.add_bias = config.add_bias_linear - # Project to 4h due to swiglu doubling the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear(config.hidden_size, config.ffn_hidden_size * 2, bias_attr=self.add_bias) - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias_attr=self.add_bias, - ) + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + if config.tensor_parallel_degree > 1: + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, config.ffn_hidden_size * 2, has_bias=self.add_bias, gather_output=False + ) + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, config.hidden_size, input_is_parallel=True, has_bias=self.add_bias + ) + else: + # Project to 4h due to swiglu doubling the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.ffn_hidden_size * 2, bias_attr=self.add_bias) + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias_attr=self.add_bias, + ) def forward(self, hidden_states): # [s, b, 4hp] @@ -336,19 +482,23 @@ def __init__(self, config: ChatGLMv2Config, layer_number): super(GLMBlock, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.config = config self.fp32_residual_connection = config.fp32_residual_connection LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon) + self.input_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config) # Self attention. self.self_attention = SelfAttention(config, layer_number) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon) + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, epsilon=config.layernorm_epsilon, config=config + ) # MLP self.mlp = MLP(config) @@ -366,10 +516,21 @@ def forward( # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) + has_gradient = not layernorm_output.stop_gradient # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache - ) + if self.enable_recompute and has_gradient and self.config.recompute_granularity == "full_attn": + attention_output, kv_cache = recompute( + self.self_attention, + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + ) + else: + attention_output, kv_cache = self.self_attention( + layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache + ) # Residual connection. if self.apply_residual_connection_post_layernorm: @@ -377,7 +538,10 @@ def forward( else: residual = hidden_states - layernorm_input = F.dropout(attention_output, p=self.hidden_dropout, training=self.training) + current_seed = "local_seed" if self.config.sequence_parallel else "global_seed" + + with seed_guard_context(current_seed): + layernorm_input = F.dropout(attention_output, p=self.hidden_dropout, training=self.training) layernorm_input = residual + layernorm_input # Layer norm post the self attention. @@ -392,7 +556,8 @@ def forward( else: residual = layernorm_input - output = F.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + with seed_guard_context(current_seed): + output = F.dropout(mlp_output, p=self.hidden_dropout, training=self.training) output = residual + output return output, kv_cache @@ -403,7 +568,10 @@ class GLMTransformer(nn.Layer): def __init__(self, config: ChatGLMv2Config): super(GLMTransformer, self).__init__() self.config = config + # Recompute defaults to False and is controlled by Trainer self.enable_recompute = False + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + self.recompute_granularity = config.recompute_granularity self.fp32_residual_connection = config.fp32_residual_connection self.post_layer_norm = config.post_layer_norm @@ -419,7 +587,7 @@ def build_layer(layer_number): if self.post_layer_norm: LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon) + self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config) def _get_layer(self, layer_number): return self.layers[layer_number] @@ -476,7 +644,12 @@ def forward( layer = self._get_layer(index) - if self.enable_recompute and not hidden_states.stop_gradient: + if ( + self.enable_recompute + and not hidden_states.stop_gradient + and index not in self.no_recompute_layers + and self.recompute_granularity == "full" + ): hidden_states, kv_cache = self.recompute_training( layer, hidden_states, @@ -546,6 +719,10 @@ def get_masks(self, input_ids, past_key_values, padding_mask=None): return casual_mask + def init_weights(self, layer): + """Initialization hook""" + return None + def get_position_ids(self, input_ids): batch_size, seq_length = input_ids.shape position_ids = paddle.arange(seq_length, dtype="int64").unsqueeze(0).tile([batch_size, 1]) @@ -610,23 +787,158 @@ def _get_name_mappings(cls, config: ChatGLMv2Config) -> List[StateDictNameMappin ] ) - for mapping in mappings: - mapping[0] = "transformer." + mapping[0] - if len(mapping) > 1 and mapping[1] is not None: - mapping[1] = "chatglm_v2." + mapping[1] - init_name_mappings(mappings) return [StateDictNameMapping(*mapping) for mapping in mappings] + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + def split_or_merge_qkv_weights(tensor_parallel_degree, tensor_parallel_rank, hidden_size, is_split, tensor): + if is_split: + return split_qkv_weights(tensor_parallel_degree, tensor_parallel_rank, hidden_size, tensor) + else: + assert ( + len(tensor) == tensor_parallel_degree + ), "The length of tensor_list must match tensor_parallel_degree" + return merge_qkv_weights(tensor_parallel_degree, hidden_size, tensor) + + def split_qkv_weights(tensor_parallel_degree, tensor_parallel_rank, hidden_size, tensor): + split_query_size = hidden_size // tensor_parallel_degree + split_kv_size = (tensor.shape[-1] - hidden_size) // (2 * tensor_parallel_degree) + + query = tensor[..., :hidden_size] + key = tensor[..., hidden_size : hidden_size + split_kv_size * tensor_parallel_degree] + value = tensor[..., tensor.shape[-1] - split_kv_size * tensor_parallel_degree :] + + key_part = key[..., tensor_parallel_rank * split_kv_size : (tensor_parallel_rank + 1) * split_kv_size] + value_part = value[..., tensor_parallel_rank * split_kv_size : (tensor_parallel_rank + 1) * split_kv_size] + query_part = query[ + ..., tensor_parallel_rank * split_query_size : (tensor_parallel_rank + 1) * split_query_size + ] + + return paddle.concat([query_part, key_part, value_part], axis=-1) + + def merge_qkv_weights(tensor_parallel_degree, hidden_size, tensor): + split_query_size = hidden_size // tensor_parallel_degree + split_kv_size = (tensor[0].shape[-1] - split_query_size) // 2 + merge_q = tensor[0][..., :split_query_size] + merge_k = tensor[0][..., split_query_size : split_query_size + split_kv_size] + merge_v = tensor[0][..., split_query_size + split_kv_size :] + is_ndarry = isinstance(tensor[0], np.ndarray) + for i in range(1, tensor_parallel_degree): + if is_ndarry: + merge_q = np.concatenate([merge_q, tensor[i][..., :split_query_size]], axis=-1) + merge_k = np.concatenate( + [merge_k, tensor[i][..., split_query_size : split_query_size + split_kv_size]], axis=-1 + ) + merge_v = np.concatenate([merge_v, tensor[i][..., split_query_size + split_kv_size :]], axis=-1) + else: + merge_q = paddle.concat([merge_q, tensor[i][..., :split_query_size]], axis=-1) + merge_k = paddle.concat( + [merge_k, tensor[i][..., split_query_size : split_query_size + split_kv_size]], axis=-1 + ) + merge_v = paddle.concat([merge_v, tensor[i][..., split_query_size + split_kv_size :]], axis=-1) + if is_ndarry: + return np.concatenate([merge_q, merge_k, merge_v], axis=-1) + else: + return paddle.concat([merge_q, merge_k, merge_v], axis=-1) + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def split_or_merge_mlp_weights(tensor_parallel_degree, tensor_parallel_rank, is_split, tensor): + if is_split: + return split_mlp_weights(tensor_parallel_degree, tensor_parallel_rank, tensor) + else: + assert ( + len(tensor) == tensor_parallel_degree + ), "The length of tensor_list must match tensor_parallel_degree" + return merge_mlp_weights(tensor_parallel_degree, tensor) + + def split_mlp_weights(tensor_parallel_degree, tensor_parallel_rank, tensor): + split_size = tensor.shape[-1] // tensor_parallel_degree // 2 + ffn_fc = tensor[..., : tensor.shape[-1] // 2] + gate = tensor[..., tensor.shape[-1] // 2 :] + ffn_fc_part = ffn_fc[..., tensor_parallel_rank * split_size : (tensor_parallel_rank + 1) * split_size] + gate_part = gate[..., tensor_parallel_rank * split_size : (tensor_parallel_rank + 1) * split_size] + return paddle.concat([ffn_fc_part, gate_part], axis=-1) + + def merge_mlp_weights(tensor_parallel_degree, tensor): + split_size = tensor[0].shape[-1] // 2 + merge_ffn_fc = tensor[0][..., :split_size] + merge_gate = tensor[0][..., split_size:] + is_ndarry = isinstance(tensor[0], np.ndarray) + for i in range(1, tensor_parallel_degree): + if is_ndarry: + merge_ffn_fc = np.concatenate([merge_ffn_fc, tensor[i][..., :split_size]], axis=-1) + merge_gate = np.concatenate([merge_gate, tensor[i][..., split_size:]], axis=-1) + else: + merge_ffn_fc = paddle.concat([merge_ffn_fc, tensor[i][..., :split_size]], axis=-1) + merge_gate = paddle.concat([merge_gate, tensor[i][..., split_size:]], axis=-1) + if is_ndarry: + return np.concatenate([merge_ffn_fc, merge_gate], axis=-1) + else: + return paddle.concat([merge_ffn_fc, merge_gate], axis=-1) + + def get_tensor_parallel_split_mappings(num_hidden_layers): + final_actions = {} + base_actions = { + # Column Linear + "output_layer.weight": partial(fn, is_column=True), + "encoder.layers.0.mlp.dense_h_to_4h.weight": partial( + split_or_merge_mlp_weights, config.tensor_parallel_degree, config.tensor_parallel_rank, is_split + ), + "encoder.layers.0.self_attention.query_key_value.bias": partial( + split_or_merge_qkv_weights, + config.tensor_parallel_degree, + config.tensor_parallel_rank, + config.hidden_size, + is_split, + ), + "encoder.layers.0.self_attention.query_key_value.weight": partial( + split_or_merge_qkv_weights, + config.tensor_parallel_degree, + config.tensor_parallel_rank, + config.hidden_size, + is_split, + ), + # Row Linear + "embedding.word_embeddings.weight": partial(fn, is_column=False), + "encoder.layers.0.self_attention.dense.weight": partial(fn, is_column=False), + "encoder.layers.0.mlp.dense_4h_to_h.weight": partial(fn, is_column=False), + } + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_hidden_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + class Embedding(nn.Layer): """Language model embeddings.""" def __init__(self, config: ChatGLMv2Config): super(Embedding, self).__init__() - + self.config = config self.hidden_size = config.hidden_size - self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size) + if config.tensor_parallel_degree > 1: + self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding( + config.padded_vocab_size, self.hidden_size + ) + else: + self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size) self.fp32_residual_connection = config.fp32_residual_connection def forward(self, input_ids): @@ -646,6 +958,7 @@ class ChatGLMv2Model(ChatGLMv2PretrainedModel): def __init__(self, config: ChatGLMv2Config, empty_init=True): super().__init__(config) self.embedding = Embedding(config) + self.config = config # Rotary positional embeddings self.max_sequence_length = config.max_sequence_length @@ -662,7 +975,13 @@ def __init__(self, config: ChatGLMv2Config, empty_init=True): else: self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2) self.encoder = GLMTransformer(config) - self.output_layer = nn.Linear(config.hidden_size, config.padded_vocab_size, bias_attr=False) + if config.tensor_parallel_degree > 1: + self.output_layer = nn.Linear( + config.hidden_size, config.padded_vocab_size // config.tensor_parallel_degree, bias_attr=False + ) + else: + self.output_layer = nn.Linear(config.hidden_size, config.padded_vocab_size, bias_attr=False) + self.apply(self.init_weights) def get_input_embeddings(self): return self.embedding.word_embeddings @@ -692,6 +1011,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) + if self.config.sequence_parallel: + seq_length, batch_size, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [batch_size * seq_length, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Rotary positional embeddings @@ -730,11 +1055,88 @@ def forward( ) +class ChatGLMv2PretrainingCriterion(nn.Layer): + """ + Criterion for ChatGLMv2. It calculates the final loss. + """ + + def __init__(self, config): + super(ChatGLMv2PretrainingCriterion, self).__init__() + self.config = config + if config.tensor_parallel_degree > 1 and config.tensor_parallel_output: + self.loss_func = fleet.meta_parallel.ParallelCrossEntropy() + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") + + def forward(self, prediction_scores, masked_lm_labels): + """ + Args: + prediction_scores(Tensor): + The logits of masked token prediction. Its data type should be float32 and + its shape is [batch_size, sequence_length, vocab_size]. + masked_lm_labels(Tensor): + The labels of the masked language modeling, the dimensionality of `masked_lm_labels` + is equal to `prediction_scores`. Its data type should be int64 and + its shape is [batch_size, sequence_length, 1]. + loss_mask(Tensor): + Mask used for calculating the loss of the masked language modeling to avoid + calculating some unwanted tokens. + Its data type should be float32 and its shape is [batch_size, sequence_length, 1]. + + Returns: + Tensor: The pretraining loss. Its data type should be float32 and its shape is [1]. + + """ + with paddle.amp.auto_cast(False): + loss_mask = (masked_lm_labels != -100).astype("float32") + reshaped_logits = prediction_scores.reshape([-1, prediction_scores.shape[-1]]).astype("float32") + reshaped_labels = masked_lm_labels.reshape([-1]) + loss = self.loss_func(reshaped_logits, reshaped_labels) + loss = paddle.sum(loss.reshape([-1]).cast(paddle.float32) * loss_mask.reshape([-1]).cast(paddle.float32)) + loss = loss / loss_mask.sum() + return loss + + +class Chatglmv2LMHead(nn.Layer): + def __init__(self, config: ChatGLMv2Config, embedding_weights=None): + super(Chatglmv2LMHead, self).__init__() + if embedding_weights is not None: + self.decoder_weight = embedding_weights + else: + if config.tensor_parallel_degree > 1: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + if vocab_size != config.vocab_size: + with get_rng_state_tracker().rng_state(): + self.decoder_weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + else: + self.decoder_weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], dtype=paddle.get_default_dtype() + ) + self.config = config + + def forward(self, hidden_states, return_last_logit=False): + if return_last_logit: + hidden_states = hidden_states[-1:] + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = paddle.reshape_(hidden_states, [self.config.seq_length, -1, self.config.hidden_size]) + logits = parallel_matmul(hidden_states, self.decoder_weight, self.config.tensor_parallel_output) + return logits.transpose([1, 0, 2]) + + class ChatGLMv2ForCausalLM(ChatGLMv2PretrainedModel): def __init__(self, config: ChatGLMv2Config): super().__init__(config) self.max_sequence_length = config.max_sequence_length self.chatglm_v2 = ChatGLMv2Model(config) + self.criterion = ChatGLMv2PretrainingCriterion(config) + self.config = config def reorder_cache(self, cache: paddle.Tensor, beam_idx): cache = map_structure(lambda x: paddle.index_select(x, beam_idx, axis=1), cache) @@ -826,23 +1228,23 @@ def forward( hidden_states = transformer_outputs[0] + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + seq_length = self.config.seq_length + hidden_states = hidden_states.reshape([seq_length, -1, self.config.hidden_size]) if return_last_logit: hidden_states = hidden_states[-1:] - lm_logits = self.chatglm_v2.output_layer(hidden_states) + if self.config.tensor_parallel_degree > 1: + lm_logits = parallel_matmul( + hidden_states, self.chatglm_v2.output_layer.weight, self.config.tensor_parallel_output + ) + else: + lm_logits = self.chatglm_v2.output_layer(hidden_states) lm_logits = lm_logits.transpose([1, 0, 2]) - + # shape = [batch_size, seq_length, vocab_size] loss = None if labels is not None: - reshaped_logits = lm_logits.reshape([-1, lm_logits.shape[-1]]).astype("float32") - reshaped_labels = labels.reshape([-1]) - - loss_fn = nn.CrossEntropyLoss(reduction="none") - - loss_mask = (labels != -100).astype("float32") - loss = loss_fn(reshaped_logits, reshaped_labels) - loss = paddle.sum(loss.reshape([-1]).cast(paddle.float32) * loss_mask.reshape([-1]).cast(paddle.float32)) - loss = loss / loss_mask.sum() - + loss = self.criterion(lm_logits, labels) lm_logits = lm_logits.astype(hidden_states.dtype) loss = loss.astype(hidden_states.dtype)