Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba2 constantly gives error regarding causal_conv1d: strides (x.stride(0) && x.stride(2)) to be multiples of 8 #643

Open
roman8ivanov opened this issue Dec 9, 2024 · 1 comment

Comments

@roman8ivanov
Copy link

HI,

I am trying to play with Mamba2. I tried several options when it comes to input dim. However, I constantly getting error as below.

from mamba_ssm import Mamba2
batch, length, dim = 100, 64, 64
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=128, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
headdim=dim,
).to("cuda")

y = model(x)

File C:\ProgramData\anaconda3\Lib\site-packages\mamba_ssm\ops\triton\ssd_combined.py:930 in mamba_split_conv1d_scan_combined
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

File C:\ProgramData\anaconda3\Lib\site-packages\torch\autograd\function.py:575 in apply
return super().apply(*args, **kwargs) # type: ignore[misc]

File C:\ProgramData\anaconda3\Lib\site-packages\torch\amp\autocast_mode.py:465 in decorate_fwd
return fwd(*args, **kwargs)

File C:\ProgramData\anaconda3\Lib\site-packages\mamba_ssm\ops\triton\ssd_combined.py:779 in forward
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),

RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) && x.stride(2)) to be multiples of 8
Exception raised from causal_conv1d_fwd at C:\Users\roman\AppData\Local\Temp\pip-req-build-ph0joy5p\csrc\causal_conv1d.cpp:162 (most recent call first):
00007FFDDA6E83C900007FFDDA6E8320 c10.dll!c10::Error::Error [ @ ]
00007FFDDA6E6BEA00007FFDDA6E6B90 c10.dll!c10::detail::torchCheckFail [ @ ]

@yunqingliu1996
Copy link

yunqingliu1996 commented Dec 12, 2024

haunted me for 2 days, same as you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants