diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index b78f58284e24..c194906178d0 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -16,6 +16,7 @@ from __future__ import annotations import math +import os import warnings from functools import partial from typing import Optional, Tuple @@ -82,6 +83,17 @@ def swiglu(x, y=None): ] +def enable_fuse_ffn_qkv_pass(): + if os.getenv("FLAGS_enable_fused_ffn_qkv_pass") in [ + "True", + "true", + "1", + ]: + return True + else: + return False + + def is_pp_enable(): mesh = fleet.auto.get_mesh() return "pp" in mesh.dim_names @@ -221,7 +233,7 @@ def __init__(self, config, ipp: Optional[int] = None): self.ipp = ipp self.config = config - if config.fuse_attention_ffn: + if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass(): self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) self.gate_up_fused_proj.weight = dist.shard_tensor( self.gate_up_fused_proj.weight, @@ -251,7 +263,7 @@ def __init__(self, config, ipp: Optional[int] = None): ) def forward(self, x): - if self.fuse_attention_ffn: + if self.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass(): x = swiglu(self.gate_up_fused_proj(x)) else: x = swiglu(self.gate_proj(x), self.up_proj(x)) @@ -298,7 +310,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp: ) self.use_fused_rope = False - if self.fuse_attention_qkv: + if self.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass(): self.qkv_proj = nn.Linear( self.hidden_size, self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, @@ -412,7 +424,7 @@ def forward( [dist.Shard(1), dist.Replicate()], ) - if self.fuse_attention_qkv: + if self.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass(): target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim] mix_layer = self.qkv_proj(hidden_states) mix_layer = paddle.reshape_(mix_layer, target_shape) @@ -760,7 +772,7 @@ def get_tensor_parallel_split_mappings(num_layers): } # Column Linear - if config.fuse_attention_qkv: + if config.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass(): base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) else: base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) @@ -769,7 +781,7 @@ def get_tensor_parallel_split_mappings(num_layers): base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) - if config.fuse_attention_ffn: + if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass(): base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial( fn, is_column=True, is_naive_2fuse=True )