Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leaks when using TRT execution provider #18466

Closed
krishung5 opened this issue Nov 16, 2023 · 5 comments · Fixed by #18491
Closed

Memory leaks when using TRT execution provider #18466

krishung5 opened this issue Nov 16, 2023 · 5 comments · Fixed by #18491
Assignees
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider

Comments

@krishung5
Copy link

Describe the issue

Valgrind reports memory leak when using TRT execution provider with user-provided cuda stream. It also introduces memory growth in GPU.

To reproduce

Compile the script below and put any onnx model under the same path(using resnet50-1.2.onnx in the example).

main.cc

#include <onnxruntime_c_api.h>
#include <tensorrt_provider_factory.h>
#include <iostream>
#include <cuda_runtime_api.h>
#include <stdlib.h>

const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);

void CheckStatus(OrtStatus* status)
{
  if (status != NULL) {
    std::cerr << ort_api->GetErrorMessage(status) << std::endl;
    ort_api->ReleaseStatus(status);
    exit(1);
  }
}

int main(int argc, char* argv[])
{
  for (int i = 0; i < 100; i++) {
    OrtEnv* env;
    CheckStatus(ort_api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env));

    OrtSessionOptions* session_options;
    CheckStatus(ort_api->CreateSessionOptions(&session_options));

    CheckStatus(ort_api->SetIntraOpNumThreads(session_options, 1));
    CheckStatus(ort_api->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_BASIC));

    OrtSession* session;
    const char* model_path = "resnet50-1.2.onnx";

    cudaStream_t stream_ = nullptr;

    cudaStreamCreate(&stream_);

    // Initialize TRT options with default values
    OrtTensorRTProviderOptions trt_options{
        0,        // instance_group_device_id
        1,
        (void*)stream_,  // cuda_stream
        1000,     // trt_max_partition_iterations
        1,        // trt_min_subgraph_size
        1 << 30,  // max_workspace_size
        0,        // trt_fp16_enable
        0,        // trt_int8_enable
        nullptr,  // trt_int8_calibration_table_name
        0,        // trt_int8_use_native_calibration_table
        0,        // trt_dla_enable
        0,        // trt_dla_core
        0,        // trt_dump_subgraphs
        0,        // trt_engine_cache_enable
        nullptr,  // trt_engine_cache_path
        0,        // trt_engine_decryption_enable
        nullptr,  // trt_engine_decryption_lib_path
        0         // trt_force_sequential_engine_build
    };
    
    CheckStatus(ort_api->SessionOptionsAppendExecutionProvider_TensorRT(session_options, &trt_options));
    CheckStatus(ort_api->CreateSession(env, model_path, session_options, &session));

    ort_api->ReleaseSession(session);
    system("bash -c 'LOADED_GPU_USAGE_MiB=$(nvidia-smi -i 0 --query-gpu=memory.used --format=csv | grep \" MiB\") && echo $LOADED_GPU_USAGE_MiB >> gpu_mem.log'");
    ort_api->ReleaseSessionOptions(session_options);
    ort_api->ReleaseEnv(env);

    cudaStreamDestroy(stream_);
  }

  return 0;
}

Compilation example - CMakeLists.txt

cmake_minimum_required(VERSION 3.17)

project(test LANGUAGES C CXX)

set(CMAKE_BUILD_TYPE Debug)

include_directories( 
  /onnxruntime/include/onnxruntime/core/session/
  /onnxruntime/include/onnxruntime/core/providers/tensorrt
)

find_package(CUDAToolkit REQUIRED)
        
ADD_EXECUTABLE(test main.cc)
target_link_libraries(
  test
  onnxruntime
  CUDA::cudart
)

The script is loading and unloading an onnx model for a couple of iterations. In the script it prints out the GPU memory usage to the file gpu_mem.log. The GPU memory growth should be observed.

To see the leak reported by Valgrind, run Valgrind with the executable:

/usr/bin/valgrind --leak-check=full --show-leak-kinds=definite --max-threads=3000 --num-callers=20 --keep-debuginfo=yes --log-file=./valgrind.log ./test

In the Valgrind output, memory leaks are reported:

==63613== 68,000 (67,136 direct, 864 indirect) bytes in 1 blocks are definitely lost in loss record 2,354 of 2,388
==63613==    at 0x4848899: malloc (in /usr/libexec/valgrind/vgpreload_memcheck-amd64-linux.so)
==63613==    by 0x40C6D9FC: ??? (in /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcublas.so.12.3.2.9)
==63613==    by 0x407F64C6: cublasCreate_v2 (in /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcublas.so.12.3.2.9)
==63613==    by 0x5B7E347: onnxruntime::TensorrtExecutionProvider::TensorrtExecutionProvider(onnxruntime::TensorrtExecutionProviderInfo const&) (in /opt/tritonserver/backends/onnxruntime/libonnxruntime_providers_tensorrt.so)
==63613==    by 0x5BAD496: onnxruntime::TensorrtProviderFactory::CreateProvider() (in /opt/tritonserver/backends/onnxruntime/libonnxruntime_providers_tensorrt.so)
==63613==    by 0x4A838CC: (anonymous namespace)::InitializeSession(OrtSessionOptions const*, std::unique_ptr<onnxruntime::InferenceSession, std::default_delete<onnxruntime::InferenceSession> >&, OrtPrepackedWeightsContainer*) (in /opt/tritonserver/backends/onnxruntime/libonnxruntime.so)
==63613==    by 0x4A8E8D4: OrtApis::CreateSession(OrtEnv const*, char const*, OrtSessionOptions const*, OrtSession**) (in /opt/tritonserver/backends/onnxruntime/libonnxruntime.so)
==63613==    by 0x10942C: main (main.cc:58)
==63613== 
==63613== 101,848 (1,080 direct, 100,768 indirect) bytes in 1 blocks are definitely lost in loss record 2,363 of 2,388
==63613==    at 0x4849013: operator new(unsigned long) (in /usr/libexec/valgrind/vgpreload_memcheck-amd64-linux.so)
==63613==    by 0xB605FE71: cudnnCreate (in /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6)
==63613==    by 0x5B7E3EE: onnxruntime::TensorrtExecutionProvider::TensorrtExecutionProvider(onnxruntime::TensorrtExecutionProviderInfo const&) (in /opt/tritonserver/backends/onnxruntime/libonnxruntime_providers_tensorrt.so)
==63613==    by 0x5BAD496: onnxruntime::TensorrtProviderFactory::CreateProvider() (in /opt/tritonserver/backends/onnxruntime/libonnxruntime_providers_tensorrt.so)
==63613==    by 0x4A838CC: (anonymous namespace)::InitializeSession(OrtSessionOptions const*, std::unique_ptr<onnxruntime::InferenceSession, std::default_delete<onnxruntime::InferenceSession> >&, OrtPrepackedWeightsContainer*) (in /opt/tritonserver/backends/onnxruntime/libonnxruntime.so)
==63613==    by 0x4A8E8D4: OrtApis::CreateSession(OrtEnv const*, char const*, OrtSessionOptions const*, OrtSession**) (in /opt/tritonserver/backends/onnxruntime/libonnxruntime.so)
==63613==    by 0x10942C: main (main.cc:58)

I suspect that the leak is introduced to rel-1.16.2 due to cublas_handle and cudnn_handle not getting cleaned up properly when using user-provided cuda stream. The handles are created in the constructor here:
https://github.com/microsoft/onnxruntime/blob/rel-1.16.2/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L795-L798
and there are no cublasDestroy and cudnnDestroy calls for them.
Without providing the cuda stream to TRT execution provider, there is no leak/GPU memory growth observed.

Urgency

High

Platform

Linux

OS Version

22.04

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

rel-1.16.2

ONNX Runtime API

C++

Architecture

X64

Execution Provider

TensorRT

Execution Provider Library Version

TensorRT 8.6.1.6

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider labels Nov 16, 2023
@jywu-msft
Copy link
Member

thanks. this will be fixed asap.

jywu-msft pushed a commit that referenced this issue Nov 16, 2023
Free memory for cudnn/cublas instances at TRT EP destruction.
#18466
@Tabrizian
Copy link

@jywu-msft Thanks for quick fix. Would it be possible to cherry-pick this change in 1.16.2?

@tanmayv25
Copy link

tanmayv25 commented Nov 17, 2023

If not cherry-pick, then is it possible to get a patch release of ORT with this fix soon?

@jywu-msft
Copy link
Member

Unfortunately it's too late to include in 1.16.2 which already went out.
1.17 is scheduled to be released soon on Dec 12th.

jywu-msft pushed a commit that referenced this issue Nov 17, 2023
Free memory for cudnn/cublas instances at TRT EP destruction.
#18466
@snnn
Copy link
Member

snnn commented Nov 17, 2023

I am preparing a patch release for this issue.

@snnn snnn linked a pull request Nov 17, 2023 that will close this issue
@snnn snnn closed this as completed Nov 20, 2023
kleiti pushed a commit to kleiti/onnxruntime that referenced this issue Mar 22, 2024
Free memory for cudnn/cublas instances at TRT EP destruction.
microsoft#18466
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants