Skip to content

Commit f2d7eaa

Browse files
jiminharegisss
authored andcommitted
Fix OOM error for code llama (#1437)
1 parent 720ccc2 commit f2d7eaa

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

optimum/habana/transformers/models/llama/modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __init__(
119119
self.rope_type = "default"
120120
self.max_seq_len_cached = config.max_position_embeddings
121121
# Truncate the cached max sequence length to 8k to limit cached register buffer size
122-
if config.max_position_embeddings >= 8192:
122+
if config.max_position_embeddings > 8192 and self.rope_type == "llama3":
123123
self.max_seq_len_cached = 8192
124124
self.original_max_seq_len = config.max_position_embeddings
125125

0 commit comments

Comments
 (0)