Skip to content
5 changes: 3 additions & 2 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def swiglu(x, y=None):
"LlamaForCausalLM",
"LlamaPretrainingCriterion",
]

global npu_is_casual
npu_is_casual = False

def _get_interleave(n):
Expand Down Expand Up @@ -213,7 +213,7 @@ def scaled_dot_product_attention(
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape

global npu_is_casual
if config.use_flash_attention and flash_attention:
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
Expand Down Expand Up @@ -1613,6 +1613,7 @@ def forward(
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
global npu_is_casual
if self.config.use_flash_attention:
is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
Expand Down