Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error in backward of inplace operation and tensor version #642

Open
kanyu369 opened this issue Dec 9, 2024 · 0 comments
Open

error in backward of inplace operation and tensor version #642

kanyu369 opened this issue Dec 9, 2024 · 0 comments

Comments

@kanyu369
Copy link

kanyu369 commented Dec 9, 2024

Thanks for your work!
I use mamba2 to build a U-Net to solve a vision task.My U-Net consists of 14 layers: 5 encoder layers, 5 decoder layers, and 4 additional layers after the decoder. Each layer is a Mamba2 model. Each encoder's output is passed to the corresponding decoder layer as a residual connection using the .clone() method.I printed the version numbers of the two tensors, conv1d_weight and conv1d_bias, at the end of the forward and backward propagation in mamba/mamba_ssm/ops/triton/ssd_combined.MambaSplitConv1dScanCombinedFn.
Each occurrence of "x after patch embed torch.Size([1, 256, 8, 16, 16])" indicates that a new batch is being trained.
But when I try to train the model,I encounter this error during the backward propagation process when training (it occurs during the second batch).From the tensor.shape, it can be judged that the problematic tensors are conv1d_weight or conv1d_bias.The output and error messages are as follows:

(py310) root@autodl-container-f1ea4781b3-ca3a2739:/autodl-fs/data/20241122v1# python -m videomamba2.trainvm2
/root/miniconda3/envs/py310/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
  warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
/root/miniconda3/envs/py310/lib/python3.10/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models
  warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
torch.Size([4, 64, 128, 128])
original shape: torch.Size([696, 256, 256])
Use checkpoint: False
Checkpoint number: 0
-----第1轮训练开始-----
x after patch embed torch.Size([1, 256, 8, 16, 16])
end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

end of forward conv1d_weight version: 1
end of forward conv1d_bias version:   1

up sample x shape: torch.Size([1, 256, 8, 256, 256])
x before out proj torch.Size([1, 524289, 256])
x in out proj before reshape torch.Size([1, 524288, 256])
end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

end of backward conv1d_weight version: 1
end of backward conv1d_bias version:   1

x after patch embed torch.Size([1, 256, 8, 16, 16])
end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

end of forward conv1d_weight version: 2
end of forward conv1d_bias version:   2

up sample x shape: torch.Size([1, 256, 8, 256, 256])
x before out proj torch.Size([1, 524289, 256])
x in out proj before reshape torch.Size([1, 524288, 256])
end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2

end of backward conv1d_weight version: 2
end of backward conv1d_bias version:   2
/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:251: UserWarning: Error detected in MambaSplitConv1dScanCombinedFnBackward. Traceback of forward call that caused the error:
  File "/root/miniconda3/envs/py310/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/root/miniconda3/envs/py310/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/autodl-fs/data/20241122v1/videomamba2/trainvm2.py", line 113, in <module>
    outputs = myModule(imgs.to(torch.float32))
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/autodl-fs/data/20241122v1/videomamba2/models/videomamba_pretrain.py", line 742, in forward
    hidden_states, residual = layer(
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/autodl-fs/data/20241122v1/videomamba2/models/videomamba_pretrain.py", line 92, in forward
    hidden_states = self.mixer(hidden_states, inference_params=inference_params)
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/autodl-fs/data/20241122v1/mamba/mamba_ssm/modules/mamba2.py", line 243, in forward
    out = mamba_split_conv1d_scan_combined(
  File "/autodl-fs/data/20241122v1/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 959, in mamba_split_conv1d_scan_combined
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/root/miniconda3/envs/py310/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/root/miniconda3/envs/py310/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/autodl-fs/data/20241122v1/videomamba2/trainvm2.py", line 119, in <module>
    loss.backward(retain_graph=True)
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/root/miniconda3/envs/py310/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 140, in decorate_bwd
    return bwd(*args, **kwargs)
  File "/autodl-fs/data/20241122v1/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 839, in backward
    zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [768, 4]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The error occurs at 837 row in the file mamba/mamba_ssm/ops/triton/ssd_combined.py:

zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors

Does anyone know what to do?

config:
mamba 2.2.0
casua_conv_1d 1.4.0
torch 2.1.1
cuda 11.8
python 3.10

@kanyu369 kanyu369 changed the title error in backward of inplace operation error in backward of inplace operation and tensor version Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant