Skip to content

Commit

Permalink
Car revert (opendatahub-io#140)
Browse files Browse the repository at this point in the history
* Per @iotamudelta suggestion until the deadlocks issue is better understood
Revert "Make CAR ROCm 6.1 compatible. (opendatahub-io#137)"

This reverts commit 4d2dda6.

* Per @iotamudelta suggestion until the deadlocks issue is better understood
Revert "Optimize custom all reduce (opendatahub-io#130)"

This reverts commit 636ff01.
  • Loading branch information
gshtras authored Aug 15, 2024
1 parent 4d2dda6 commit e7c3a5c
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 116 deletions.
170 changes: 73 additions & 97 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ struct __align__(16) RankData { const void* ptrs[8]; };
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
#endif

struct __align__(16) RankSignals {
#ifndef USE_ROCM
volatile
#endif
Signal* signals[8];
};
struct __align__(16) RankSignals { volatile Signal* signals[8]; };

// like std::array, but aligned
template <typename T, int sz>
Expand Down Expand Up @@ -141,28 +136,25 @@ DINLINE O downcast(array_t<float, O::size> val) {
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes (CUDA-only).
// other volatile writes.
template <int ngpus>
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) {
#ifdef USE_ROCM
DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
__atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], 1,
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag,
__ATOMIC_RELAXED);
__atomic_thread_fence(__ATOMIC_ACQ_REL);
// wait until we got true from all ranks
while (!__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED);
while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED) < flag);
}
__syncthreads();
}
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
#else
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) {
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->end[blockIdx.x][threadIdx.x] = 0;
Expand All @@ -173,38 +165,36 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
while (!self_sg->start[blockIdx.x][threadIdx.x]);
}
__syncthreads();
}
#endif
}

// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) {
#ifdef USE_ROCM
DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// reset flag for next time
__atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
__ATOMIC_RELAXED);
__atomic_thread_fence(__ATOMIC_ACQ_REL);
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
// wait until we got true from all ranks
while (!__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED));
while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
flag);
}
if constexpr (!final_sync) __syncthreads();
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
#else
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) {
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
Expand All @@ -221,8 +211,8 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
while (!self_sg->end[blockIdx.x][threadIdx.x]);
}
if constexpr (!final_sync) __syncthreads();
}
#endif
}

template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
Expand All @@ -237,11 +227,8 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
#ifndef USE_ROCM
volatile
#endif
Signal* self_sg,
T* __restrict__ result, int rank, int size) {
volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
Expand All @@ -257,22 +244,15 @@ __global__ void __launch_bounds__(512, 1)
}

template <typename P>
DINLINE P* get_tmp_buf(
#ifndef USE_ROCM
volatile
#endif
Signal* sg) {
DINLINE P* get_tmp_buf(volatile Signal* sg) {
return (P*)(((Signal*)sg) + 1);
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
#ifndef USE_ROCM
volatile
#endif
Signal* self_sg,
T* __restrict__ result, int rank, int size) {
volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
Expand Down Expand Up @@ -475,41 +455,37 @@ class CustomAllreduce {
*/
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size,
#ifdef USE_ROCM
int threads = 512, int block_limit = 18){
#else
int threads = 512, int block_limit = 36) {
#endif
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " +
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));

RankData* ptrs;
cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"buffer address " +
std::to_string(reinterpret_cast<uint64_t>(input)) +
" is not registered!");
ptrs = it->second;
}
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " +
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));

RankData* ptrs;
cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " +
std::to_string(reinterpret_cast<uint64_t>(input)) +
" is not registered!");
ptrs = it->second;
}

size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
Expand All @@ -528,27 +504,27 @@ class CustomAllreduce {
break; \
}

switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
}

~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(cudaIpcCloseMemHandle(ptr));
~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(cudaIpcCloseMemHandle(ptr));
}
}
}
}; // namespace vllm
};
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
Expand Down
7 changes: 0 additions & 7 deletions csrc/custom_all_reduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -330,17 +330,10 @@ int main(int argc, char** argv) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// }
// }
#ifdef USE_ROCM
for (int sz = 512; sz <= (8 << 22); sz *= 2) {
run<half>(myRank, nRanks, comm, 512, 18, sz + 8 * 47, performance_test);
}
#else
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
}
#endif

cudaProfilerStop();
MPICHECK(MPI_Finalize());
return EXIT_SUCCESS;
}
16 changes: 4 additions & 12 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,10 @@ def initialize_model_parallel(
if _ENABLE_CUSTOM_ALL_REDUCE:
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)

# max size defaults to 8 MiB, increase to 16 MiB on ROCm
# due to later crossover
if is_hip():
_TP_CA_COMMUNICATOR = CustomAllreduce(group=_TP_CPU_GROUP,
device=_LOCAL_RANK,
max_size=2 * 8192 * 1024)
else:
_TP_CA_COMMUNICATOR = CustomAllreduce(
group=_TP_CPU_GROUP,
device=_LOCAL_RANK,
)
_TP_CA_COMMUNICATOR = CustomAllreduce(
group=_TP_CPU_GROUP,
device=_LOCAL_RANK,
)

# Build the pipeline model-parallel groups.
global _PP_DEVICE_GROUP, _PP_CPU_GROUP
Expand Down

0 comments on commit e7c3a5c

Please sign in to comment.