Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import math
import os
import warnings
from functools import partial
from typing import Optional, Tuple
Expand Down Expand Up @@ -82,6 +83,17 @@ def swiglu(x, y=None):
]


def enable_fuse_ffn_qkv_pass():
if os.getenv("FLAGS_enable_fused_ffn_qkv_pass") in [
"True",
"true",
"1",
]:
return True
else:
return False


def is_pp_enable():
mesh = fleet.auto.get_mesh()
return "pp" in mesh.dim_names
Expand Down Expand Up @@ -221,7 +233,7 @@ def __init__(self, config, ipp: Optional[int] = None):
self.ipp = ipp
self.config = config

if config.fuse_attention_ffn:
if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass():
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
self.gate_up_fused_proj.weight = dist.shard_tensor(
self.gate_up_fused_proj.weight,
Expand Down Expand Up @@ -251,7 +263,7 @@ def __init__(self, config, ipp: Optional[int] = None):
)

def forward(self, x):
if self.fuse_attention_ffn:
if self.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass():
x = swiglu(self.gate_up_fused_proj(x))
else:
x = swiglu(self.gate_proj(x), self.up_proj(x))
Expand Down Expand Up @@ -298,7 +310,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
)
self.use_fused_rope = False

if self.fuse_attention_qkv:
if self.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass():
self.qkv_proj = nn.Linear(
self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
Expand Down Expand Up @@ -412,7 +424,7 @@ def forward(
[dist.Shard(1), dist.Replicate()],
)

if self.fuse_attention_qkv:
if self.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass():
target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim]
mix_layer = self.qkv_proj(hidden_states)
mix_layer = paddle.reshape_(mix_layer, target_shape)
Expand Down Expand Up @@ -760,7 +772,7 @@ def get_tensor_parallel_split_mappings(num_layers):
}

# Column Linear
if config.fuse_attention_qkv:
if config.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass():
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
else:
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
Expand All @@ -769,7 +781,7 @@ def get_tensor_parallel_split_mappings(num_layers):
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)

if config.fuse_attention_ffn:
if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass():
base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
fn, is_column=True, is_naive_2fuse=True
)
Expand Down