Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Bug of CuDNN RNN with variable sequence length #10453

Closed
sxjscience opened this issue Apr 7, 2018 · 9 comments · Fixed by #11041
Closed

Bug of CuDNN RNN with variable sequence length #10453

sxjscience opened this issue Apr 7, 2018 · 9 comments · Fixed by #11041

Comments

@sxjscience
Copy link
Member

sxjscience commented Apr 7, 2018

Description

Segfault will be triggered by the following code:

from mxnet.gluon.rnn import LSTM
import mxnet as mx
import numpy as np

ctx = mx.gpu()
lstm = LSTM(num_layers=1, hidden_size=200, dropout=0.5)
lstm.initialize(ctx=ctx)
batch_size = 32
for seq_len in range(500, 10, -1):
    for repeat in range(10):
        print(seq_len, repeat)
        inputs_nd = mx.nd.random.normal(0, 1, shape=(seq_len, batch_size, 200), ctx=ctx)
        out = lstm(inputs_nd)
        print(out[0].sum().asscalar())
        mx.nd.waitall()

I'm using V100 + cuda 9.0 + cudnn 7.0.4 (P3 instance). The GPU memory keeps increasing and finally raises seg fault.

Also, the same script + configuration has not triggered an error in M60 (g3 instance).

@eric-haibin-lin @DickJC123 @szha @szhengac

backtrace:

#0  __GI___libc_free (mem=0x7f0000000000) at malloc.c:2951
#1  0x00007fff7252e6f9 in cudnnDestroyFilterDescriptor () from /usr/local/cuda/lib64/libcudnn.so.7
#2  0x00007fff99c3683c in mxnet::op::CuDNNRNNOp<float>::~CuDNNRNNOp() () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#3  0x00007fff99c36bd9 in mxnet::op::CuDNNRNNOp<float>::~CuDNNRNNOp() () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#4  0x00007fff97d9bde5 in dmlc::any::TypeOnHeap<mxnet::op::OperatorState>::destroy(dmlc::any::Data*) () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#5  0x00007fff9579d8e1 in std::_Sp_counted_ptr_inplace<mxnet::OpStatePtr::OpState, std::allocator<mxnet::OpStatePtr::OpState>, (__gnu_cxx::_Lock_policy)2>::_M_dispose() () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#6  0x00007fff9551bed7 in std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release() () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#7  0x00007fff97def885 in mxnet::imperative::PushOperator(mxnet::OpStatePtr const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, mxnet::DispatchMode)::{lambda(mxnet::RunContext, mxnet::engine::CallbackOnComplete)#2}::~CallbackOnComplete() () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#8  0x00007fff97deff08 in std::_Function_base::_Base_manager<mxnet::imperative::PushOperator(mxnet::OpStatePtr const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, mxnet::DispatchMode)::{lambda(mxnet::RunContext)#3}>::_M_manager(std::_Any_data&, std::_Function_base::_Base_manager<mxnet::imperative::PushOperator(mxnet::OpStatePtr const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, mxnet::DispatchMode)::{lambda(mxnet::RunContext)#3}> const&, std::_Manager_operation) () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#9  0x00007fff982233a6 in std::_Function_base::_Base_manager<mxnet::engine::ThreadedEngine::PushSync(std::function<void (mxnet::RunContext)>, mxnet::Context, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, mxnet::FnProperty, int, char const*)::{lambda(mxnet::RunContext, mxnet::engine::CallbackOnComplete)#1}>::_M_manager(std::_Any_data&, std::_Function_base::_Base_manager<mxnet::engine::ThreadedEngine::PushSync(std::function<void (mxnet::RunContext)>, mxnet::Context, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, mxnet::FnProperty, int, char const*)::{lambda(mxnet::RunContext, mxnet::engine::CallbackOnComplete)#1}> const&, std::_Manager_operation) () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#10 0x00007fff98229cc4 in mxnet::engine::ThreadedEngine::OnComplete(mxnet::engine::ThreadedOpr*) () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#11 0x00007fff982245c6 in mxnet::engine::ThreadedEngine::OnCompleteStatic(mxnet::Engine*, void*) () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#12 0x00007fff9821e868 in mxnet::engine::ThreadedEngine::ExecuteOprBlock(mxnet::RunContext, mxnet::engine::OprBlock*) ()
   from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#13 0x00007fff9823419b in void mxnet::engine::ThreadedEnginePerDevice::GPUWorker<(dmlc::ConcurrentQueueType)0>(mxnet::Context, bool, mxnet::engine::ThreadedEnginePerDevice::ThreadWorkerBlock<(dmlc::ConcurrentQueueType)0>*, std::shared_ptr<dmlc::ManualEvent> const&) () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#14 0x00007fff982343fe in std::_Function_handler<void (std::shared_ptr<dmlc::ManualEvent>), mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#3}::operator()() const::{lambda(std::shared_ptr<dmlc::ManualEvent>)#1}>::_M_invoke(std::_Any_data const&, std::shared_ptr<dmlc::ManualEvent>&&) ()
   from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#15 0x00007fff9822e2aa in std::thread::_Impl<std::_Bind_simple<std::function<void (std::shared_ptr<dmlc::ManualEvent>)> (std::shared_ptr<dmlc::ManualEvent>)> >::_M_run() () from /home/ubuntu/mxnet/python/mxnet/../../lib/libmxnet.so
#16 0x00007fffee4a2c80 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#17 0x00007ffff7bc16ba in start_thread (arg=0x7fff3ed07700) at pthread_create.c:333
#18 0x00007ffff78f741d in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:109
@sxjscience
Copy link
Member Author

The following code, which always use seq_len=500 will not trigger the seg fault. This is a very critical bug.

from mxnet.gluon.rnn import LSTM
import mxnet as mx
import numpy as np

ctx = mx.gpu()
lstm = LSTM(num_layers=1, hidden_size=200, dropout=0.0)
lstm.initialize(ctx=ctx)
batch_size = 32
for seq_len in range(500, 10, -1):
    for repeat in range(10):
        real_seq_len = 500
        print(real_seq_len, repeat)
        inputs_nd = mx.nd.random.normal(0, 1, shape=(real_seq_len, batch_size, 200), ctx=ctx)
        out = lstm(inputs_nd)
        print(out[0].sum().asscalar())
        mx.nd.waitall()

@szhengac
Copy link
Contributor

szhengac commented Apr 7, 2018

The bug occurs when we have variable sequence length. I think it may be related to how the mxnet reuses the memory.

@szha
Copy link
Member

szha commented Apr 7, 2018

I was able to finish running the script by setting export MXNET_GPU_MEM_POOL_RESERVE=7

@eric-haibin-lin
Copy link
Member

@szha how much memory consumption did you observe?
One thing to try is to can run the code step by step and figure out which batch causes the err

@szha
Copy link
Member

szha commented Apr 9, 2018

What I observed is that it doesn't fail consistently on certain specific batch. Another team observed the same issue before, and it is likely caused by our backend memory pool holding too much memory, in which case the curand doesn't have enough memory to keep the random number generator states for each stream multiprocessor.

@Jerryzcn
Copy link
Contributor

Jerryzcn commented Apr 9, 2018

I have similar issue when training speech model. even after
export MXNET_GPU_MEM_POOL_RESERVE=7
will try larger RESERVE

@sxjscience
Copy link
Member Author

@sxjscience
Copy link
Member Author

sxjscience commented Apr 17, 2018

It's related to pytorch/pytorch#953. cudnnSetDropoutDescriptor uses a large amount of GPU memory. One choice to solve this problem is to create a DropoutDescriptor when we create a stream and always use cudnnGetDropoutDescriptor. This will also accelerate the speed of RNN layer in Gluon because we can avoid calling Alloc and Free.

@leezu
Copy link
Contributor

leezu commented May 20, 2018

#11004 "fixes" this issue. The filter descriptors that are freed in the destructor were not created if cudaMalloc would fail during Forward or Backward.

Now the following error will be returned in an OOM situation:


mxnet.base.MXNetError: [05:13:15] src/storage/./pooled_storage_manager.h:108: cudaMalloc failed: out of memory

Stack trace returned 10 entries:
[bt] (0) /home/leonard/software/mxnet-master/python/mxnet/../../lib/libmxnet.so(dmlc::StackTrace[abi:cxx11]()+0x5b) [0x7f358103783b]
[bt] (1) /home/leonard/software/mxnet-master/python/mxnet/../../lib/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x28) [0x7f35810383a8]
[bt] (2) /home/leonard/software/mxnet-master/python/mxnet/../../lib/libmxnet.so(mxnet::storage::GPUPooledStorageManager::Alloc(mxnet::Storage::Handle*)+0x154) [0x7f358398b384]
[bt] (3) /home/leonard/software/mxnet-master/python/mxnet/../../lib/libmxnet.so(mxnet::StorageImpl::Alloc(mxnet::Storage::Handle*)+0x5d) [0x7f358398d80d]
[bt] (4) /home/leonard/software/mxnet-master/python/mxnet/../../lib/libmxnet.so(mxnet::op::CuDNNRNNOp<float>::Init(mshadow::Stream<mshadow::gpu>*, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)+0x1de5) [0x7f3585606a55]
[bt] (5) /home/leonard/software/mxnet-master/python/mxnet/../../lib/libmxnet.so(mxnet::op::CuDNNRNNOp<float>::Forward(mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)+0xa5d) [0x7f358560e07d]
...

In particular, #11004 makes sure that the descriptors are always created during class initialization and not just somewhere down the line inForward / Backward.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants