Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I input a four-dimensional tensor into Mamba2? #635

Open
NicoleDyson opened this issue Dec 4, 2024 · 1 comment
Open

How do I input a four-dimensional tensor into Mamba2? #635

NicoleDyson opened this issue Dec 4, 2024 · 1 comment

Comments

@NicoleDyson
Copy link

For example, I have a four-dimensional tensor with shape [2, 56, 56, 96], where the four dimensions correspond to batch_size, height, width, and channels. If I directly set the parameters as follows:

d_model = width * height  # This is incorrect
d_state = 64,
d_conv = `4,`
expand = 2,

and create the Mamba2 module, it results in an error:

File "/root/anaconda3/envs/test/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py", line 157, in forward
batch, seqlen, dim = u.shape
ValueError: too many values to unpack (expected 3)
@AlwaysFHao
Copy link

Mamba's input requires the format of [batch, seq_len, dim]. You can refer to the method of Vision Transformer to transform the data of [batch, c, h, w] into [batch, dim, patch, patch] using Conv2d, and then transpose it to the shape of [batch, patch * patch, dim] for input. The above satisfies seq_len=(h / patch) ^ 2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants