Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 43 additions & 31 deletions python/mlx/nn/layers/positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,47 +115,59 @@ def __call__(self, x):


class ALiBi(Module):
_alibi_mask_key = None
_alibi_mask = None

@classmethod
@staticmethod
def create_alibi_matrix(
cls,
q_sequence_length: int,
k_sequence_length: int,
num_heads: int,
offset: int,
dtype=mx.float32,
):
if (
q_sequence_length,
k_sequence_length,
num_heads,
offset,
dtype,
) != cls._alibi_mask_key:
x1 = mx.arange(offset, q_sequence_length)
x2 = mx.arange(0, k_sequence_length)
distance_matrix = -mx.abs(
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
)
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
cls._alibi_mask_key = (
q_sequence_length,
k_sequence_length,
num_heads,
offset,
dtype,
)
cls._alibi_mask = alibi_mask
x1 = mx.arange(offset, q_sequence_length)
x2 = mx.arange(0, k_sequence_length)
distance_matrix = -mx.abs(
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
)
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads, dtype=dtype)
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
return alibi_mask

@staticmethod
def alibi_get_slopes(num_heads: int):
Comment thread
vovw marked this conversation as resolved.
Outdated
"""Get the slopes for different attention heads defined in ALiBi paper.

return cls._alibi_mask
This is a direct copy from ALiBi codebase.
Ref:
https://github.com/ofirpress/attention_with_linear_biases/tree/3b7c2eca/fairseq/models/transformer.py#L742-L752

Args:
num_heads: An integer for the number of attention heads.

Returns:
A tensor of slopes with shape of [num_heads]. Each value represents
a slope for one attention head.
"""

def get_slopes_power_of_2(n: int):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]

if math.log2(num_heads).is_integer():
return get_slopes_power_of_2(num_heads)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
return (
get_slopes_power_of_2(closest_power_of_2)
+ ALiBi.alibi_get_slopes(2 * closest_power_of_2)[0::2][
: num_heads - closest_power_of_2
]
)

@staticmethod
def create_alibi_slope(num_heads):
x = (2**8) ** (1 / num_heads)
out = mx.power(x, -mx.arange(1, num_heads + 1))
def create_alibi_slope(num_heads, dtype):
slopes = ALiBi.alibi_get_slopes(num_heads)
out = mx.array(slopes, dtype=dtype)
return mx.expand_dims(out, axis=(-1, -2))

def __call__(self, attention_scores, offset=0, mask=None):
Expand Down