Skip to content

Implement bi-directionality#52

Open
yair-schiff wants to merge 6 commits intostate-spaces:mainfrom
yair-schiff:bidirectional
Open

Implement bi-directionality#52
yair-schiff wants to merge 6 commits intostate-spaces:mainfrom
yair-schiff:bidirectional

Conversation

@yair-schiff
Copy link
Copy Markdown
Contributor

@yair-schiff yair-schiff commented Dec 13, 2023

Edit:

  • Implement bi-directionality by applying Mamba module twice: (1) to the forward sequence and (2) to the backward sequence.
  • Implement 3 2 strategies for combining forward / backward Mamba hidden states:
    1. add: Add the states.
    2. concat: Concatenate the states. This doubles the hidden dimension,d_model, which also prevents weight tying between embedding and lm_head weights.
    3. ew_multiply: perform element-wise multiplication between the states.

Copy link
Copy Markdown

@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some nits

@sentialx
Copy link
Copy Markdown

What if the sequences have paddings? E.g.
Input is
[1 2 3 0 0 0]
So flipped input would be
[0 0 0 3 2 1].
Shouldn't it be
[3 2 1 0 0 0]?

@yair-schiff
Copy link
Copy Markdown
Contributor Author

@sentialx , agreed. That's a good catch.

@jimmieliu
Copy link
Copy Markdown

how the speed compares to uni-directional?

@yair-schiff
Copy link
Copy Markdown
Contributor Author

yair-schiff commented Jan 3, 2024

how the speed compares to uni-directional?

@jimmieliu, it's about 2x

@albertfgu albertfgu mentioned this pull request Jan 11, 2024
@pengzhangzhi
Copy link
Copy Markdown

@yair-schiff I am just curious, did you solve the

What if the sequences have paddings? E.g. Input is [1 2 3 0 0 0] So flipped input would be [0 0 0 3 2 1]. Shouldn't it be [3 2 1 0 0 0]?

Just curious, is this problem solved?

@pengzhangzhi
Copy link
Copy Markdown

I came up with a solution to the padding issue. Say a tensor [1,2,3,0,0], where 0 is the padding token. We flip it to get [0,0,1,2,3], pass it to the network and flip it back. Therefore, the flipped tensor information matches the original tensor order as we apply double flips.

given: x
out = x + f(x.flip()).flip()

@xuanwuji
Copy link
Copy Markdown

I came up with a solution to the padding issue. Say a tensor [1,2,3,0,0], where 0 is the padding token. We flip it to get [0,0,1,2,3], pass it to the network and flip it back. Therefore, the flipped tensor information matches the original tensor order as we apply double flips.

given: x
out = x + f(x.flip()).flip()

Hi, Your approach is clever! But I have a question: if you flip the input to [0,0,1,2,3], does the padding in front of it affect sequence hidden features learning? i.e., does it produce a different result(bad repersentation of sequence) than the input of [3,2,1,0,0]?
I don't know enough about it, could you possibly give me some guidance? This will help me a lot. Thank you very much!

@Museum7432
Copy link
Copy Markdown

Museum7432 commented Jul 14, 2024

@xuanwuji well, you can remove the leading paddings by shifting each row of x before flipping x. As for its effect, since the hidden state is initialized with 0, it should still be filled with 0 after scanning through the paddings. So, those padding shouldn't have any effect on the result. However, you can use the following function just to be sure.

def flip_padded_hidden_states(hidden_states, seq_lens):
    batch_size, seq_len, hidden_dim = hidden_states.shape

    indices = torch.arange(batch_size * seq_len, device=hidden_states.device).reshape(
        batch_size, seq_len
    )

    indices_offset = seq_len - seq_lens

    indices = (indices - indices_offset.unsqueeze(1)) % (seq_len * batch_size)

    indices = indices.flip(1)

    return hidden_states.reshape(batch_size * seq_len, hidden_dim)[indices]

To check the effect of paddings:

import torch
from mamba_ssm import Mamba2, Mamba
from torch.nn import functional as F

batch, length, dim = 2, 64, 16

model = Mamba(
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

x = torch.randn(batch, length, dim).to("cuda")
padded_x = F.pad(x, (0,0, 4,0))

y = model(x)
padded_y = model(padded_x)

unpadded_y = padded_y[:,4:]

print(f'Output max diff: {(unpadded_y - y).abs().max().item()}')
print(f'Output mean diff: {(unpadded_y - y).abs().mean().item()}')

However, these errors do stack after multiple layers, so you should use the flip_padded_hidden_states function just to be certain.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants