diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index d0df32321e18..7196f52eea6d 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -46,6 +46,7 @@ ) from paddlenlp.utils.batch_sampler import DistributedBatchSampler from paddlenlp.utils.log import logger +from paddlenlp.utils.tools import get_env_device def add_start_docstrings(*docstr): @@ -483,6 +484,16 @@ def main(): config.num_attention_heads % config.sep_parallel_degree == 0 ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + # It's OK, not use accumulate_steps optimization + pass + print("Final pre-training config:", config) # Set the dtype for loading model diff --git a/paddlenlp/transformers/linear_utils.py b/paddlenlp/transformers/linear_utils.py new file mode 100644 index 000000000000..de1a0f886b79 --- /dev/null +++ b/paddlenlp/transformers/linear_utils.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file is used for replacing Paddle's native Linear implementations with vendors' customized implementations +""" + +import paddle.distributed.fleet.meta_parallel as mpu +from paddle import nn +from paddle.distributed.fleet.utils import sequence_parallel_utils + +from paddlenlp.transformers.mc2_parallel_linear import ( + MC2ColumnSeqParallelLinear, + MC2RowSeqParallelLinear, +) +from paddlenlp.utils.tools import get_env_device + +Linear = nn.Linear +ColumnParallelLinear = mpu.ColumnParallelLinear +RowParallelLinear = mpu.RowParallelLinear +ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear +RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear + +if get_env_device() == "npu": + if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: + ColumnSequenceParallelLinear = MC2ColumnSeqParallelLinear + RowSequenceParallelLinear = MC2RowSeqParallelLinear +elif get_env_device() == "xpu": + try: + from paddle_xpu.layers.nn import ColumnParallelLinear as XPUColumnParallelLinear + from paddle_xpu.layers.nn import Linear as XPULinear + from paddle_xpu.layers.nn import RowParallelLinear as XPURowParallelLinear + from paddle_xpu.layers.nn.sequence_parallel import ( + XPUColumnSequenceParallelLinear, + XPURowSequenceParallelLinear, + ) + + Linear = XPULinear + ColumnParallelLinear = XPUColumnParallelLinear + RowParallelLinear = XPURowParallelLinear + ColumnSequenceParallelLinear = XPUColumnSequenceParallelLinear + RowSequenceParallelLinear = XPURowSequenceParallelLinear + except ImportError: + # If paddle_xpu is not installed, just use Paddle's native Linear implementations + pass +else: + # By default, use Paddle's native Linear implementations + pass diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 3efbb9de89a1..aee1313ba8df 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -62,10 +62,6 @@ def swiglu(x, y=None): init_name_mappings, ) from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies -from paddlenlp.transformers.mc2_parallel_linear import ( - MC2ColumnSeqParallelLinear, - MC2RowSeqParallelLinear, -) from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -74,6 +70,8 @@ def swiglu(x, y=None): from paddlenlp.utils.log import logger from paddlenlp.utils.tools import get_env_device +from .. import linear_utils +from ..linear_utils import Linear from ..segment_parallel_utils import ReshardLayer from .configuration import ( LLAMA_PRETRAINED_INIT_CONFIGURATION, @@ -410,6 +408,15 @@ def forward(self, hidden_states): if self.config.use_fused_rms_norm: if get_env_device() == "npu": return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0] + elif get_env_device() == "xpu": + try: + import paddle_xpu_nn # noqa: F821 + + return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + except ImportError: + raise NotImplementedError( + f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" + ) return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) if paddle.in_dynamic_mode(): @@ -571,15 +578,11 @@ def __init__(self, config): self.fuse_attention_ffn = config.fuse_attention_ffn if config.sequence_parallel: - if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: - ColumnParallelLinear = MC2ColumnSeqParallelLinear - RowParallelLinear = MC2RowSeqParallelLinear - else: - ColumnParallelLinear = ColumnSequenceParallelLinear - RowParallelLinear = RowSequenceParallelLinear + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear else: - ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear - RowParallelLinear = fleet.meta_parallel.RowParallelLinear + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear if config.tensor_parallel_degree > 1: if config.fuse_attention_ffn: @@ -611,15 +614,29 @@ def __init__(self, config): ) else: if config.fuse_attention_ffn: - self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) else: - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) def forward(self, x): if self.fuse_attention_ffn: + # FIXME(yangjianbang): use paddle's native swiglu + if get_env_device() == "xpu": + try: + import paddle_xpu_nn # noqa: F821 + + out = self.gate_up_fused_proj(x) + out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True) + out = self.down_proj(out) + return out + except ImportError: + gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) + out = self.down_proj(F.silu(gate_out) * up_out) + return out + x = swiglu(self.gate_up_fused_proj(x)) else: x = swiglu(self.gate_proj(x), self.up_proj(x)) @@ -680,7 +697,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): ) self.use_fused_rope = config.use_fused_rope - if self.use_fused_rope and get_env_device() != "npu": + if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]: if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: warnings.warn( "Enable fuse rope in the config, but fuse rope is not available. " @@ -689,15 +706,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): self.use_fused_rope = False if config.sequence_parallel: - if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: - ColumnParallelLinear = MC2ColumnSeqParallelLinear - RowParallelLinear = MC2RowSeqParallelLinear - else: - ColumnParallelLinear = ColumnSequenceParallelLinear - RowParallelLinear = RowSequenceParallelLinear + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear else: - ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear - RowParallelLinear = fleet.meta_parallel.RowParallelLinear + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear if config.tensor_parallel_degree > 1: if self.fuse_attention_qkv: @@ -728,12 +741,12 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): gather_output=False, ) else: - self.k_proj = nn.Linear( + self.k_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) - self.v_proj = nn.Linear( + self.v_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, @@ -741,23 +754,23 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): else: if self.fuse_attention_qkv: - self.qkv_proj = nn.Linear( + self.qkv_proj = Linear( self.hidden_size, self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) else: - self.q_proj = nn.Linear( + self.q_proj = Linear( self.hidden_size, self.hidden_size, bias_attr=False, ) - self.k_proj = nn.Linear( + self.k_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) - self.v_proj = nn.Linear( + self.v_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, @@ -771,7 +784,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): input_is_parallel=True, ) else: - self.o_proj = nn.Linear( + self.o_proj = Linear( self.hidden_size, self.hidden_size, bias_attr=False, @@ -1419,6 +1432,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16") expanded_attn_mask = expanded_attn_mask.astype("float16") expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) + elif get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype=dtype) + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) + expanded_attn_mask = expanded_attn_mask.astype(dtype) + expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) else: expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) return expanded_attn_mask @@ -1698,6 +1716,15 @@ def __init__(self, config: LlamaConfig): self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False if self.weight.is_distributed: self.weight.split_axis = 1 + if get_env_device() == "xpu": + try: + from paddle_xpu.layers.nn import ( # noqa: F401 + parallel_matmul as xpu_parallel_matmul, + ) + + self.xpu_parallel_matmul = xpu_parallel_matmul() + except ImportError: + self.xpu_parallel_matmul = None def forward(self, hidden_states, tensor_parallel_output=None): if self.config.sequence_parallel: @@ -1711,7 +1738,12 @@ def forward(self, hidden_states, tensor_parallel_output=None): if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output - logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None: + logits = self.xpu_parallel_matmul( + hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training + ) + else: + logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) return logits