Skip to content
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a5ed9ed
update
Galaxy1458 May 9, 2024
8ebdcfa
Merge branch 'develop' of https://github.com/Galaxy1458/PaddleNLP int…
Galaxy1458 May 9, 2024
bd0aa87
add llama-npu-opt-script
Galaxy1458 May 9, 2024
ce921ab
Merge branch 'PaddlePaddle:develop' into develop
Galaxy1458 May 9, 2024
cc24132
Update dev_opt_lora.sh
Galaxy1458 May 9, 2024
036d03c
Update dev_opt_ppt.sh
Galaxy1458 May 9, 2024
8dd2d02
Update dev_opt_lora.sh
Galaxy1458 May 9, 2024
96e69aa
Update dev_opt_ppt.sh
Galaxy1458 May 9, 2024
a35ba59
Update dev_opt_sft.sh
Galaxy1458 May 9, 2024
68388a7
Rename dev_opt_lora.sh to llama_npu_opt_lora.sh
Galaxy1458 May 11, 2024
fee8f04
Update dev_opt_ppt.sh
Galaxy1458 May 11, 2024
783de3b
Rename dev_opt_ppt.sh to llama_npu_opt_ppt.sh
Galaxy1458 May 11, 2024
10f9415
Update llama_npu_opt_lora.sh
Galaxy1458 May 11, 2024
f3d96e5
Update and rename dev_opt_sft.sh to llama_npu_opt_sft.sh
Galaxy1458 May 11, 2024
e51cc9a
Merge branch 'PaddlePaddle:develop' into develop
Galaxy1458 May 13, 2024
6771aa9
add funsion ops
Galaxy1458 May 13, 2024
61dc79c
add funsion ops
Galaxy1458 May 13, 2024
558200f
add funsion ops
Galaxy1458 May 13, 2024
f387c30
add funsion ops
Galaxy1458 May 13, 2024
a12947b
add funsion ops
Galaxy1458 May 13, 2024
aff105e
add funsion ops
Galaxy1458 May 13, 2024
075c8de
add funsion ops
Galaxy1458 May 13, 2024
15f2fe3
add funsion ops
Galaxy1458 May 13, 2024
2741769
add funsion ops
Galaxy1458 May 13, 2024
12fc048
add funsion ops
Galaxy1458 May 13, 2024
f678361
add funsion ops
Galaxy1458 May 13, 2024
9b2ca6b
add funsion ops
Galaxy1458 May 13, 2024
cac0f8e
add funsion ops
Galaxy1458 May 13, 2024
73866a2
add funsion ops
Galaxy1458 May 13, 2024
d8f1950
add funsion ops
Galaxy1458 May 13, 2024
9a2f1c5
add funsion ops
Galaxy1458 May 13, 2024
df78b71
update
Galaxy1458 May 14, 2024
8c3cd0d
Update fusion_ops.py
Galaxy1458 May 14, 2024
0a6d6b8
update
Galaxy1458 May 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions paddlenlp/transformers/fusion_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
import paddle.nn.functional as F

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
fused_rotary_position_embedding = None

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:

def swiglu(x, y=None):
if y is None:
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y


from paddle.utils import try_import

from paddlenlp.utils.tools import get_env_device

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
fused_rotary_position_embedding = None
try:
if get_env_device() == "npu":
from paddle.base import core

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
from paddle.nn.functional.flash_attention import flash_attention
except:
flash_attention = None


def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fusion_rope、fusion_flash_attention这种太长了就不建议去抽取了

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经将paddlenlp/transformers/fusion_ops.py 移动到paddlenlp/transformers/llama/fusion_ops.py

assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
else:
# paddle version > 2.6 or develop support q and k/v with different num_heads
paddle_version = float(paddle.__version__[:3])
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
query_states, _, _ = fused_rotary_position_embedding(
query_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
key_states, _, _ = fused_rotary_position_embedding(
key_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
query_states, key_states, _ = fused_rotary_position_embedding(
query_states,
key_states,
v=None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
return query_states, key_states


def rms_norm_fused(x_in, w, eps):
fused_ln = try_import("fused_ln")
return fused_ln.fused_rms_norm(x_in, w, eps)[0]


def fusion_rms_norm(hidden_states, weight, variance_epsilon):
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0]
except ImportError:
raise NotImplementedError(
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
)
return rms_norm_fused(hidden_states, weight, variance_epsilon)


def fusion_flash_attention(
query_states,
config,
key_states,
value_states,
attention_mask,
output_attentions,
alibi=None,
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
version = paddle.version.full_version
if version != "0.0.0" and version <= "2.5.2":
if alibi is not None:
raise ValueError("Flash Attention doesn't support alibi")
attn_output, attn_weights = flash_attention(
query_states,
key_states,
value_states,
causal=True,
return_softmax=output_attentions,
)
else:
if alibi is not None:
alibi = alibi.reshape([bsz, num_heads, 1, -1])
attention_mask = attention_mask.cast(alibi.dtype) + alibi
if get_env_device() == "npu":
attn_output = core.eager._run_custom_op(
"flash_attention_npu",
query_states,
key_states,
value_states,
None,
attention_mask,
0.0,
attention_mask is None,
True,
False,
npu_is_casual,
)[0]
else:
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None,
)
attn_weights = None

if reshard_layer is not None:
# attn_output shape: [bs, seqlen, num_head/sep, head_dim]
attn_output = reshard_layer(
attn_output,
split_axis=1,
concat_axis=2,
)
# attn_output shape: [bs, seqlen/sep, num_head, head_dim]
assert (
config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
q_len = q_len // config.sep_parallel_degree
num_heads = num_heads * config.sep_parallel_degree

if sequence_parallel:
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
else:
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
return (attn_output, attn_weights) if output_attentions else attn_output
143 changes: 27 additions & 116 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def swiglu(x, y=None):
)
except:
pass
from paddle.utils import try_import

from paddlenlp.transformers.conversion_utils import (
StateDictNameMapping,
Expand All @@ -81,14 +80,16 @@ def swiglu(x, y=None):

try:
if get_env_device() == "npu":
from paddle.base import core

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意看是不是有不需要的代码,注意删除掉。

from paddle.nn.functional.flash_attention import flash_attention
except:
flash_attention = None
from .. import fusion_ops

rms_norm_fused = fusion_ops.rms_norm_fused

__all__ = [
"LlamaModel",
Expand Down Expand Up @@ -215,67 +216,22 @@ def scaled_dot_product_attention(
_, kv_seq_len, _, _ = value_states.shape

if config.use_flash_attention and flash_attention:
return fusion_ops.fusion_flash_attention(
query_states,
config,
key_states,
value_states,
attention_mask,
output_attentions,
alibi,
sequence_parallel,
reshard_layer,
npu_is_casual,
)

# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]

version = paddle.version.full_version
if version != "0.0.0" and version <= "2.5.2":
if alibi is not None:
raise ValueError("Flash Attention doesn't support alibi")
attn_output, attn_weights = flash_attention(
query_states,
key_states,
value_states,
causal=True,
return_softmax=output_attentions,
)
else:
if alibi is not None:
alibi = alibi.reshape([bsz, num_heads, 1, -1])
attention_mask = attention_mask.cast(alibi.dtype) + alibi
if get_env_device() == "npu":
attn_output = core.eager._run_custom_op(
"flash_attention_npu",
query_states,
key_states,
value_states,
None,
attention_mask,
0.0,
attention_mask is None,
True,
False,
npu_is_casual,
)[0]
else:
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None,
)
attn_weights = None

if reshard_layer is not None:
# attn_output shape: [bs, seqlen, num_head/sep, head_dim]
attn_output = reshard_layer(
attn_output,
split_axis=1,
concat_axis=2,
)
# attn_output shape: [bs, seqlen/sep, num_head, head_dim]
assert (
config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
q_len = q_len // config.sep_parallel_degree
num_heads = num_heads * config.sep_parallel_degree

if sequence_parallel:
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
else:
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
return (attn_output, attn_weights) if output_attentions else attn_output
else:
# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
query_states = paddle.transpose(query_states, [0, 2, 1, 3])
Expand Down Expand Up @@ -385,11 +341,6 @@ def _expand_2d_mask(mask, dtype, tgt_length):
return expanded_mask


def rms_norm_fused(x_in, w, eps):
fused_ln = try_import("fused_ln")
return fused_ln.fused_rms_norm(x_in, w, eps)[0]


class LlamaRMSNorm(nn.Layer):
def __init__(self, config):
super().__init__()
Expand All @@ -407,18 +358,7 @@ def __init__(self, config):

def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
except ImportError:
raise NotImplementedError(
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
)
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)
return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
with paddle.amp.auto_cast(False):
Expand Down Expand Up @@ -974,45 +914,16 @@ def forward(
batch_size, seq_length, _, _ = query_states.shape
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
if self.use_fused_rope:
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
else:
# paddle version > 2.6 or develop support q and k/v with different num_heads
paddle_version = float(paddle.__version__[:3])
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
query_states, _, _ = fused_rotary_position_embedding(
query_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
key_states, _, _ = fused_rotary_position_embedding(
key_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
query_states, key_states, _ = fused_rotary_position_embedding(
query_states,
key_states,
v=None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
query_states, key_states = fusion_ops.fusion_rope(
query_states,
key_states,
value_states,
hidden_states,
position_ids,
past_key_value,
self.rotary_emb,
)

else:
if self.config.use_long_sequence_strategies:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)
Expand Down