Skip to content

Commit 9216159

Browse files
authored
Add bias to gptj (#1363)
1 parent fc399fa commit 9216159

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

optimum/habana/transformers/models/gptj/modeling_gptj.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ def __init__(self, config: GPTJConfig, layer_idx=None):
7373
super().__init__(config)
7474
self.config = config
7575

76+
max_positions = config.max_position_embeddings
77+
self.register_buffer(
78+
"bias",
79+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
80+
1, 1, max_positions, max_positions
81+
),
82+
persistent=False,
83+
)
7684
self.matmul_qk = Matmul()
7785
self.matmul_av = Matmul()
7886
self.k_cache = KVCache()

0 commit comments

Comments
 (0)