Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about ncclCommAbort stuck issue #1013

Open
acphile opened this issue Sep 27, 2023 · 89 comments
Open

Question about ncclCommAbort stuck issue #1013

acphile opened this issue Sep 27, 2023 · 89 comments

Comments

@acphile
Copy link

acphile commented Sep 27, 2023

Hi, I find an issue that ncclCommAbort hangs when there are multiple ProcessGroups.

Here is a simple example on 1 node 4 rank:

def run_worker(rank, world_size):
    ndev = torch.cuda.device_count()
    device = torch.device(f'cuda:{rank % ndev}')
    torch.cuda.set_device(device)
    dist.init_process_group('nccl', rank=rank, world_size=world_size,
                            timeout=datetime.timedelta(seconds=1800)
                           )
    print(torch.distributed.get_world_size())
    pg = c10d.distributed_c10d._get_default_group()
    pg1 = torch.distributed.new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend="nccl")

    device_id = f"cuda:{rank % torch.cuda.device_count()}"
    tensor0 = torch.ones([1]).cuda(rank) * rank
    tensor = torch.ones([1]).cuda(rank) * rank
    tensor1 = torch.ones([2,3]).cuda(rank) * rank
    # for initialization purpose
    pg.allreduce(tensor0)
    pg1.allreduce(tensor0)
    time.sleep(10)
    assert torch.cuda.current_stream().query() == True
    # suppose before the following operations, one rank is down.
    if rank==2:
        exit()
    print("One rank is down, other ranks continue")
    pg1.allreduce(tensor1)    
    pg.broadcast(tensor, torch.distributed.get_world_size() - 1)
    print("Suppose until now we find one rank is down, and we want to abort two collective operations")
    pg._get_backend(torch.device(torch.cuda.current_device()))._abort()
    print("abort")
    pg1._get_backend(torch.device(torch.cuda.current_device()))._abort()
    print("end")

In this case, we can find that the process would be stuck at the first _abort(). By gdb, we can find that it hangs at ncclCommAbort. However, if we change the order of two _abort() (abort pg1 first then pg) , then the process can exit successfully.

Even if all the two collective operations happen at pg1, when we first try to abort pg, the process would be stuck.

So is there any bug related to ncclCommAbort?

NCCL version=2.18.3

@KaimingOuyang
Copy link
Collaborator

That's expected if you are using the same stream or different stream but CUDA schedules the work in sequence.
you first call pg1.allreduce(tensor1), which is enqueued but cannot complete, then you enqueue pg.broadcast(tensor, torch.distributed.get_world_size() - 1).

At this time, if you abort pg first, it can wait for pg1.allreduce to complete (which will never complete since one rank is dead), so you hang forever. However, if you abort pg1 first, allreduce can exit, then broadcast can exit as well.

It is not an NCCL bug to me but an application bug.

@acphile
Copy link
Author

acphile commented Sep 27, 2023

Thanks for your reply. I am wondering when we first abort pg, why it has to wait the collective operation on pg1 to complete? Does it mean that when there are multiple ongoing collective operations and an error happens in a rank, we must abort these operations in the same order as they are enqueued?
Actually, when we change this two abort into a raise Exception line, the process cannot exit as expected and would hang at ncclCommAbort according to gdb.

And when allreduce and broadcast is both executed at pg1, like

print("One rank is down, other ranks continue")
pg1.allreduce(tensor1)    
pg1.broadcast(tensor, torch.distributed.get_world_size() - 1)

Even there is no ongoing operation on pg, we still cannot abort pg first.

@KaimingOuyang
Copy link
Collaborator

why it has to wait the collective operation on pg1 to complete?

Because you first issue pg1 allreduce, it cannot complete, which can block pg broadcast. abort will wait for its issued collective to complete before aborting everything.

You need to be very careful while using multiple communicators. It is a complex topic. If you enqueue collectives into different streams, and each stream gets resources to issue their workload. Then, abort sequence does not matter. As long as there is one which blocks another, you can easily reach a hang.

For the case

pg1.allreduce(tensor1)    
pg1.broadcast(tensor, torch.distributed.get_world_size() - 1)

If you abort pg1 first, it does not hang? If so, can u provide me the gdb backtrace when it hangs?

@acphile
Copy link
Author

acphile commented Sep 27, 2023

For the second case, yes, if we abort pg1 first, it does not hang. When aborting pg first, the backtrace of rank 0 is like

#0  __futex_abstimed_wait_common64 (private=128, cancel=true, abstime=0x0, op=265, expected=397683, futex_word=0x7f2188c90910) at ./nptl/futex-internal.c:57
#1  __futex_abstimed_wait_common (cancel=true, private=128, abstime=0x0, clockid=0, expected=397683, futex_word=0x7f2188c90910) at ./nptl/futex-internal.c:87
#2  __GI___futex_abstimed_wait_cancelable64 (futex_word=futex_word@entry=0x7f2188c90910, expected=397683, clockid=clockid@entry=0, abstime=abstime@entry=0x0, private=private@entry=128) at ./nptl/futex-internal.c:139
#3  0x00007f22220966a4 in __pthread_clockjoin_ex (threadid=139782005524032, thread_return=0x0, clockid=0, abstime=0x0, block=<optimized out>) at ./nptl/pthread_join_common.c:105
#4  0x00007f21a5251def in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#5  0x00007f21a525c668 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#6  0x00007f21a525d242 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#7  0x00007f21a525d862 in pncclCommAbort () from /lib/x86_64-linux-gnu/libnccl.so.2
#8  0x00007f21c7a5fc7a in c10d::NCCLComm::ncclCommAbort(c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#9  0x00007f21c7a2ba38 in c10d::abortCommsFromMap(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::vector<std::shared_ptr<c10d::NCCLComm>, std::allocator<std::shared_ptr<c10d::NCCLComm> > >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::vector<std::shared_ptr<c10d::NCCLComm>, std::allocator<std::shared_ptr<c10d::NCCLComm> > > > > >&, int, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) ()
   from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#10 0x00007f21c7a2bbf0 in c10d::ProcessGroupNCCL::abort(c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#11 0x00007f21dcc36dd2 in pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&)#61}, void, c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&)#61}&&, void (*)(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) () from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#12 0x00007f21dc3d61ff in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#13 0x00000000005072d7 in cfunction_call (func=0x7f21c679b310, args=<optimized out>, kwargs=<optimized out>) at /usr/local/src/conda/python-3.9.17/Objects/methodobject.c:543
#14 0x00000000004f06ac in _PyObject_MakeTpCall (tstate=0xbfde20, callable=0x7f21c679b310, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /usr/local/src/conda/python-3.9.17/Objects/call.c:191
#15 0x00000000005051f0 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x5b0c8a0, callable=0x7f21c679b310, tstate=0xbfde20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:116
#16 _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x5b0c8a0, callable=0x7f21c679b310, tstate=0xbfde20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:103
#17 method_vectorcall (method=<optimized out>, args=0x5b0c8a8, nargsf=<optimized out>, kwnames=0x0) at /usr/local/src/conda/python-3.9.17/Objects/classobject.c:53
#18 0x00000000004ec6d4 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x5b0c8a8, callable=0x7f2194e071c0, tstate=0xbfde20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118
#19 PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x5b0c8a8, callable=0x7f2194e071c0) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127
#20 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0xbfde20) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077
#21 _PyEval_EvalFrameDefault (tstate=<optimized out>, f=0x5b0c6e0, throwflag=<optimized out>) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:3489
#22 0x00000000004f8053 in _PyEval_EvalFrame (throwflag=0, f=0x5b0c6e0, tstate=0xbfde20) at /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40
#23 function_code_fastcall (tstate=0xbfde20, co=<optimized out>, args=<optimized out>, nargs=<optimized out>, globals=0x7f2221f6efc0) at /usr/local/src/conda/python-3.9.17/Objects/call.c:330
#24 0x00000000004e7d59 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0xc5af10, callable=0x7f2221fb30d0, tstate=0xbfde20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118
#25 PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0xc5af10, callable=0x7f2221fb30d0) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127
#26 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0xbfde20) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077
#27 _PyEval_EvalFrameDefault (tstate=<optimized out>, f=0xc5ada0, throwflag=<optimized out>) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:3520
#28 0x00000000004e6a8a in _PyEval_EvalFrame (throwflag=0, f=0xc5ada0, tstate=0xbfde20) at /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40
#29 _PyEval_EvalCode (tstate=<optimized out>, _co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x0, kwcount=<optimized out>, kwstep=2, defs=0x0, defcount=<optimized out>, 
    kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329
#30 0x00000000004e6717 in _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=<optimized out>, kwargs=0x0, kwcount=0, kwstep=2, defs=0x0, defcount=0, kwdefs=0x0, 
    closure=0x0, name=0x0, qualname=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4361
#31 0x00000000004e66c9 in PyEval_EvalCodeEx (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kws=<optimized out>, kwcount=0, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0)
    at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4377
#32 0x000000000059398b in PyEval_EvalCode (co=co@entry=0x7f2221bcec90, globals=globals@entry=0x7f2221f6efc0, locals=locals@entry=0x7f2221f6efc0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:828
#33 0x00000000005c1217 in run_eval_code_obj (tstate=0xbfde20, co=0x7f2221bcec90, globals=0x7f2221f6efc0, locals=0x7f2221f6efc0) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1221
#34 0x00000000005bd220 in run_mod (mod=<optimized out>, filename=<optimized out>, globals=0x7f2221f6efc0, locals=0x7f2221f6efc0, flags=<optimized out>, arena=<optimized out>) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1242
#35 0x0000000000456537 in pyrun_file (fp=0xbfb340, filename=0x7f2221ba48a0, start=<optimized out>, globals=0x7f2221f6efc0, locals=0x7f2221f6efc0, closeit=1, flags=0x7ffc9f633b98) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1140
#36 0x00000000005b6f02 in pyrun_simple_file (flags=0x7ffc9f633b98, closeit=1, filename=0x7f2221ba48a0, fp=0xbfb340) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:450
#37 PyRun_SimpleFileExFlags (fp=0xbfb340, filename=<optimized out>, closeit=1, flags=0x7ffc9f633b98) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:483
#38 0x00000000005b447e in pymain_run_file (cf=0x7ffc9f633b98, config=0xbfeab0) at /usr/local/src/conda/python-3.9.17/Modules/main.c:379
#39 pymain_run_python (exitcode=0x7ffc9f633b90) at /usr/local/src/conda/python-3.9.17/Modules/main.c:604
#40 Py_RunMain () at /usr/local/src/conda/python-3.9.17/Modules/main.c:683
#41 0x0000000000587a39 in Py_BytesMain (argc=<optimized out>, argv=<optimized out>) at /usr/local/src/conda/python-3.9.17/Modules/main.c:1129
#42 0x00007f2222029d90 in __libc_start_call_main (main=main@entry=0x5879f0 <main>, argc=argc@entry=3, argv=argv@entry=0x7ffc9f633dc8) at ../sysdeps/nptl/libc_start_call_main.h:58
#43 0x00007f2222029e40 in __libc_start_main_impl (main=0x5879f0 <main>, argc=3, argv=0x7ffc9f633dc8, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7ffc9f633db8) at ../csu/libc-start.c:392
#44 0x00000000005878ee in _start ()

@acphile
Copy link
Author

acphile commented Sep 27, 2023

Because you first issue pg1 allreduce, it cannot complete, which can block pg broadcast. abort will wait for its issued collective to complete before aborting everything.

Could you elaborate more about the situation that one collective operation can block another? In my first case, broadcast and allreduce operate on different processgroup and different tensors, so why pg1.allreduce would block pg.broadcast?

@KaimingOuyang
Copy link
Collaborator

I found the root cause.

It is because during abort, NCCL will call cudaFree to release cuda resources. However, cudaFree can cause a kernel sync.
Since now allreduce kernel is not going to complete, so we get a deadlock. If you abort pg1 first, it will not have this problem. Indeed we probably need to change it to cuMem* based alloc and free to avoid this type of deadlock.

Could you elaborate more about the situation that one collective operation can block another?

As I explained above, there are major two reasons that pg1 allreduce can block pg broadcast. One is you issue them on the same stream; the other is GPU does not have enough resources, so CUDA runtime decides to schedule them in sequence even if they are on the different stream.

@acphile
Copy link
Author

acphile commented Sep 28, 2023

Indeed we probably need to change it to cuMem* based alloc and free to avoid this type of deadlock.

Thanks for your explanation. Do you have a plan and timeline to solve this issue?

@KaimingOuyang
Copy link
Collaborator

Need to discuss with the team. Will let you know when we have a plan.

@AddyLaddy
Copy link
Collaborator

Can you try NCCL_CUMEM_ENABLE=1 to see if it resolves the deadlock in NCCL 2.18.x?
NCCL 2.19.x will have CUMEM support enabled by default.

@acphile
Copy link
Author

acphile commented Sep 29, 2023

No, it does not help.
Meanwhile I would like to report another related case. Consider the collective operations after one rank down are

print("One rank is down, other ranks continue")
pg.allreduce(tensor)
pg1.allreduce(tensor1)    
print("Suppose until now we find one rank is down, and we want to abort two collective operations")
pg._get_backend(torch.device(torch.cuda.current_device()))._abort()
print("abort")
pg1._get_backend(torch.device(torch.cuda.current_device()))._abort()
print("end")

Even we abort pg, pg1 in the same order as the collective operations, it will still be stuck.

@acphile
Copy link
Author

acphile commented Sep 29, 2023

Another thing is that if we wait until the underlying NCCL watchdog thread captures the timeout and do abort in the watchdog thread, it can abort successfully. But if we want to abort in the main process, it fails.

@KaimingOuyang
Copy link
Collaborator

I remember watch dog thread is per comm. So watch dog thread won't block each other like main thread since main thread aborts them one by one. Just want to confirm one more point, when you enable NCCL_CUMEM_ENABLE=1 and get stuck, can you show me the backtrace of all threads?

@acphile
Copy link
Author

acphile commented Sep 29, 2023

Thread 7 (Thread 0x7f272e99a640 (LWP 455407) "python3"):
#0 __futex_abstimed_wait_common64 (private=128, cancel=true, abstime=0x0, op=265, expected=0, futex_word=0x7f272ed9b158) at ./nptl/futex-internal.c:57
#1 __futex_abstimed_wait_common (cancel=true, private=128, abstime=0x0, clockid=0, expected=0, futex_word=0x7f272ed9b158) at ./nptl/futex-internal.c:87
#2 __GI___futex_abstimed_wait_cancelable64 (futex_word=futex_word@entry=0x7f272ed9b158, expected=expected@entry=0, clockid=clockid@entry=0, abstime=abstime@entry=0x0, private=private@entry=128) at ./nptl/futex-internal.c:139
#3 0x00007f281f693ac1 in __pthread_cond_wait_common (abstime=0x0, clockid=0, mutex=0x7f272ed9b108, cond=0x7f272ed9b130) at ./nptl/pthread_cond_wait.c:503
#4 ___pthread_cond_wait (cond=0x7f272ed9b130, mutex=0x7f272ed9b108) at ./nptl/pthread_cond_wait.c:627
#5 0x00007f27a267d880 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#6 0x00007f281f694b43 in start_thread (arg=) at ./nptl/pthread_create.c:442
#7 0x00007f281f726a00 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 6 (Thread 0x7f273eecc640 (LWP 455403) "python3"):
#0 0x00007f281f718d7f in __GI___poll (fds=0x7f273eeca000, nfds=65, timeout=500) at ../sysdeps/unix/sysv/linux/poll.c:29
#1 0x00007f27a267dee7 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#2 0x00007f281f694b43 in start_thread (arg=) at ./nptl/pthread_create.c:442
#3 0x00007f281f726a00 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 5 (Thread 0x7f27816b3640 (LWP 455379) "python3"):
#0 __futex_abstimed_wait_common64 (private=-857358696, cancel=true, abstime=0x7f27816b2a00, op=137, expected=0, futex_word=0x6fa3e40) at ./nptl/futex-internal.c:57
#1 __futex_abstimed_wait_common (cancel=true, private=-857358696, abstime=0x7f27816b2a00, clockid=0, expected=0, futex_word=0x6fa3e40) at ./nptl/futex-internal.c:87
#2 __GI___futex_abstimed_wait_cancelable64 (futex_word=futex_word@entry=0x6fa3e40, expected=expected@entry=0, clockid=clockid@entry=1, abstime=abstime@entry=0x7f27816b2a00, private=private@entry=0) at ./nptl/futex-internal.c:139
#3 0x00007f281f69435d in __pthread_cond_wait_common (abstime=0x7f27816b2a00, clockid=1, mutex=0x6fa3df0, cond=0x6fa3e18) at ./nptl/pthread_cond_wait.c:503
#4 ___pthread_cond_clockwait64 (abstime=0x7f27816b2a00, clockid=1, mutex=0x6fa3df0, cond=0x6fa3e18) at ./nptl/pthread_cond_wait.c:691
#5 ___pthread_cond_clockwait64 (cond=0x6fa3e18, mutex=0x6fa3df0, clockid=1, abstime=0x7f27816b2a00) at ./nptl/pthread_cond_wait.c:679
#6 0x00007f27c4e31feb in c10d::ProcessGroupNCCL::workCleanupLoop() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#7 0x00007f27c4e324e7 in c10d::ProcessGroupNCCL::ncclCommWatchdog() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#8 0x00007f27dc4dc253 in ?? () from /home/ubuntu/anaconda3/envs/dev_pytorch/bin/../lib/libstdc++.so.6
#9 0x00007f281f694b43 in start_thread (arg=) at ./nptl/pthread_create.c:442
#10 0x00007f281f726a00 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 4 (Thread 0x7f2781eb4640 (LWP 455378) "python3"):
#0 futex_wait (private=0, expected=2, futex_word=0x87656a0) at ../sysdeps/nptl/futex-internal.h:146
#1 __GI___lll_lock_wait (futex=futex@entry=0x87656a0, private=0) at ./nptl/lowlevellock.c:49
#2 0x00007f281f698082 in lll_mutex_lock_optimized (mutex=0x87656a0) at ./nptl/pthread_mutex_lock.c:48
#3 ___pthread_mutex_lock (mutex=0x87656a0) at ./nptl/pthread_mutex_lock.c:93
#4 0x00007f27c4e2bdc6 in c10d::ProcessGroupNCCL::checkForNCCLErrorsInternal(std::vector<std::shared_ptrc10d::NCCLComm, std::allocator<std::shared_ptrc10d::NCCLComm > > const&) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#5 0x00007f27c4e2c373 in c10d::ProcessGroupNCCL::WorkNCCL::checkAndSetException() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#6 0x00007f27c4e320a8 in c10d::ProcessGroupNCCL::workCleanupLoop() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#7 0x00007f27c4e324e7 in c10d::ProcessGroupNCCL::ncclCommWatchdog() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#8 0x00007f27dc4dc253 in ?? () from /home/ubuntu/anaconda3/envs/dev_pytorch/bin/../lib/libstdc++.so.6
#9 0x00007f281f694b43 in start_thread (arg=) at ./nptl/pthread_create.c:442
#10 0x00007f281f726a00 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 3 (Thread 0x7f2788d7a640 (LWP 455372) "cuda-EvtHandlr"):
#0 0x00007f281f718d7f in __GI___poll (fds=0x7f2768000c20, nfds=11, timeout=100) at ../sysdeps/unix/sysv/linux/poll.c:29
#1 0x00007f27da8b738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2 0x00007f27da97a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3 0x00007f27da8b047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4 0x00007f281f694b43 in start_thread (arg=) at ./nptl/pthread_create.c:442
#5 0x00007f281f726a00 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 2 (Thread 0x7f278957b640 (LWP 455369) "cuda-EvtHandlr"):
#0 0x00007f281f718d7f in __GI___poll (fds=0x72557f0, nfds=2, timeout=-1) at ../sysdeps/unix/sysv/linux/poll.c:29
#1 0x00007f27da8b738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2 0x00007f27da97a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3 0x00007f27da8b047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4 0x00007f281f694b43 in start_thread (arg=) at ./nptl/pthread_create.c:442
#5 0x00007f281f726a00 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 1 (Thread 0x7f281f893740 (LWP 455355) "python3"):
#0 0x00007f27daa1f4ce in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#1 0x00007f27da773f59 in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2 0x00007f27dab00d2f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3 0x00007f27dab0191e in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4 0x00007f27da778bbd in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#5 0x00007f27daaff888 in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#6 0x00007f27da733604 in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#7 0x00007f27da8d7c88 in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#8 0x00007f27a2713110 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#9 0x00007f27a2755078 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#10 0x00007f27a268e371 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
--Type for more, q to quit, c to continue without paging--
#11 0x00007f27a26948d0 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#12 0x00007f27a2668e70 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#13 0x00007f27a2651c37 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#14 0x00007f27a265c668 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#15 0x00007f27a265d242 in ?? () from /lib/x86_64-linux-gnu/libnccl.so.2
#16 0x00007f27a265d862 in pncclCommAbort () from /lib/x86_64-linux-gnu/libnccl.so.2
#17 0x00007f27c4e5fc7a in c10d::NCCLComm::ncclCommAbort(c10::optional<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#18 0x00007f27c4e2ba38 in c10d::abortCommsFromMap(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, std::vector<std::shared_ptrc10d::NCCLComm, std::allocator<std::shared_ptrc10d::NCCLComm > >, std::hash<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits, std::allocator > const, std::vector<std::shared_ptrc10d::NCCLComm, std::allocator<std::shared_ptrc10d::NCCLComm > > > > >&, int, c10::optional<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#19 0x00007f27c4e2bbf0 in c10d::ProcessGroupNCCL::abort(c10::optional<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#20 0x00007f27da036dd2 in pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_typec10d::ProcessGroupNCCL > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits, std::allocator > > const&)#61}, void, c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_typec10d::ProcessGroupNCCL > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits, std::allocator > > const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg_v, pybind11::call_guardpybind11::gil_scoped_release >(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_typec10d::ProcessGroupNCCL > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits, std::allocator > > const&)#61}&&, void ()(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_typec10d::ProcessGroupNCCL > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits, std::allocator > > const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg_v const&, pybind11::call_guardpybind11::gil_scoped_release const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) () from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#21 0x00007f27d97d61ff in pybind11::cpp_function::dispatcher(_object
, _object*, _object*) () from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#22 0x00000000005072d7 in cfunction_call (func=0x7f27c3fdd130, args=, kwargs=) at /usr/local/src/conda/python-3.9.17/Objects/methodobject.c:543
#23 0x00000000004f06ac in _PyObject_MakeTpCall (tstate=0x2037e20, callable=0x7f27c3fdd130, args=, nargs=, keywords=0x0) at /usr/local/src/conda/python-3.9.17/Objects/call.c:191
#24 0x00000000005051f0 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=, args=0x6f465b0, callable=0x7f27c3fdd130, tstate=0x2037e20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:116
#25 _PyObject_VectorcallTstate (kwnames=0x0, nargsf=, args=0x6f465b0, callable=0x7f27c3fdd130, tstate=0x2037e20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:103
#26 method_vectorcall (method=, args=0x6f465b8, nargsf=, kwnames=0x0) at /usr/local/src/conda/python-3.9.17/Objects/classobject.c:53
#27 0x00000000004ec6d4 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=, args=0x6f465b8, callable=0x7f27b33f68c0, tstate=0x2037e20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118
#28 PyObject_Vectorcall (kwnames=0x0, nargsf=, args=0x6f465b8, callable=0x7f27b33f68c0) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127
#29 call_function (kwnames=0x0, oparg=, pp_stack=, tstate=0x2037e20) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077
#30 _PyEval_EvalFrameDefault (tstate=, f=0x6f463f0, throwflag=) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:3489
#31 0x00000000004f8053 in _PyEval_EvalFrame (throwflag=0, f=0x6f463f0, tstate=0x2037e20) at /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40
#32 function_code_fastcall (tstate=0x2037e20, co=, args=, nargs=, globals=0x7f281f52efc0) at /usr/local/src/conda/python-3.9.17/Objects/call.c:330
#33 0x00000000004e7d59 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=, args=0x2095080, callable=0x7f281f5730d0, tstate=0x2037e20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118
#34 PyObject_Vectorcall (kwnames=0x0, nargsf=, args=0x2095080, callable=0x7f281f5730d0) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127
#35 call_function (kwnames=0x0, oparg=, pp_stack=, tstate=0x2037e20) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077
#36 _PyEval_EvalFrameDefault (tstate=, f=0x2094f10, throwflag=) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:3520
#37 0x00000000004e6a8a in _PyEval_EvalFrame (throwflag=0, f=0x2094f10, tstate=0x2037e20) at /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40
#38 _PyEval_EvalCode (tstate=, _co=, globals=, locals=, args=, argcount=, kwnames=0x0, kwargs=0x0, kwcount=, kwstep=2, defs=0x0, defcount=, kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329
#39 0x00000000004e6717 in _PyEval_EvalCodeWithName (_co=, globals=, locals=, args=, argcount=, kwnames=, kwargs=0x0, kwcount=0, kwstep=2, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4361
#40 0x00000000004e66c9 in PyEval_EvalCodeEx (_co=, globals=, locals=, args=, argcount=, kws=, kwcount=0, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4377
#41 0x000000000059398b in PyEval_EvalCode (co=co@entry=0x7f281f045500, globals=globals@entry=0x7f281f52efc0, locals=locals@entry=0x7f281f52efc0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:828
#42 0x00000000005c1217 in run_eval_code_obj (tstate=0x2037e20, co=0x7f281f045500, globals=0x7f281f52efc0, locals=0x7f281f52efc0) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1221
#43 0x00000000005bd220 in run_mod (mod=, filename=, globals=0x7f281f52efc0, locals=0x7f281f52efc0, flags=, arena=) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1242
#44 0x0000000000456537 in pyrun_file (fp=0x2035340, filename=0x7f281f1648a0, start=, globals=0x7f281f52efc0, locals=0x7f281f52efc0, closeit=1, flags=0x7fff07544eb8) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1140
#45 0x00000000005b6f02 in pyrun_simple_file (flags=0x7fff07544eb8, closeit=1, filename=0x7f281f1648a0, fp=0x2035340) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:450
#46 PyRun_SimpleFileExFlags (fp=0x2035340, filename=, closeit=1, flags=0x7fff07544eb8) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:483
#47 0x00000000005b447e in pymain_run_file (cf=0x7fff07544eb8, config=0x2038ab0) at /usr/local/src/conda/python-3.9.17/Modules/main.c:379
#48 pymain_run_python (exitcode=0x7fff07544eb0) at /usr/local/src/conda/python-3.9.17/Modules/main.c:604
#49 Py_RunMain () at /usr/local/src/conda/python-3.9.17/Modules/main.c:683
#50 0x0000000000587a39 in Py_BytesMain (argc=, argv=) at /usr/local/src/conda/python-3.9.17/Modules/main.c:1129
#51 0x00007f281f629d90 in __libc_start_call_main (main=main@entry=0x5879f0

, argc=argc@entry=3, argv=argv@entry=0x7fff075450e8) at ../sysdeps/nptl/libc_start_call_main.h:58
#52 0x00007f281f629e40 in __libc_start_main_impl (main=0x5879f0 , argc=3, argv=0x7fff075450e8, init=, fini=, rtld_fini=, stack_end=0x7fff075450d8) at ../csu/libc-start.c:392
#53 0x00000000005878ee in _start ()

@KaimingOuyang
Copy link
Collaborator

Hi,
Could you please try https://github.com/NVIDIA/nccl/tree/github-abort-meta, and let me know the results?
Note you must use CUDA version >= 12.2, which supports the new CUMEM host mem feature.

@acphile
Copy link
Author

acphile commented Nov 7, 2023

Hi, Kaiming. Thanks for your update but it does not work for the aforementioned script (e.g NCCL_CUMEM_ENABLE=1 torchrun xxx). To make sure the nccl is correctly built, I use the command after cloning your branch:

sudo make -j src.build CUDA_HOME=/usr/local/cuda \
    NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_70,code=sm_70" && \
    sudo make install 

I have checked the version of /usr/local/cuda is 12.2:

> nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Jul_11_02:20:44_PDT_2023
Cuda compilation tools, release 12.2, V12.2.128
Build cuda_12.2.r12.2/compiler.33053471_0

@acphile
Copy link
Author

acphile commented Nov 7, 2023

It looks like your branch is based on NCCL 2.19.3. But after I build your branch, python -c "import torch;print(torch.cuda.nccl.version())" still shows 2.18.3. Is there any way to fix that?

@KaimingOuyang
Copy link
Collaborator

It seems you don't link to the installed NCCL. Can you try to set LD_LIBRARY_PATH to the path where you install NCCL?

@acphile
Copy link
Author

acphile commented Nov 7, 2023

I have exported LD_LIBRARY_PATH as /home/ubuntu/nccl/build/lib but still does not work. I have checked /usr/local/lib/ also contains libnccl.so.2.19.3. But still, the output of this command locate nccl| grep "libnccl.so" | tail -n1 | sed -r 's/^.*\.so\.//' is 2.18.3. Do you have any idea for that?

@KaimingOuyang
Copy link
Collaborator

Pytorch might link NCCL internally. Maybe this post helps https://discuss.pytorch.org/t/ncc-version-and-pytorch-nccl-version-mismatch/87771

@chauhang
Copy link

chauhang commented Nov 8, 2023

@acphile Please set USE_SYSTEM_NCCL=1 and compile PyTorch from source as well to pick the latest compiled version of NCCL. Please see steps below as an example, you will need to include the build flags appropriate for your setup (detailed build steps are here):

export NCCL_ROOT_DIR=/path/to/nccl
export NCCL_INCLUDE_DIR=$NCCL_ROOT_DIR/include
export NCCL_LIB_DIR=$NCCL_ROOT_DIR/lib
export LIBRARY_PATH=/usr/lib/x86_64-linux-gnu
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$NCCL_ROOT_DIR/lib:$LD_LIBRARY_PATH
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}

USE_SYSTEM_NCCL=1 python setup.py install

@kwen2501
Copy link
Contributor

kwen2501 commented Nov 8, 2023

Thanks @KaimingOuyang and team for providing a fix! Wondering if there is a planned release for this fix?

And, is NCCL_CUMEM_ENABLE=1 needed to enable this fix? Trying to understand if there is any tradeoff if making it the default. Thanks!

@AddyLaddy
Copy link
Collaborator

Great news!
We're targeting NCCL 2.20.x for this fix. NCCL_CUMEM_ENABLE=1 will be the default in NCCL 2.20.x like it is in NCCL 2.19.3. However, we have found some issues with A2A connection setup in NCCL 2.19.3, and the changes from Kaiming's patch will also be necessary to complete the CUMEM support in NCCL 2.20.x.

There is no performance impact from using NCCL_CUMEM_ENABLE=1

@acphile
Copy link
Author

acphile commented Nov 8, 2023

Hi, I tried re-compile the PyTorch and now python -c "import torch;print(torch.cuda.nccl.version()) shows 2.19.3. But still the following script would hang .

def run_worker(rank, world_size):
    ndev = torch.cuda.device_count()
    device = torch.device(f'cuda:{rank % ndev}')
    torch.cuda.set_device(device)
    dist.init_process_group('nccl', rank=rank, world_size=world_size,
                            timeout=datetime.timedelta(seconds=1800)
                           )
    print(torch.distributed.get_world_size())
    pg = c10d.distributed_c10d._get_default_group()
    pg1 = torch.distributed.new_group(ranks=[2,3], timeout=datetime.timedelta(seconds=1800), backend="nccl")

    device_id = f"cuda:{rank % torch.cuda.device_count()}"
    tensor0 = torch.ones([1]).cuda(rank) * rank
    tensor = torch.ones([1]).cuda(rank) * rank
    tensor1 = torch.ones([2,3]).cuda(rank) * rank
    # for initialization purpose
    torch.distributed.all_reduce(tensor0, group=pg)
    if not isinstance(pg1, int):
        torch.distributed.all_reduce(tensor0, group=pg1)
    time.sleep(10)
    print(tensor0.item())
    # suppose before the following operations, one rank is down.
    if rank==2:
        exit()
    print("One rank is down, other ranks continue")
    if not isinstance(pg1, int):
        torch.distributed.all_reduce(tensor1, group=pg1)
    torch.distributed.all_reduce(tensor0, group=pg)
    print("Suppose until now we find one rank is down, and we want to abort two collective operations")
    pg._get_backend(torch.device(torch.cuda.current_device()))._abort()
    print("abort")
    if not isinstance(pg1, int):
        pg1._get_backend(torch.device(torch.cuda.current_device()))._abort()
    print(f"end {torch.distributed.get_rank()}")

But when we change the abort order, like

    if not isinstance(pg1, int):
        pg1._get_backend(torch.device(torch.cuda.current_device()))._abort()
    print("abort")
    pg._get_backend(torch.device(torch.cuda.current_device()))._abort()

It can abort successfully. So look like the abort order needs to match the order of the NCCL collectives. Do you have a fix for that? @KaimingOuyang

@acphile
Copy link
Author

acphile commented Nov 8, 2023

And for this case

def run_worker(rank, world_size):
    ndev = torch.cuda.device_count()
    device = torch.device(f'cuda:{rank % ndev}')
    torch.cuda.set_device(device)
    dist.init_process_group('nccl', rank=rank, world_size=world_size,
                            timeout=datetime.timedelta(seconds=1800)
                           )
    print(torch.distributed.get_world_size())
    pg = c10d.distributed_c10d._get_default_group()
    pg1 = torch.distributed.new_group(ranks=[2,3], timeout=datetime.timedelta(seconds=1800), backend="nccl")

    device_id = f"cuda:{rank % torch.cuda.device_count()}"
    tensor0 = torch.ones([1]).cuda(rank) * rank
    tensor = torch.ones([1]).cuda(rank) * rank
    tensor1 = torch.ones([2,3]).cuda(rank) * rank
    # for initialization purpose
    torch.distributed.all_reduce(tensor0, group=pg)
    if not isinstance(pg1, int):
        torch.distributed.all_reduce(tensor0, group=pg1)
    time.sleep(10)
    print(tensor0.item())
    # suppose before the following operations, one rank is down.
    if rank==2:
        exit()
    print("One rank is down, other ranks continue")
    if not isinstance(pg1, int):
        torch.distributed.all_reduce(tensor1, group=pg1)
    #torch.distributed.all_reduce(tensor0, group=pg)
    print("Suppose until now we find one rank is down, and we want to abort two collective operations")
    pg._get_backend(torch.device(torch.cuda.current_device()))._abort()
    print("abort")
    if not isinstance(pg1, int):
        pg1._get_backend(torch.device(torch.cuda.current_device()))._abort()
    print(f"end {torch.distributed.get_rank()}")

Only rank 0 and rank 1 prints end and other ranks would hang at pg._get_backend(torch.device(torch.cuda.current_device()))._abort()

@KaimingOuyang
Copy link
Collaborator

No, it should abort successfully. Can you provide me the backtrace of every thread in rank 1 and 3?

@acphile
Copy link
Author

acphile commented Nov 8, 2023

No, it should abort successfully. Can you provide me the backtrace of every thread in rank 1 and 3?

You mean for the above two cases you updates should abort successfully? And which case you want the backtrace?

@KaimingOuyang
Copy link
Collaborator

The case you get hang, i.e.

 pg._get_backend(torch.device(torch.cuda.current_device()))._abort()
print("abort")
if not isinstance(pg1, int):
  pg1._get_backend(torch.device(torch.cuda.current_device()))._abort()

@acphile
Copy link
Author

acphile commented Nov 8, 2023

#1013 (comment) For this case, it is a little weird. Initially only rank 0 prints end but after I gdb into rank 1 process and then exit gdb, other ranks begins to abort successfully.

@KaimingOuyang
Copy link
Collaborator

Can you gdb into rank 1 and provide me the backtrace when it gets hang?

On the other hand, to make sure it is not due to your OS issue, please leave your program there at least 10 seconds after calling abort and see whether all ranks abort successfully (see #992).

@acphile
Copy link
Author

acphile commented Nov 8, 2023

It is for the case here #1013 (comment)
This is the backtrace of rank 1

#0  __futex_abstimed_wait_common64 (private=128, cancel=true, abstime=0x0, op=265, expected=0, futex_word=0x7f97adfff158) at ./nptl/futex-internal.c:57
#1  __futex_abstimed_wait_common (cancel=true, private=128, abstime=0x0, clockid=0, expected=0, futex_word=0x7f97adfff158) at ./nptl/futex-internal.c:87
#2  __GI___futex_abstimed_wait_cancelable64 (futex_word=futex_word@entry=0x7f97adfff158, expected=expected@entry=0, clockid=clockid@entry=0, abstime=abstime@entry=0x0, private=private@entry=128) at ./nptl/futex-internal.c:139
#3  0x00007f9855893a41 in __pthread_cond_wait_common (abstime=0x0, clockid=0, mutex=0x7f97adfff108, cond=0x7f97adfff130) at ./nptl/pthread_cond_wait.c:503
#4  ___pthread_cond_wait (cond=0x7f97adfff130, mutex=0x7f97adfff108) at ./nptl/pthread_cond_wait.c:627
#5  0x00007f97e4c60f40 in ncclProxyGetPostedOps (added=<synthetic pointer>, proxyState=0x7f97731dacc0) at proxy.cc:721
#6  ncclProxyProgress (proxyState_=<optimized out>) at proxy.cc:877
#7  0x00007f9855894ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#8  0x00007f9855926a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 8 (Thread 0x7f97c0b5c640 (LWP 153573) "python3"):
#0  0x00007f9855918dbf in __GI___poll (fds=fds@entry=0x7f97c0b59800, nfds=nfds@entry=65, timeout=200) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f97e4c616c3 in poll (__timeout=<optimized out>, __nfds=65, __fds=0x7f97c0b59800) at /usr/include/x86_64-linux-gnu/bits/poll2.h:39
#2  ncclProxyService (_args=0x7f97731dacc0) at proxy.cc:1499
#3  0x00007f9855894ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#4  0x00007f9855926a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 7 (Thread 0x7f97c1fff640 (LWP 153565) "cuda-EvtHandlr"):
#0  0x00007f9855918dbf in __GI___poll (fds=0x7f9780000c20, nfds=11, timeout=100) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f9810cb738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f9810d7a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f9810cb047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f9855894ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#5  0x00007f9855926a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 6 (Thread 0x7f97c8df0640 (LWP 153562) "cuda-EvtHandlr"):
#0  0x00007f9855918dbf in __GI___poll (fds=0x7f9788000c20, nfds=11, timeout=100) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f9810cb738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f9810d7a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f9810cb047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f9855894ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#5  0x00007f9855926a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 5 (Thread 0x7f97c9a92640 (LWP 153560) "cuda-EvtHandlr"):
#0  0x00007f9855918dbf in __GI___poll (fds=0x7f9794000c20, nfds=11, timeout=100) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f9810cb738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f9810d7a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f9810cb047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f9855894ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#5  0x00007f9855926a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 4 (Thread 0x7f97ca78d640 (LWP 153555) "python3"):
#0  futex_wait (private=0, expected=2, futex_word=0x865a180) at ../sysdeps/nptl/futex-internal.h:146
#1  __GI___lll_lock_wait (futex=futex@entry=0x865a180, private=0) at ./nptl/lowlevellock.c:49
#2  0x00007f9855898002 in lll_mutex_lock_optimized (mutex=0x865a180) at ./nptl/pthread_mutex_lock.c:48
#3  ___pthread_mutex_lock (mutex=0x865a180) at ./nptl/pthread_mutex_lock.c:93
#4  0x00007f97fb634cf6 in c10d::ProcessGroupNCCL::checkForNCCLErrorsInternal(std::vector<std::shared_ptr<c10d::NCCLComm>, std::allocator<std::shared_ptr<c10d::NCCLComm> > > const&) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#5  0x00007f97fb6352a3 in c10d::ProcessGroupNCCL::WorkNCCL::checkAndSetException() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#6  0x00007f97fb63af60 in c10d::ProcessGroupNCCL::workCleanupLoop() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#7  0x00007f97fb63b407 in c10d::ProcessGroupNCCL::ncclCommWatchdog() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#8  0x00007f98128dc253 in ?? () from /home/ubuntu/anaconda3/envs/dev_pytorch/bin/../lib/libstdc++.so.6
#9  0x00007f9855894ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#10 0x00007f9855926a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 3 (Thread 0x7f97cb57a640 (LWP 153550) "cuda-EvtHandlr"):
#0  0x00007f9855918dbf in __GI___poll (fds=0x7f97a8000c20, nfds=11, timeout=100) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f9810cb738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f9810d7a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f9810cb047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f9855894ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#5  0x00007f9855926a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 2 (Thread 0x7f97cbd7b640 (LWP 153546) "cuda-EvtHandlr"):
#0  0x00007f9855918dbf in __GI___poll (fds=0x71477b0, nfds=2, timeout=-1) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f9810cb738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f9810d7a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f9810cb047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f9855894ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#5  0x00007f9855926a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81
--Type <RET> for more, q to quit, c to continue without paging--

Thread 1 (Thread 0x7f9855b3c740 (LWP 153527) "python3"):
#0  __futex_abstimed_wait_common64 (private=128, cancel=true, abstime=0x0, op=265, expected=153573, futex_word=0x7f97c0b5c910) at ./nptl/futex-internal.c:57
#1  __futex_abstimed_wait_common (cancel=true, private=128, abstime=0x0, clockid=0, expected=153573, futex_word=0x7f97c0b5c910) at ./nptl/futex-internal.c:87
#2  __GI___futex_abstimed_wait_cancelable64 (futex_word=futex_word@entry=0x7f97c0b5c910, expected=153573, clockid=clockid@entry=0, abstime=abstime@entry=0x0, private=private@entry=128) at ./nptl/futex-internal.c:139
#3  0x00007f9855896624 in __pthread_clockjoin_ex (threadid=140289749927488, thread_return=0x0, clockid=0, abstime=0x0, block=<optimized out>) at ./nptl/pthread_join_common.c:105
#4  0x00007f97e4c4ff67 in commFree (comm=0xa1d27a0) at init.cc:182
#5  commCleanup (comm=0xa1d27a0) at init.cc:1828
#6  commReclaim (comm=comm@entry=0xa1d27a0) at init.cc:1959
#7  0x00007f97e4c5509a in ncclCommAbort (comm=0xa1d27a0) at init.cc:2027
#8  0x00007f97fb668bba in c10d::NCCLComm::ncclCommAbort(c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#9  0x00007f97fb634968 in c10d::abortCommsFromMap(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::vector<std::shared_ptr<c10d::NCCLComm>, std::allocator<std::shared_ptr<c10d::NCCLComm> > >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::vector<std::shared_ptr<c10d::NCCLComm>, std::allocator<std::shared_ptr<c10d::NCCLComm> > > > > >&, int, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#10 0x00007f97fb634b20 in c10d::ProcessGroupNCCL::abort(c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#11 0x00007f98104402a2 in pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&)#61}, void, c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&)#61}&&, void (*)(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) () from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#12 0x00007f980fbd9b2f in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#13 0x00000000005072d7 in cfunction_call (func=0x7f980431c900, args=<optimized out>, kwargs=<optimized out>) at /usr/local/src/conda/python-3.9.17/Objects/methodobject.c:543
#14 0x00000000004f06ac in _PyObject_MakeTpCall (tstate=0x1eb7e20, callable=0x7f980431c900, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /usr/local/src/conda/python-3.9.17/Objects/call.c:191
#15 0x00000000005051f0 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x6e38a10, callable=0x7f980431c900, tstate=0x1eb7e20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:116
#16 _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x6e38a10, callable=0x7f980431c900, tstate=0x1eb7e20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:103
#17 method_vectorcall (method=<optimized out>, args=0x6e38a18, nargsf=<optimized out>, kwnames=0x0) at /usr/local/src/conda/python-3.9.17/Objects/classobject.c:53
#18 0x00000000004ec6d4 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x6e38a18, callable=0x7f97d4ca27c0, tstate=0x1eb7e20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118
#19 PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x6e38a18, callable=0x7f97d4ca27c0) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127
#20 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x1eb7e20) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077
#21 _PyEval_EvalFrameDefault (tstate=<optimized out>, f=0x6e38850, throwflag=<optimized out>) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:3489
#22 0x00000000004f8053 in _PyEval_EvalFrame (throwflag=0, f=0x6e38850, tstate=0x1eb7e20) at /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40
#23 function_code_fastcall (tstate=0x1eb7e20, co=<optimized out>, args=<optimized out>, nargs=<optimized out>, globals=0x7f98557b8040) at /usr/local/src/conda/python-3.9.17/Objects/call.c:330
#24 0x00000000004e7d59 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x1f15080, callable=0x7f98557f30d0, tstate=0x1eb7e20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118
#25 PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x1f15080, callable=0x7f98557f30d0) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127
#26 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x1eb7e20) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077
#27 _PyEval_EvalFrameDefault (tstate=<optimized out>, f=0x1f14f10, throwflag=<optimized out>) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:3520
#28 0x00000000004e6a8a in _PyEval_EvalFrame (throwflag=0, f=0x1f14f10, tstate=0x1eb7e20) at /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40
#29 _PyEval_EvalCode (tstate=<optimized out>, _co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x0, kwcount=<optimized out>, kwstep=2, defs=0x0, defcount=<optimized out>, kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329
#30 0x00000000004e6717 in _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=<optimized out>, kwargs=0x0, kwcount=0, kwstep=2, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4361
#31 0x00000000004e66c9 in PyEval_EvalCodeEx (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kws=<optimized out>, kwcount=0, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4377
#32 0x000000000059398b in PyEval_EvalCode (co=co@entry=0x7f98552ea500, globals=globals@entry=0x7f98557b8040, locals=locals@entry=0x7f98557b8040) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:828
#33 0x00000000005c1217 in run_eval_code_obj (tstate=0x1eb7e20, co=0x7f98552ea500, globals=0x7f98557b8040, locals=0x7f98557b8040) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1221
#34 0x00000000005bd220 in run_mod (mod=<optimized out>, filename=<optimized out>, globals=0x7f98557b8040, locals=0x7f98557b8040, flags=<optimized out>, arena=<optimized out>) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1242
#35 0x0000000000456537 in pyrun_file (fp=0x1eb5340, filename=0x7f98557098a0, start=<optimized out>, globals=0x7f98557b8040, locals=0x7f98557b8040, closeit=1, flags=0x7ffe759226f8) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1140
#36 0x00000000005b6f02 in pyrun_simple_file (flags=0x7ffe759226f8, closeit=1, filename=0x7f98557098a0, fp=0x1eb5340) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:450
#37 PyRun_SimpleFileExFlags (fp=0x1eb5340, filename=<optimized out>, closeit=1, flags=0x7ffe759226f8) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:483
#38 0x00000000005b447e in pymain_run_file (cf=0x7ffe759226f8, config=0x1eb8ab0) at /usr/local/src/conda/python-3.9.17/Modules/main.c:379
#39 pymain_run_python (exitcode=0x7ffe759226f0) at /usr/local/src/conda/python-3.9.17/Modules/main.c:604
#40 Py_RunMain () at /usr/local/src/conda/python-3.9.17/Modules/main.c:683
#41 0x0000000000587a39 in Py_BytesMain (argc=<optimized out>, argv=<optimized out>) at /usr/local/src/conda/python-3.9.17/Modules/main.c:1129
#42 0x00007f9855829d90 in __libc_start_call_main (main=main@entry=0x5879f0 <main>, argc=argc@entry=3, argv=argv@entry=0x7ffe75922928) at ../sysdeps/nptl/libc_start_call_main.h:58
#43 0x00007f9855829e40 in __libc_start_main_impl (main=0x5879f0 <main>, argc=3, argv=0x7ffe75922928, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7ffe75922918) at ../csu/libc-start.c:392
#44 0x00000000005878ee in _start ()

And for rank 3

Thread 12 (Thread 0x7f68d9fff640 (LWP 153586) "python3"):
#0  __futex_abstimed_wait_common64 (private=128, cancel=true, abstime=0x0, op=265, expected=0, futex_word=0x7f68f0957158) at ./nptl/futex-internal.c:57
#1  __futex_abstimed_wait_common (cancel=true, private=128, abstime=0x0, clockid=0, expected=0, futex_word=0x7f68f0957158) at ./nptl/futex-internal.c:87
#2  __GI___futex_abstimed_wait_cancelable64 (futex_word=futex_word@entry=0x7f68f0957158, expected=expected@entry=0, clockid=clockid@entry=0, abstime=abstime@entry=0x0, private=private@entry=128) at ./nptl/futex-internal.c:139
#3  0x00007f69a0093a41 in __pthread_cond_wait_common (abstime=0x0, clockid=0, mutex=0x7f68f0957108, cond=0x7f68f0957130) at ./nptl/pthread_cond_wait.c:503
#4  ___pthread_cond_wait (cond=0x7f68f0957130, mutex=0x7f68f0957108) at ./nptl/pthread_cond_wait.c:627
#5  0x00007f692f260f40 in ncclProxyGetPostedOps (added=<synthetic pointer>, proxyState=0x7f68bb1da020) at proxy.cc:721
#6  ncclProxyProgress (proxyState_=<optimized out>) at proxy.cc:877
#7  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#8  0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 11 (Thread 0x7f68f1158640 (LWP 153585) "python3"):
#0  0x00007f69a0118dbf in __GI___poll (fds=fds@entry=0x7f68f1155800, nfds=nfds@entry=65, timeout=200) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f692f2616c3 in poll (__timeout=<optimized out>, __nfds=65, __fds=0x7f68f1155800) at /usr/include/x86_64-linux-gnu/bits/poll2.h:39
#2  ncclProxyService (_args=0x7f68bb1da020) at proxy.cc:1499
#3  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#4  0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 10 (Thread 0x7f68f1d5a640 (LWP 153577) "python3"):
#0  __futex_abstimed_wait_common64 (private=128, cancel=true, abstime=0x0, op=265, expected=0, futex_word=0x7f68f215b158) at ./nptl/futex-internal.c:57
#1  __futex_abstimed_wait_common (cancel=true, private=128, abstime=0x0, clockid=0, expected=0, futex_word=0x7f68f215b158) at ./nptl/futex-internal.c:87
#2  __GI___futex_abstimed_wait_cancelable64 (futex_word=futex_word@entry=0x7f68f215b158, expected=expected@entry=0, clockid=clockid@entry=0, abstime=abstime@entry=0x0, private=private@entry=128) at ./nptl/futex-internal.c:139
#3  0x00007f69a0093a41 in __pthread_cond_wait_common (abstime=0x0, clockid=0, mutex=0x7f68f215b108, cond=0x7f68f215b130) at ./nptl/pthread_cond_wait.c:503
#4  ___pthread_cond_wait (cond=0x7f68f215b130, mutex=0x7f68f215b108) at ./nptl/pthread_cond_wait.c:627
#5  0x00007f692f260f40 in ncclProxyGetPostedOps (added=<synthetic pointer>, proxyState=0x7f68bb1dacc0) at proxy.cc:721
#6  ncclProxyProgress (proxyState_=<optimized out>) at proxy.cc:877
#7  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#8  0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 9 (Thread 0x7f68f295c640 (LWP 153576) "python3"):
#0  0x00007f69a0118dbf in __GI___poll (fds=fds@entry=0x7f68f2959800, nfds=nfds@entry=65, timeout=200) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f692f2616c3 in poll (__timeout=<optimized out>, __nfds=65, __fds=0x7f68f2959800) at /usr/include/x86_64-linux-gnu/bits/poll2.h:39
#2  ncclProxyService (_args=0x7f68bb1dacc0) at proxy.cc:1499
#3  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#4  0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 8 (Thread 0x7f68f3dff640 (LWP 153568) "cuda-EvtHandlr"):
#0  0x00007f69a0118dbf in __GI___poll (fds=0x7f68c8000c20, nfds=11, timeout=100) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f695b2b738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f695b37a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f695b2b047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#5  0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 7 (Thread 0x7f6910b5c640 (LWP 153563) "cuda-EvtHandlr"):
#0  0x00007f69a0118dbf in __GI___poll (fds=0x7f68d4000c20, nfds=11, timeout=100) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f695b2b738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f695b37a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f695b2b047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#5  0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 6 (Thread 0x7f69117fe640 (LWP 153559) "cuda-EvtHandlr"):
#0  0x00007f69a0118dbf in __GI___poll (fds=0x7f68e8000c20, nfds=11, timeout=100) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f695b2b738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f695b37a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f695b2b047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#5  0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 5 (Thread 0x7f6911fff640 (LWP 153554) "python3"):
#0  __futex_abstimed_wait_common64 (private=-1621048394, cancel=true, abstime=0x7f6911ffea00, op=137, expected=0, futex_word=0x6b198d8) at ./nptl/futex-internal.c:57
#1  __futex_abstimed_wait_common (cancel=true, private=-1621048394, abstime=0x7f6911ffea00, clockid=0, expected=0, futex_word=0x6b198d8) at ./nptl/futex-internal.c:87
#2  __GI___futex_abstimed_wait_cancelable64 (futex_word=futex_word@entry=0x6b198d8, expected=expected@entry=0, clockid=clockid@entry=1, abstime=abstime@entry=0x7f6911ffea00, private=private@entry=0) at ./nptl/futex-internal.c:139
#3  0x00007f69a00942dd in __pthread_cond_wait_common (abstime=0x7f6911ffea00, clockid=1, mutex=0x6b19888, cond=0x6b198b0) at ./nptl/pthread_cond_wait.c:503
#4  ___pthread_cond_clockwait64 (abstime=0x7f6911ffea00, clockid=1, mutex=0x6b19888, cond=0x6b198b0) at ./nptl/pthread_cond_wait.c:691
#5  ___pthread_cond_clockwait64 (cond=0x6b198b0, mutex=0x6b19888, clockid=1, abstime=0x7f6911ffea00) at ./nptl/pthread_cond_wait.c:679
#6  0x00007f6945c3ae9b in c10d::ProcessGroupNCCL::workCleanupLoop() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#7  0x00007f6945c3b407 in c10d::ProcessGroupNCCL::ncclCommWatchdog() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#8  0x00007f695cedc253 in ?? () from /home/ubuntu/anaconda3/envs/dev_pytorch/bin/../lib/libstdc++.so.6
--Type <RET> for more, q to quit, c to continue without paging--
#9  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#10 0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 4 (Thread 0x7f6914d8d640 (LWP 153553) "python3"):
#0  futex_wait (private=0, expected=2, futex_word=0x87eee60) at ../sysdeps/nptl/futex-internal.h:146
#1  __GI___lll_lock_wait (futex=futex@entry=0x87eee60, private=0) at ./nptl/lowlevellock.c:49
#2  0x00007f69a0098002 in lll_mutex_lock_optimized (mutex=0x87eee60) at ./nptl/pthread_mutex_lock.c:48
#3  ___pthread_mutex_lock (mutex=0x87eee60) at ./nptl/pthread_mutex_lock.c:93
#4  0x00007f6945c34cf6 in c10d::ProcessGroupNCCL::checkForNCCLErrorsInternal(std::vector<std::shared_ptr<c10d::NCCLComm>, std::allocator<std::shared_ptr<c10d::NCCLComm> > > const&) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#5  0x00007f6945c352a3 in c10d::ProcessGroupNCCL::WorkNCCL::checkAndSetException() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#6  0x00007f6945c3af60 in c10d::ProcessGroupNCCL::workCleanupLoop() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#7  0x00007f6945c3b407 in c10d::ProcessGroupNCCL::ncclCommWatchdog() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#8  0x00007f695cedc253 in ?? () from /home/ubuntu/anaconda3/envs/dev_pytorch/bin/../lib/libstdc++.so.6
#9  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#10 0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 3 (Thread 0x7f6915b7a640 (LWP 153547) "cuda-EvtHandlr"):
#0  0x00007f69a0118dbf in __GI___poll (fds=0x7f68fc000c20, nfds=11, timeout=100) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f695b2b738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f695b37a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f695b2b047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#5  0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 2 (Thread 0x7f691637b640 (LWP 153543) "cuda-EvtHandlr"):
#0  0x00007f69a0118dbf in __GI___poll (fds=0x72db800, nfds=2, timeout=-1) at ../sysdeps/unix/sysv/linux/poll.c:29
#1  0x00007f695b2b738f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f695b37a03f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f695b2b047f in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f69a0094ac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#5  0x00007f69a0126a40 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

Thread 1 (Thread 0x7f69a027b740 (LWP 153529) "python3"):
#0  0x00007ffe23de46e8 in ?? ()
#1  0x00007ffe23de484a in ?? ()
#2  0x00007f69a00e566d in __GI___clock_gettime (clock_id=<optimized out>, tp=<optimized out>) at ../sysdeps/unix/sysv/linux/clock_gettime.c:42
#3  0x00007f695b2ae944 in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f695b178e2e in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#5  0x00007f695b2929eb in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#6  0x00007f695b4e23e5 in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#7  0x00007f695b32b5bd in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#8  0x00007f692f2cd885 in libcudart_static_141dba5462e92d2cffd1abc474df476c510a3a8c () from /home/ubuntu/nccl/build/lib/libnccl.so.2
#9  0x00007f692f32ecf8 in cudaStreamSynchronize () from /home/ubuntu/nccl/build/lib/libnccl.so.2
#10 0x00007f692f2978e6 in ncclStrongStreamSynchronize (ss=0x7f68bb0291d0) at misc/strongstream.cc:398
#11 0x00007f692f24f323 in commDestroySync (job_=job_@entry=0x7ffe23d38520) at init.cc:1795
#12 0x00007f692f24f7be in commReclaim (comm=comm@entry=0xa366e60) at init.cc:1924
#13 0x00007f692f25509a in ncclCommAbort (comm=0xa366e60) at init.cc:2027
#14 0x00007f6945c68bba in c10d::NCCLComm::ncclCommAbort(c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#15 0x00007f6945c34968 in c10d::abortCommsFromMap(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::vector<std::shared_ptr<c10d::NCCLComm>, std::allocator<std::shared_ptr<c10d::NCCLComm> > >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::vector<std::shared_ptr<c10d::NCCLComm>, std::allocator<std::shared_ptr<c10d::NCCLComm> > > > > >&, int, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#16 0x00007f6945c34b20 in c10d::ProcessGroupNCCL::abort(c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#17 0x00007f695aa402a2 in pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&)#61}, void, c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&)#61}&&, void (*)(c10::intrusive_ptr<c10d::ProcessGroupNCCL, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroupNCCL> > const&, c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) () from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#18 0x00007f695a1d9b2f in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#19 0x00000000005072d7 in cfunction_call (func=0x7f6959d1c860, args=<optimized out>, kwargs=<optimized out>) at /usr/local/src/conda/python-3.9.17/Objects/methodobject.c:543
#20 0x00000000004f06ac in _PyObject_MakeTpCall (tstate=0x204be20, callable=0x7f6959d1c860, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /usr/local/src/conda/python-3.9.17/Objects/call.c:191
#21 0x00000000005051f0 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x6fcc580, callable=0x7f6959d1c860, tstate=0x204be20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:116
#22 _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x6fcc580, callable=0x7f6959d1c860, tstate=0x204be20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:103
#23 method_vectorcall (method=<optimized out>, args=0x6fcc588, nargsf=<optimized out>, kwnames=0x0) at /usr/local/src/conda/python-3.9.17/Objects/classobject.c:53
#24 0x00000000004ec6d4 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x6fcc588, callable=0x7f6944cec580, tstate=0x204be20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118
#25 PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x6fcc588, callable=0x7f6944cec580) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127
#26 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x204be20) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077
#27 _PyEval_EvalFrameDefault (tstate=<optimized out>, f=0x6fcc3c0, throwflag=<optimized out>) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:3489
#28 0x00000000004f8053 in _PyEval_EvalFrame (throwflag=0, f=0x6fcc3c0, tstate=0x204be20) at /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40
--Type <RET> for more, q to quit, c to continue without paging--
#29 function_code_fastcall (tstate=0x204be20, co=<optimized out>, args=<optimized out>, nargs=<optimized out>, globals=0x7f699ff21040) at /usr/local/src/conda/python-3.9.17/Objects/call.c:330
#30 0x00000000004e7d59 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x20a9080, callable=0x7f699ff5c0d0, tstate=0x204be20) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118
#31 PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x20a9080, callable=0x7f699ff5c0d0) at /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127
#32 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x204be20) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077
#33 _PyEval_EvalFrameDefault (tstate=<optimized out>, f=0x20a8f10, throwflag=<optimized out>) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:3520
#34 0x00000000004e6a8a in _PyEval_EvalFrame (throwflag=0, f=0x20a8f10, tstate=0x204be20) at /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40
#35 _PyEval_EvalCode (tstate=<optimized out>, _co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x0, kwcount=<optimized out>, kwstep=2, defs=0x0, defcount=<optimized out>, kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329
#36 0x00000000004e6717 in _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=<optimized out>, kwargs=0x0, kwcount=0, kwstep=2, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4361
#37 0x00000000004e66c9 in PyEval_EvalCodeEx (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kws=<optimized out>, kwcount=0, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:4377
#38 0x000000000059398b in PyEval_EvalCode (co=co@entry=0x7f699fa45500, globals=globals@entry=0x7f699ff21040, locals=locals@entry=0x7f699ff21040) at /usr/local/src/conda/python-3.9.17/Python/ceval.c:828
#39 0x00000000005c1217 in run_eval_code_obj (tstate=0x204be20, co=0x7f699fa45500, globals=0x7f699ff21040, locals=0x7f699ff21040) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1221
#40 0x00000000005bd220 in run_mod (mod=<optimized out>, filename=<optimized out>, globals=0x7f699ff21040, locals=0x7f699ff21040, flags=<optimized out>, arena=<optimized out>) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1242
#41 0x0000000000456537 in pyrun_file (fp=0x2049340, filename=0x7f699fb66850, start=<optimized out>, globals=0x7f699ff21040, locals=0x7f699ff21040, closeit=1, flags=0x7ffe23d395c8) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1140
#42 0x00000000005b6f02 in pyrun_simple_file (flags=0x7ffe23d395c8, closeit=1, filename=0x7f699fb66850, fp=0x2049340) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:450
#43 PyRun_SimpleFileExFlags (fp=0x2049340, filename=<optimized out>, closeit=1, flags=0x7ffe23d395c8) at /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:483
#44 0x00000000005b447e in pymain_run_file (cf=0x7ffe23d395c8, config=0x204cab0) at /usr/local/src/conda/python-3.9.17/Modules/main.c:379
#45 pymain_run_python (exitcode=0x7ffe23d395c0) at /usr/local/src/conda/python-3.9.17/Modules/main.c:604
#46 Py_RunMain () at /usr/local/src/conda/python-3.9.17/Modules/main.c:683
#47 0x0000000000587a39 in Py_BytesMain (argc=<optimized out>, argv=<optimized out>) at /usr/local/src/conda/python-3.9.17/Modules/main.c:1129
#48 0x00007f69a0029d90 in __libc_start_call_main (main=main@entry=0x5879f0 <main>, argc=argc@entry=3, argv=argv@entry=0x7ffe23d397f8) at ../sysdeps/nptl/libc_start_call_main.h:58
#49 0x00007f69a0029e40 in __libc_start_main_impl (main=0x5879f0 <main>, argc=3, argv=0x7ffe23d397f8, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7ffe23d397e8) at ../csu/libc-start.c:392
#50 0x00000000005878ee in _start ()

@wconstab
Copy link

Hmm no idea about the gdb part.

@KaimingOuyang
Copy link
Collaborator

@acphile I remember exit() won't call abort for pg based on my investigation (#1013 (comment)). Could you please verify all ranks have called abort?

@acphile
Copy link
Author

acphile commented Jan 17, 2024

the exit() rank 2 indeed would enter the abort()

#0  0x00007f454aae57f8 in __GI___clock_nanosleep (clock_id=clock_id@entry=0, flags=flags@entry=0, 
    req=req@entry=0x7fffe3f0d330, rem=rem@entry=0x0) at ../sysdeps/unix/sysv/linux/clock_nanosleep.c:78
#1  0x00007f454aaea677 in __GI___nanosleep (req=req@entry=0x7fffe3f0d330, rem=rem@entry=0x0)
    at ../sysdeps/unix/sysv/linux/nanosleep.c:25
#2  0x00007f454ab1bf2f in usleep (useconds=<optimized out>) at ../sysdeps/posix/usleep.c:31
#3  0x00007f44d9e48d6a in groupLaunch (job_=<optimized out>) at group.cc:335
#4  0x00007f44d9e49a58 in ncclGroupEndInternal () at group.cc:421
#5  ncclGroupEndInternal () at group.cc:371
#6  0x00007f44d9e4a1cc in ncclGroupEnd () at group.cc:98
#7  0x00007f44f0c2a104 in c10d::ProcessGroupNCCL::abort(c10::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#8  0x00007f44f0c2de39 in c10d::ProcessGroupNCCL::~ProcessGroupNCCL() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#9  0x00007f44f0c2e30d in c10d::ProcessGroupNCCL::~ProcessGroupNCCL() () from /home/ubuntu/pytorch/torch/lib/libtorch_cuda.so
#10 0x00007f44fe855df8 in c10d::ProcessGroup::~ProcessGroup() () from /home/ubuntu/pytorch/torch/lib/libtorch_cpu.so
#11 0x00007f44fe855f1d in c10d::ProcessGroup::~ProcessGroup() () from /home/ubuntu/pytorch/torch/lib/libtorch_cpu.so
#12 0x00007f4505e0ff60 in pybind11::class_<c10d::ProcessGroup, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> >, c10d::PyProcessGroup>::dealloc(pybind11::detail::value_and_holder&) ()
   from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#13 0x00007f45055c16be in pybind11::detail::clear_instance(_object*) () from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#14 0x00007f45055c25f5 in pybind11_object_dealloc () from /home/ubuntu/pytorch/torch/lib/libtorch_python.so
#15 0x00000000004e4277 in _Py_Dealloc (op=<optimized out>) at /usr/local/src/conda/python-3.9.17/Objects/object.c:2209

@KaimingOuyang
Copy link
Collaborator

If that's the case, what does

pg._get_backend(torch.device(torch.cuda.current_device()))._abort()

translate to?

Since you are using the master thread to abort everything. We need all ranks to call abort like

ncclGroupStart();
pg._get_backend(torch.device(torch.cuda.current_device()))._abort()
print("abort")
if not isinstance(pg1, int):
    pg1._get_backend(torch.device(torch.cuda.current_device()))._abort()
ncclGroupEnd();

Can your implementation guarantee pg and pg1 abort in the same group for all ranks?

@acphile
Copy link
Author

acphile commented Jan 17, 2024

pg._get_backend(torch.device(torch.cuda.current_device()))._abort() it should enter the implementation https://github.com/pytorch/pytorch/blob/32f93b1c689954aa55a057061154c094a3eecb6f/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L817 here where I modify like:

void ProcessGroupNCCL::abort(c10::optional<std::string> abortReason) {
  std::lock_guard<std::mutex> lock(mutex_);
  ncclGroupStart();
  abortCommsFromMap(devNCCLCommMap_, rank_, abortReason);
  abortCommsFromMap(inInitializationCommMap_, rank_, abortReason);
  ncclGroupEnd();
}

Does it meet your requirement?

Or does it require for every rank, all the communicators from different process groups should be put in a single block like

ncclGroupStart();
abort processgroup1 communicators
abort processgroup2 communicators
...
ncclGroupEnd();

If so, @wconstab do you have an API to do that?

@KaimingOuyang
Copy link
Collaborator

That means you only abort pg, which causes the problem.
You have to put pg and pg1 abort in the same group op. It is required for each rank in your case.

@wconstab
Copy link

we might need to add a 'global abort' API to ProcesGroup. (it would be class-wide and do a start/end group around aborting all the comms for all the PGs). We should open an issue for this (i think @kwen2501 was opening one) so we can discuss API specifics.

@minsii
Copy link

minsii commented Feb 13, 2024

Hi, since we have all commAbort related discussion here, putting one more question for @KaimingOuyang regarding out-of-order commAbort hang, with NCCL 2.17.1/NCCL2.18.3/NCCL2.19.4 (but fixed in NCCL 2.19.3, see more notes below).

The hanging program looks like this:

created pg0 -> created nccl comm0
created pg1 -> created nccl comm1

if (rank % 2) { // here we repro with odd/even ranks aborting in different order, in real workloads it is arbitrary on each rank
   abort(comm0)
   abort(comm1)
} else {
   abort(comm1)
   abort(comm0)
}

After live debugging, we confirmed all ranks hang in commAbort->commFree->pthread_join(service thread) (code pointer of NCCL 2.18.3, same for other failing versions: https://github.com/NVIDIA/nccl/blob/v2.18.3-1/src/init.cc#L180), because service thread (one thread per comm) is waiting for the other rank to close socket, but that rank is still in abort of the other comm.

We found that this case doesn't hang in NCCL 2.19.3, is because it explicitly calls ncclProxyTryDetach rather than pthread_join in abort , which will pthread_detach the service thread after 5sec no matter all connections are closed or not (code pointer of ncclProxyTryDetach: https://github.com/NVIDIA/nccl/blob/v2.19.3-1/src/proxy.cc#L1702).

In the later releases 2.19.4/ 2.20.3, however, ncclProxyTryDetach was reverted. Thus, the above case hangs again.

I was wondering, is there any reason why we revert ncclProxyTryDetach?

@KaimingOuyang
Copy link
Collaborator

Thank Min for digging out the root cause!
it all makes sense to me now. ncclProxyTryDetach is a try to solve abort hang issue. However, we found if the main thread exits after detaching, the cuda driver can start to reclaim all memory. However, the proxy thread might still be running. That would cause a segmentation fault. So we revert the feature in 2.19.4.

@minsii
Copy link

minsii commented Feb 13, 2024

Makes sense. Thanks for the explanation, Kaiming. So looks like we cannot count 2.19.3 as a proper fix :-(

@shuqiangzhang
Copy link

we recently had some cases multiple ranks on the same host experienced 'cuda failure out of memory' and ncclCommAbort hangs forever on the same host. Some logs from nccls: init.cc:1908 NCCL WARN Cuda failure 'out of memory' proxy.cc:1736 NCCL WARN Cuda failure 'out of memory'. Any suggestions on how to avoid the hang?

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue May 30, 2024
Imported from GitHub PR openxla/xla#13109

This introduces a flag for termination on NCCL async error. With the flag on, XLA will terminate the process on NCCL error. With the flag off, the existing behavior should remain unchanged.

The patch is motivated by several problems:

- Without this patch, the heartbeat monitor only checks communicators that are currently not use by the running executable (because it obtains the communicators with TryAcquire). Since NCCL errors cause a hang in the running communicator, most failing communicators are locked, so their async errors just go undetected. As a result, XLA often hangs until Grpc timeout even in cases when ncclCommGetAsyncError would report an error.

- Ideally we would recover by aborting the faulty communicators, but that seems to be unreliable (aborts can cause hangs if NCCL currently hangs on a different communicator than the one being aborted). NCCL team is aware of this and working on a fix (NVIDIA/nccl#1013). At the moment, there does not seem to be a reliable fast recovery mechanism short of process termination.

We propose to expose a flag for terminating the process on failure so that there is some way to detect and recover from a NCCL failure. Once the comm-abort works reliably, we will use that and propagate the error to the API user.

The patch is based on a PoC from [email protected] and [email protected].
Copybara import of the project:

--
87bea4695582041f6efae5322185482e934b79b8 by Jaroslav Sevcik <[email protected]>:

Add flag for termination on nccl error

--
96532e4462828f0de86664dffec898bbc78859af by Jaroslav Sevcik <[email protected]>:

Comment, better name for the checking method

Merging this change closes #13109

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13109 from jaro-sevcik:terminate-on-nccl-error 96532e4462828f0de86664dffec898bbc78859af
PiperOrigin-RevId: 638198800
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue May 30, 2024
Imported from GitHub PR openxla/xla#13109

This introduces a flag for termination on NCCL async error. With the flag on, XLA will terminate the process on NCCL error. With the flag off, the existing behavior should remain unchanged.

The patch is motivated by several problems:

- Without this patch, the heartbeat monitor only checks communicators that are currently not use by the running executable (because it obtains the communicators with TryAcquire). Since NCCL errors cause a hang in the running communicator, most failing communicators are locked, so their async errors just go undetected. As a result, XLA often hangs until Grpc timeout even in cases when ncclCommGetAsyncError would report an error.

- Ideally we would recover by aborting the faulty communicators, but that seems to be unreliable (aborts can cause hangs if NCCL currently hangs on a different communicator than the one being aborted). NCCL team is aware of this and working on a fix (NVIDIA/nccl#1013). At the moment, there does not seem to be a reliable fast recovery mechanism short of process termination.

We propose to expose a flag for terminating the process on failure so that there is some way to detect and recover from a NCCL failure. Once the comm-abort works reliably, we will use that and propagate the error to the API user.

The patch is based on a PoC from [email protected] and [email protected].
Copybara import of the project:

--
87bea4695582041f6efae5322185482e934b79b8 by Jaroslav Sevcik <[email protected]>:

Add flag for termination on nccl error

--
96532e4462828f0de86664dffec898bbc78859af by Jaroslav Sevcik <[email protected]>:

Comment, better name for the checking method

Merging this change closes #13109

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13109 from jaro-sevcik:terminate-on-nccl-error 96532e4462828f0de86664dffec898bbc78859af
PiperOrigin-RevId: 638198800
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jun 3, 2024
Imported from GitHub PR openxla/xla#13109

This introduces a flag for termination on NCCL async error. With the flag on, XLA will terminate the process on NCCL error. With the flag off, the existing behavior should remain unchanged.

The patch is motivated by several problems:

- Without this patch, the heartbeat monitor only checks communicators that are currently not use by the running executable (because it obtains the communicators with TryAcquire). Since NCCL errors cause a hang in the running communicator, most failing communicators are locked, so their async errors just go undetected. As a result, XLA often hangs until Grpc timeout even in cases when ncclCommGetAsyncError would report an error.

- Ideally we would recover by aborting the faulty communicators, but that seems to be unreliable (aborts can cause hangs if NCCL currently hangs on a different communicator than the one being aborted). NCCL team is aware of this and working on a fix (NVIDIA/nccl#1013). At the moment, there does not seem to be a reliable fast recovery mechanism short of process termination.

We propose to expose a flag for terminating the process on failure so that there is some way to detect and recover from a NCCL failure. Once the comm-abort works reliably, we will use that and propagate the error to the API user.

The patch is based on a PoC from [email protected] and [email protected].
Copybara import of the project:

--
ab79a15bcbcfa70d76efc69db26e15450340afac by Jaroslav Sevcik <[email protected]>:

Add flag for termination on nccl error

--
b91e63e3d8f7bacee86a1d641ae42db8e4e390ad by Jaroslav Sevcik <[email protected]>:

Comment, better name for the checking method

Merging this change closes #13109

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13109 from jaro-sevcik:terminate-on-nccl-error b91e63e3d8f7bacee86a1d641ae42db8e4e390ad
PiperOrigin-RevId: 638198800
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jun 3, 2024
Imported from GitHub PR openxla/xla#13109

This introduces a flag for termination on NCCL async error. With the flag on, XLA will terminate the process on NCCL error. With the flag off, the existing behavior should remain unchanged.

The patch is motivated by several problems:

- Without this patch, the heartbeat monitor only checks communicators that are currently not use by the running executable (because it obtains the communicators with TryAcquire). Since NCCL errors cause a hang in the running communicator, most failing communicators are locked, so their async errors just go undetected. As a result, XLA often hangs until Grpc timeout even in cases when ncclCommGetAsyncError would report an error.

- Ideally we would recover by aborting the faulty communicators, but that seems to be unreliable (aborts can cause hangs if NCCL currently hangs on a different communicator than the one being aborted). NCCL team is aware of this and working on a fix (NVIDIA/nccl#1013). At the moment, there does not seem to be a reliable fast recovery mechanism short of process termination.

We propose to expose a flag for terminating the process on failure so that there is some way to detect and recover from a NCCL failure. Once the comm-abort works reliably, we will use that and propagate the error to the API user.

The patch is based on a PoC from [email protected] and [email protected].
Copybara import of the project:

--
ab79a15bcbcfa70d76efc69db26e15450340afac by Jaroslav Sevcik <[email protected]>:

Add flag for termination on nccl error

--
b91e63e3d8f7bacee86a1d641ae42db8e4e390ad by Jaroslav Sevcik <[email protected]>:

Comment, better name for the checking method

Merging this change closes #13109

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13109 from jaro-sevcik:terminate-on-nccl-error b91e63e3d8f7bacee86a1d641ae42db8e4e390ad
PiperOrigin-RevId: 638198800
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jun 3, 2024
Imported from GitHub PR openxla/xla#13109

This introduces a flag for termination on NCCL async error. With the flag on, XLA will terminate the process on NCCL error. With the flag off, the existing behavior should remain unchanged.

The patch is motivated by several problems:

- Without this patch, the heartbeat monitor only checks communicators that are currently not use by the running executable (because it obtains the communicators with TryAcquire). Since NCCL errors cause a hang in the running communicator, most failing communicators are locked, so their async errors just go undetected. As a result, XLA often hangs until Grpc timeout even in cases when ncclCommGetAsyncError would report an error.

- Ideally we would recover by aborting the faulty communicators, but that seems to be unreliable (aborts can cause hangs if NCCL currently hangs on a different communicator than the one being aborted). NCCL team is aware of this and working on a fix (NVIDIA/nccl#1013). At the moment, there does not seem to be a reliable fast recovery mechanism short of process termination.

We propose to expose a flag for terminating the process on failure so that there is some way to detect and recover from a NCCL failure. Once the comm-abort works reliably, we will use that and propagate the error to the API user.

The patch is based on a PoC from [email protected] and [email protected].
Copybara import of the project:

--
ab79a15bcbcfa70d76efc69db26e15450340afac by Jaroslav Sevcik <[email protected]>:

Add flag for termination on nccl error

--
b91e63e3d8f7bacee86a1d641ae42db8e4e390ad by Jaroslav Sevcik <[email protected]>:

Comment, better name for the checking method

Merging this change closes #13109

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13109 from jaro-sevcik:terminate-on-nccl-error b91e63e3d8f7bacee86a1d641ae42db8e4e390ad
PiperOrigin-RevId: 638198800
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jun 4, 2024
Imported from GitHub PR openxla/xla#13109

This introduces a flag for termination on NCCL async error. With the flag on, XLA will terminate the process on NCCL error. With the flag off, the existing behavior should remain unchanged.

The patch is motivated by several problems:

- Without this patch, the heartbeat monitor only checks communicators that are currently not use by the running executable (because it obtains the communicators with TryAcquire). Since NCCL errors cause a hang in the running communicator, most failing communicators are locked, so their async errors just go undetected. As a result, XLA often hangs until Grpc timeout even in cases when ncclCommGetAsyncError would report an error.

- Ideally we would recover by aborting the faulty communicators, but that seems to be unreliable (aborts can cause hangs if NCCL currently hangs on a different communicator than the one being aborted). NCCL team is aware of this and working on a fix (NVIDIA/nccl#1013). At the moment, there does not seem to be a reliable fast recovery mechanism short of process termination.

We propose to expose a flag for terminating the process on failure so that there is some way to detect and recover from a NCCL failure. Once the comm-abort works reliably, we will use that and propagate the error to the API user.

The patch is based on a PoC from [email protected] and [email protected].
Copybara import of the project:

--
858aeacb2d689e4b03f4e3bcc0595223119143d5 by Jaroslav Sevcik <[email protected]>:

Add flag for termination on nccl error

Merging this change closes #13109

Reverts changelist 637857834

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13109 from jaro-sevcik:terminate-on-nccl-error 858aeacb2d689e4b03f4e3bcc0595223119143d5
PiperOrigin-RevId: 638198800
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jun 4, 2024
Imported from GitHub PR openxla/xla#13109

This introduces a flag for termination on NCCL async error. With the flag on, XLA will terminate the process on NCCL error. With the flag off, the existing behavior should remain unchanged.

The patch is motivated by several problems:

- Without this patch, the heartbeat monitor only checks communicators that are currently not use by the running executable (because it obtains the communicators with TryAcquire). Since NCCL errors cause a hang in the running communicator, most failing communicators are locked, so their async errors just go undetected. As a result, XLA often hangs until Grpc timeout even in cases when ncclCommGetAsyncError would report an error.

- Ideally we would recover by aborting the faulty communicators, but that seems to be unreliable (aborts can cause hangs if NCCL currently hangs on a different communicator than the one being aborted). NCCL team is aware of this and working on a fix (NVIDIA/nccl#1013). At the moment, there does not seem to be a reliable fast recovery mechanism short of process termination.

We propose to expose a flag for terminating the process on failure so that there is some way to detect and recover from a NCCL failure. Once the comm-abort works reliably, we will use that and propagate the error to the API user.

The patch is based on a PoC from [email protected] and [email protected].
Copybara import of the project:

--
858aeacb2d689e4b03f4e3bcc0595223119143d5 by Jaroslav Sevcik <[email protected]>:

Add flag for termination on nccl error

Merging this change closes #13109

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13109 from jaro-sevcik:terminate-on-nccl-error 858aeacb2d689e4b03f4e3bcc0595223119143d5
PiperOrigin-RevId: 638198800
copybara-service bot pushed a commit to openxla/xla that referenced this issue Jun 4, 2024
Imported from GitHub PR #13109

This introduces a flag for termination on NCCL async error. With the flag on, XLA will terminate the process on NCCL error. With the flag off, the existing behavior should remain unchanged.

The patch is motivated by several problems:

- Without this patch, the heartbeat monitor only checks communicators that are currently not use by the running executable (because it obtains the communicators with TryAcquire). Since NCCL errors cause a hang in the running communicator, most failing communicators are locked, so their async errors just go undetected. As a result, XLA often hangs until Grpc timeout even in cases when ncclCommGetAsyncError would report an error.

- Ideally we would recover by aborting the faulty communicators, but that seems to be unreliable (aborts can cause hangs if NCCL currently hangs on a different communicator than the one being aborted). NCCL team is aware of this and working on a fix (NVIDIA/nccl#1013). At the moment, there does not seem to be a reliable fast recovery mechanism short of process termination.

We propose to expose a flag for terminating the process on failure so that there is some way to detect and recover from a NCCL failure. Once the comm-abort works reliably, we will use that and propagate the error to the API user.

The patch is based on a PoC from [email protected] and [email protected].
Copybara import of the project:

--
858aeac by Jaroslav Sevcik <[email protected]>:

Add flag for termination on nccl error

Merging this change closes #13109

COPYBARA_INTEGRATE_REVIEW=#13109 from jaro-sevcik:terminate-on-nccl-error 858aeac
PiperOrigin-RevId: 640085317
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jun 4, 2024
Imported from GitHub PR openxla/xla#13109

This introduces a flag for termination on NCCL async error. With the flag on, XLA will terminate the process on NCCL error. With the flag off, the existing behavior should remain unchanged.

The patch is motivated by several problems:

- Without this patch, the heartbeat monitor only checks communicators that are currently not use by the running executable (because it obtains the communicators with TryAcquire). Since NCCL errors cause a hang in the running communicator, most failing communicators are locked, so their async errors just go undetected. As a result, XLA often hangs until Grpc timeout even in cases when ncclCommGetAsyncError would report an error.

- Ideally we would recover by aborting the faulty communicators, but that seems to be unreliable (aborts can cause hangs if NCCL currently hangs on a different communicator than the one being aborted). NCCL team is aware of this and working on a fix (NVIDIA/nccl#1013). At the moment, there does not seem to be a reliable fast recovery mechanism short of process termination.

We propose to expose a flag for terminating the process on failure so that there is some way to detect and recover from a NCCL failure. Once the comm-abort works reliably, we will use that and propagate the error to the API user.

The patch is based on a PoC from [email protected] and [email protected].
Copybara import of the project:

--
858aeacb2d689e4b03f4e3bcc0595223119143d5 by Jaroslav Sevcik <[email protected]>:

Add flag for termination on nccl error

Merging this change closes #13109

PiperOrigin-RevId: 640085317
kwen2501 added a commit to pytorch/pytorch that referenced this issue Oct 10, 2024
…oup"


Thanks eqy for reminding me of this RFC: #119797

This PR is meant to: 
- provide a way to abort multiple PGs without deadlocking each other.
- provide a possibility to manually handle comm errors or timeouts (and potentially recovery of such).
One can find an example from: NVIDIA/nccl#1013

## How is it different from `destroy_process_group`?
`destroy_process_group` is meant for normal exit, while `_abort_process_group` is meant for bailout upon hangs or failures. Similar to `ncclCommDestroy` vs `ncclCommAbort`. 

## What's new in `_abort_process_group`?
It added support for "group abort" semantic. The "group abort" semantic is capable of aborting multiple NCCL comms concurrently, avoiding deadlock in otherwise serialized `ncclCommAbort` executions. Details are in the [RFC](#119797) targeting [the hang issue in multi-comm case](NVIDIA/nccl#1013). `Group abort` semantic is added in NCCL 2.22.

## What's next?
Ideally, the watchdog's behavior should support "group abort" too. But this is hard to implement today due to a lack of "global view" by each PG's individual watchdog. A big semi-big refactor may be needed to "uplift" the watchdogs to a global level or consolidate them into one (i.e. one dog watching multiple PGs). 

In any case, it may not be a bad idea to experiment the "group abort" feature with a manual API first and then extend to the automatic mode (watchdog).
 
cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit to pytorch/pytorch that referenced this issue Oct 10, 2024
Thanks eqy for reminding me of this RFC: #119797

This PR is meant to: 
- provide a way to abort multiple PGs without deadlocking each other.
- provide a possibility to manually handle comm errors or timeouts (and potentially recovery of such).
One can find an example from: NVIDIA/nccl#1013

## How is it different from `destroy_process_group`?
`destroy_process_group` is meant for normal exit, while `_abort_process_group` is meant for bailout upon hangs or failures. Similar to `ncclCommDestroy` vs `ncclCommAbort`. 

## What's new in `_abort_process_group`?
It added support for "group abort" semantic. The "group abort" semantic is capable of aborting multiple NCCL comms concurrently, avoiding deadlock in otherwise serialized `ncclCommAbort` executions. Details are in the [RFC](#119797) targeting [the hang issue in multi-comm case](NVIDIA/nccl#1013). `Group abort` semantic is added in NCCL 2.22.

## What's next?
Ideally, the watchdog's behavior should support "group abort" too. But this is hard to implement today due to a lack of "global view" by each PG's individual watchdog. A big semi-big refactor may be needed to "uplift" the watchdogs to a global level or consolidate them into one (i.e. one dog watching multiple PGs). 

In any case, it may not be a bad idea to experiment the "group abort" feature with a manual API first and then extend to the automatic mode (watchdog).
 
cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit to pytorch/pytorch that referenced this issue Oct 11, 2024
…oup"


Thanks eqy for reminding me of this RFC: #119797

This PR is meant to: 
- provide a way to abort multiple PGs without deadlocking each other.
- provide a possibility to manually handle comm errors or timeouts (and potentially recovery of such).
One can find an example from: NVIDIA/nccl#1013

## How is it different from `destroy_process_group`?
`destroy_process_group` is meant for normal exit, while `_abort_process_group` is meant for bailout upon hangs or failures. Similar to `ncclCommDestroy` vs `ncclCommAbort`. 

## What's new in `_abort_process_group`?
It added support for "group abort" semantic. The "group abort" semantic is capable of aborting multiple NCCL comms concurrently, avoiding deadlock in otherwise serialized `ncclCommAbort` executions. Details are in the [RFC](#119797) targeting [the hang issue in multi-comm case](NVIDIA/nccl#1013). `Group abort` semantic is added in NCCL 2.22.

## What's next?
Ideally, the watchdog's behavior should support "group abort" too. But this is hard to implement today due to a lack of "global view" by each PG's individual watchdog. A big semi-big refactor may be needed to "uplift" the watchdogs to a global level or consolidate them into one (i.e. one dog watching multiple PGs). 

In any case, it may not be a bad idea to experiment the "group abort" feature with a manual API first and then extend to the automatic mode (watchdog).
 
cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit to pytorch/pytorch that referenced this issue Oct 11, 2024
Thanks eqy for reminding me of this RFC: #119797

This PR is meant to: 
- provide a way to abort multiple PGs without deadlocking each other.
- provide a possibility to manually handle comm errors or timeouts (and potentially recovery of such).
One can find an example from: NVIDIA/nccl#1013

## How is it different from `destroy_process_group`?
`destroy_process_group` is meant for normal exit, while `_abort_process_group` is meant for bailout upon hangs or failures. Similar to `ncclCommDestroy` vs `ncclCommAbort`. 

## What's new in `_abort_process_group`?
It added support for "group abort" semantic. The "group abort" semantic is capable of aborting multiple NCCL comms concurrently, avoiding deadlock in otherwise serialized `ncclCommAbort` executions. Details are in the [RFC](#119797) targeting [the hang issue in multi-comm case](NVIDIA/nccl#1013). `Group abort` semantic is added in NCCL 2.22.

## What's next?
Ideally, the watchdog's behavior should support "group abort" too. But this is hard to implement today due to a lack of "global view" by each PG's individual watchdog. A big semi-big refactor may be needed to "uplift" the watchdogs to a global level or consolidate them into one (i.e. one dog watching multiple PGs). 

In any case, it may not be a bad idea to experiment the "group abort" feature with a manual API first and then extend to the automatic mode (watchdog).
 
cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Oct 11, 2024
Thanks @eqy for reminding me of this RFC: #119797

This PR is meant to:
- provide a way to abort multiple PGs without deadlocking each other.
- provide a possibility to manually handle comm errors or timeouts (and potentially recovery of such).
One can find an example from: NVIDIA/nccl#1013

## How is it different from `destroy_process_group`?
`destroy_process_group` is meant for normal exit, while `_abort_process_group` is meant for bailout upon hangs or failures. Similar to `ncclCommDestroy` vs `ncclCommAbort`.

## What's new in `_abort_process_group`?
It added support for "group abort" semantic. The "group abort" semantic is capable of aborting multiple NCCL comms concurrently, avoiding deadlock in otherwise serialized `ncclCommAbort` executions. Details are in the [RFC](#119797) targeting [the hang issue in multi-comm case](NVIDIA/nccl#1013). `Group abort` semantic is added in NCCL 2.22.

## What's next?
Ideally, the watchdog's behavior should support "group abort" too. But this is hard to implement today due to a lack of "global view" by each PG's individual watchdog. A big semi-big refactor may be needed to "uplift" the watchdogs to a global level or consolidate them into one (i.e. one dog watching multiple PGs).

In any case, it may not be a bad idea to experiment the "group abort" feature with a manual API first and then extend to the automatic mode (watchdog).

Pull Request resolved: #132291
Approved by: https://github.com/eqy
dvorjackz pushed a commit to pytorch/pytorch that referenced this issue Oct 16, 2024
Thanks @eqy for reminding me of this RFC: #119797

This PR is meant to:
- provide a way to abort multiple PGs without deadlocking each other.
- provide a possibility to manually handle comm errors or timeouts (and potentially recovery of such).
One can find an example from: NVIDIA/nccl#1013

## How is it different from `destroy_process_group`?
`destroy_process_group` is meant for normal exit, while `_abort_process_group` is meant for bailout upon hangs or failures. Similar to `ncclCommDestroy` vs `ncclCommAbort`.

## What's new in `_abort_process_group`?
It added support for "group abort" semantic. The "group abort" semantic is capable of aborting multiple NCCL comms concurrently, avoiding deadlock in otherwise serialized `ncclCommAbort` executions. Details are in the [RFC](#119797) targeting [the hang issue in multi-comm case](NVIDIA/nccl#1013). `Group abort` semantic is added in NCCL 2.22.

## What's next?
Ideally, the watchdog's behavior should support "group abort" too. But this is hard to implement today due to a lack of "global view" by each PG's individual watchdog. A big semi-big refactor may be needed to "uplift" the watchdogs to a global level or consolidate them into one (i.e. one dog watching multiple PGs).

In any case, it may not be a bad idea to experiment the "group abort" feature with a manual API first and then extend to the automatic mode (watchdog).

Pull Request resolved: #132291
Approved by: https://github.com/eqy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants