-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory leaks when using TRT execution provider #18466
Labels
ep:CUDA
issues related to the CUDA execution provider
ep:TensorRT
issues related to TensorRT execution provider
Comments
github-actions
bot
added
ep:CUDA
issues related to the CUDA execution provider
ep:TensorRT
issues related to TensorRT execution provider
labels
Nov 16, 2023
thanks. this will be fixed asap. |
jywu-msft
pushed a commit
that referenced
this issue
Nov 16, 2023
Free memory for cudnn/cublas instances at TRT EP destruction. #18466
@jywu-msft Thanks for quick fix. Would it be possible to cherry-pick this change in 1.16.2? |
If not cherry-pick, then is it possible to get a patch release of ORT with this fix soon? |
Unfortunately it's too late to include in 1.16.2 which already went out. |
jywu-msft
pushed a commit
that referenced
this issue
Nov 17, 2023
Free memory for cudnn/cublas instances at TRT EP destruction. #18466
I am preparing a patch release for this issue. |
Merged
kleiti
pushed a commit
to kleiti/onnxruntime
that referenced
this issue
Mar 22, 2024
Free memory for cudnn/cublas instances at TRT EP destruction. microsoft#18466
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
ep:CUDA
issues related to the CUDA execution provider
ep:TensorRT
issues related to TensorRT execution provider
Describe the issue
Valgrind reports memory leak when using TRT execution provider with user-provided cuda stream. It also introduces memory growth in GPU.
To reproduce
Compile the script below and put any onnx model under the same path(using resnet50-1.2.onnx in the example).
main.cc
Compilation example - CMakeLists.txt
The script is loading and unloading an onnx model for a couple of iterations. In the script it prints out the GPU memory usage to the file
gpu_mem.log
. The GPU memory growth should be observed.To see the leak reported by Valgrind, run Valgrind with the executable:
In the Valgrind output, memory leaks are reported:
I suspect that the leak is introduced to
rel-1.16.2
due tocublas_handle
andcudnn_handle
not getting cleaned up properly when using user-provided cuda stream. The handles are created in the constructor here:https://github.com/microsoft/onnxruntime/blob/rel-1.16.2/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L795-L798
and there are no
cublasDestroy
andcudnnDestroy
calls for them.Without providing the cuda stream to TRT execution provider, there is no leak/GPU memory growth observed.
Urgency
High
Platform
Linux
OS Version
22.04
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
rel-1.16.2
ONNX Runtime API
C++
Architecture
X64
Execution Provider
TensorRT
Execution Provider Library Version
TensorRT 8.6.1.6
The text was updated successfully, but these errors were encountered: