Skip to content

Commit 144efb4

Browse files
author
Kye
committed
forward pass integration with flash head
1 parent 8b7850a commit 144efb4

1 file changed

Lines changed: 6 additions & 7 deletions

File tree

LongNet/attention.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def get_mask(self, i, j):
6767
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 2)
6868

6969

70-
# Forward function
7170
def forward(self, x):
7271
# Get batch size, sequence length and model dimension
7372
batch_size, seq_len, _ = x.shape
@@ -86,17 +85,17 @@ def forward(self, x):
8685

8786
# Apply offset and segment for this head
8887
x_ = x[:, offset::self.dilation_rate, :]
89-
x_ = x_.contiguous().view(batch_size, -1, self.segment_size, self.d_model)
90-
88+
x_ = x_.contiguous().view(batch_size, 1, -1, self.segment_size, self.d_model) # Add an extra dimension for the number of heads
9189

92-
90+
# Process each segment separately
9391
elements_attns = []
94-
for idx in range(x_.shape[1]):
95-
element = x_[:, idx, :, :].to(dtype)
92+
for idx in range(x_.shape[2]):
93+
element = x_[:, :, idx, :, :].to(dtype)
9694
element_attn = attention(element, element, element)
9795
elements_attns.append(element_attn)
9896

99-
attn_output = torch.cat(elements_attns, dim=1)
97+
attn_output = torch.cat(elements_attns, dim=2)
98+
10099

101100

102101
#option2

0 commit comments

Comments
 (0)