Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit e25bcb3

Browse files
author
Ubuntu
committed
add comment
1 parent a093709 commit e25bcb3

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/gluonnlp/adapters.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,14 @@ def forward(self, query, key, value):
5858

5959
key = self.key_proj(key).transpose((0, 1, 3, 2))
6060
value = self.value_proj(value)
61+
62+
#previous implementaion
6163
# query = npx.reshape(self.query_proj(query), (-2, -2, 1, -1))
64+
# scores = np.squeeze(npx.batch_dot(query, key), axis=2)
65+
# with einsum
6266
query = self.query_proj(query)
63-
#scores = np.squeeze(npx.batch_dot(query, key), axis=2)
6467
scores = np.einsum('blu, blun -> bln', query, key)
68+
6569
attn_weights = npx.softmax(scores, axis=-1)
6670
#attn batch size lenght, num
6771
#value bs l, num, u

0 commit comments

Comments
 (0)