You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for your work!
I use mamba2 to build a U-Net to solve a vision task.My U-Net consists of 14 layers: 5 encoder layers, 5 decoder layers, and 4 additional layers after the decoder. Each layer is a Mamba2 model. Each encoder's output is passed to the corresponding decoder layer as a residual connection using the .clone() method.I printed the version numbers of the two tensors, conv1d_weight and conv1d_bias, at the end of the forward and backward propagation in mamba/mamba_ssm/ops/triton/ssd_combined.MambaSplitConv1dScanCombinedFn.
Each occurrence of "x after patch embed torch.Size([1, 256, 8, 16, 16])" indicates that a new batch is being trained.
But when I try to train the model,I encounter this error during the backward propagation process when training (it occurs during the second batch).From the tensor.shape, it can be judged that the problematic tensors are conv1d_weight or conv1d_bias.The output and error messages are as follows:
(py310) root@autodl-container-f1ea4781b3-ca3a2739:/autodl-fs/data/20241122v1# python -m videomamba2.trainvm2
/root/miniconda3/envs/py310/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
/root/miniconda3/envs/py310/lib/python3.10/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
torch.Size([4, 64, 128, 128])
original shape: torch.Size([696, 256, 256])
Use checkpoint: False
Checkpoint number: 0
-----第1轮训练开始-----
x after patch embed torch.Size([1, 256, 8, 16, 16])
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
end of forward conv1d_weight version: 1
end of forward conv1d_bias version: 1
up sample x shape: torch.Size([1, 256, 8, 256, 256])
x before out proj torch.Size([1, 524289, 256])
x in out proj before reshape torch.Size([1, 524288, 256])
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
end of backward conv1d_weight version: 1
end of backward conv1d_bias version: 1
x after patch embed torch.Size([1, 256, 8, 16, 16])
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
end of forward conv1d_weight version: 2
end of forward conv1d_bias version: 2
up sample x shape: torch.Size([1, 256, 8, 256, 256])
x before out proj torch.Size([1, 524289, 256])
x in out proj before reshape torch.Size([1, 524288, 256])
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
end of backward conv1d_weight version: 2
end of backward conv1d_bias version: 2
/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:251: UserWarning: Error detected in MambaSplitConv1dScanCombinedFnBackward. Traceback of forward call that caused the error:
File "/root/miniconda3/envs/py310/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/root/miniconda3/envs/py310/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/autodl-fs/data/20241122v1/videomamba2/trainvm2.py", line 113, in <module>
outputs = myModule(imgs.to(torch.float32))
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/autodl-fs/data/20241122v1/videomamba2/models/videomamba_pretrain.py", line 742, in forward
hidden_states, residual = layer(
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/autodl-fs/data/20241122v1/videomamba2/models/videomamba_pretrain.py", line 92, in forward
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/autodl-fs/data/20241122v1/mamba/mamba_ssm/modules/mamba2.py", line 243, in forward
out = mamba_split_conv1d_scan_combined(
File "/autodl-fs/data/20241122v1/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 959, in mamba_split_conv1d_scan_combined
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File "/root/miniconda3/envs/py310/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/root/miniconda3/envs/py310/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/autodl-fs/data/20241122v1/videomamba2/trainvm2.py", line 119, in <module>
loss.backward(retain_graph=True)
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 140, in decorate_bwd
return bwd(*args, **kwargs)
File "/autodl-fs/data/20241122v1/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 839, in backward
zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [768, 4]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
The error occurs at 837 row in the file mamba/mamba_ssm/ops/triton/ssd_combined.py:
Thanks for your work!
I use mamba2 to build a U-Net to solve a vision task.My U-Net consists of 14 layers: 5 encoder layers, 5 decoder layers, and 4 additional layers after the decoder. Each layer is a Mamba2 model. Each encoder's output is passed to the corresponding decoder layer as a residual connection using the
.clone()
method.I printed the version numbers of the two tensors, conv1d_weight and conv1d_bias, at the end of the forward and backward propagation in mamba/mamba_ssm/ops/triton/ssd_combined.MambaSplitConv1dScanCombinedFn.Each occurrence of "x after patch embed torch.Size([1, 256, 8, 16, 16])" indicates that a new batch is being trained.
But when I try to train the model,I encounter this error during the backward propagation process when training (it occurs during the second batch).From the tensor.shape, it can be judged that the problematic tensors are conv1d_weight or conv1d_bias.The output and error messages are as follows:
The error occurs at 837 row in the file mamba/mamba_ssm/ops/triton/ssd_combined.py:
Does anyone know what to do?
config:
mamba 2.2.0
casua_conv_1d 1.4.0
torch 2.1.1
cuda 11.8
python 3.10
The text was updated successfully, but these errors were encountered: