Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NCCL hang when invoking Reduce and ncclSend/Recv concurrently #1192

Open
weberxie opened this issue Feb 23, 2024 · 4 comments
Open

NCCL hang when invoking Reduce and ncclSend/Recv concurrently #1192

weberxie opened this issue Feb 23, 2024 · 4 comments

Comments

@weberxie
Copy link

weberxie commented Feb 23, 2024

The code is:

import argparse
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group

# Environment variables set by torch.distributed.launch
LOCAL_RANK = int(os.environ['LOCAL_RANK'])
WORLD_SIZE = int(os.environ['WORLD_SIZE'])
WORLD_RANK = int(os.environ['RANK'])

torch.cuda.set_device(LOCAL_RANK)

def run(backend):
    print(f"world size: {WORLD_SIZE}, local_rank:{LOCAL_RANK}, rank: {WORLD_RANK}")

    bucket_size = 100 * WORLD_SIZE
    send_tensor = torch.ones(bucket_size, dtype=torch.float32, device=torch.cuda.current_device()) * LOCAL_RANK
    reduce_tensor = torch.ones(1, dtype=torch.float32, device=torch.cuda.current_device()) * (LOCAL_RANK + 1)

    if LOCAL_RANK == 0:
        reduce_handle = torch.distributed.reduce(reduce_tensor, WORLD_SIZE - 1, async_op=True)
        send_handle = torch.distributed.isend(tensor=send_tensor, dst=1)
    elif LOCAL_RANK == (WORLD_SIZE - 1):
        recv_handle = torch.distributed.irecv(tensor=send_tensor, src=(WORLD_SIZE - 2))
        reduce_handle = torch.distributed.reduce(reduce_tensor, WORLD_SIZE - 1, async_op=True)
    else:
        recv_handle = torch.distributed.irecv(tensor=send_tensor, src=LOCAL_RANK - 1)
        reduce_handle = torch.distributed.reduce(reduce_tensor, WORLD_SIZE - 1, async_op=True)
        send_handle = torch.distributed.isend(tensor=send_tensor, src=LOCAL_RANK + 1)

    #torch.cuda.synchronize()
    if LOCAL_RANK == 1:
        print(reduce_tensor)
        # print(recv_tensor)

def init_processes(backend):
    dist.init_process_group(backend, rank=WORLD_RANK, world_size=WORLD_SIZE)
    run(backend)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
    parser.add_argument("--backend", type=str, default="nccl", choices=['nccl', 'gloo'])
    args = parser.parse_args()
    init_processes(backend=args.backend)
 

and the launch cmd is:

CUDA_LAUNCH_BLOCKING=0 python3 -m torch.distributed.launch  --nproc_per_node=2 test.py

The stack trace of gdb is:

Thread 1 (Thread 0x7f3f9a8ae7c0 (LWP 174454)):
#0  groupLaunch (job_=0x20ccd010) at group.cc:310
#1  0x00007f3f2e4131a8 in ncclGroupEndInternal () at group.cc:421
#2  ncclGroupEndInternal () at group.cc:376
#3  0x00007f3f2e413808 in ncclGroupEnd () at group.cc:106
#4  0x00007f3f62b873a5 in c10d::ProcessGroupNCCL::getNCCLComm(std::string const&, std::vector<c10::Device, std::allocator<c10::Device> > const&, c10d::OpType, int, bool) () from /usr/local/conda/lib/python3.9/site-
packages/torch/lib/libtorch_cuda_cpp.so
#5  0x00007f3f62b94af2 in c10d::ProcessGroupNCCL::reduce(std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&) () from /usr/local/conda/lib/python3.9/site-packages/torch/lib/libtorch_cu
da_cpp.so

According to my understanding, if SM resources can make two communication kernels run concurrently, then it will not hang. Am I correct?

Any reply will be appreciated, thanks.

@sjeaugey
Copy link
Member

send_handle = torch.distributed.irecv

This should probably be an isend.

@weberxie
Copy link
Author

@sjeaugey Thanks for your relply, I have updated the code, but I ran this program with two GPU cards, so it won't go into the else branch. So the hang issue still exists.

@weberxie
Copy link
Author

Update:

Replace the P2P op with batchP2P ops, then the hand issue was resovled.

The code is:

import os
import argparse
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group

# Environment variables set by torch.distributed.launch
LOCAL_RANK = int(os.environ['LOCAL_RANK'])
WORLD_SIZE = int(os.environ['WORLD_SIZE'])
WORLD_RANK = int(os.environ['RANK'])

torch.cuda.set_device(LOCAL_RANK)

def run(backend):
    print(f"world size: {WORLD_SIZE}, local_rank:{LOCAL_RANK}, rank: {WORLD_RANK}")

    bucket_size = 128 * WORLD_SIZE
    send_tensor = torch.ones(bucket_size, dtype=torch.float16, device=torch.cuda.current_device()) * LOCAL_RANK
    reduce_tensor = torch.ones(1, dtype=torch.float16, device=torch.cuda.current_device()) * (LOCAL_RANK + 1)

    torch.cuda.synchronize()
    if LOCAL_RANK == 0:
        reduce_handle = torch.distributed.reduce(reduce_tensor, WORLD_SIZE - 1, async_op=True)
        p2p_handle = [torch.distributed.P2POp(torch.distributed.isend, send_tensor, 1)]
    elif LOCAL_RANK == (WORLD_SIZE - 1):
        p2p_handle = [torch.distributed.P2POp(torch.distributed.irecv, send_tensor, (WORLD_SIZE - 2))]
        reduce_handle = torch.distributed.reduce(reduce_tensor, WORLD_SIZE - 1, async_op=True)
    else:
        p2p_handle = [torch.distributed.P2POp(torch.distributed.irecv, send_tensor, LOCAL_RANK - 1)]
        reduce_handle = torch.distributed.reduce(reduce_tensor, WORLD_SIZE - 1, async_op=True)
        p2p_handle.append(torch.distributed.P2POp(torch.distributed.isend, send_tensor, LOCAL_RANK + 1))
    torch.distributed.batch_isend_irecv(p2p_handle)
    if LOCAL_RANK == WORLD_SIZE - 1:
        reduce_handle.wait()
        print(reduce_tensor)
        # print(recv_tensor)

def init_processes(backend):
    dist.init_process_group(backend, rank=WORLD_RANK, world_size=WORLD_SIZE)
    run(backend)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
    parser.add_argument("--backend", type=str, default="nccl", choices=['nccl', 'gloo'])
    args = parser.parse_args()
    init_processes(backend=args.backend)
    

So, my question is, why doesn't batchP2P hang, but directly calling the asynchronous P2P interface does?

@sjeaugey Do you have any insights on this? Thanks.

@sjeaugey
Copy link
Member

Not sure. Perhaps because now, the reduce is always done first, and the p2p operation second (after the if/elif/else)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants