diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 430e24f2c8b7..28dc897f40b2 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1567,7 +1567,7 @@ def forward( if is_casual and alibi is None: attention_mask = None else: - attention_mask = attention_mask.astype("bool") + attention_mask = None if attention_mask is None else attention_mask.astype("bool") hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None