Skip to content

Commit

Permalink
[BugFix] Some fixes for custom allreduce kernels (#2760)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 authored and joerunde committed Mar 22, 2024
1 parent 1c0a1bc commit aceb4a7
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 250 deletions.
10 changes: 5 additions & 5 deletions csrc/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
}
return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}

Expand Down Expand Up @@ -62,9 +62,9 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
if (inp_size % 16 != 0) return false;
if (!_is_weak_contiguous(inp)) return false;
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
// 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
// <= 512k
return world_size <= 4 && inp_size <= 512 * 1024;
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
// performance improvement over NCCL.
return false;
}

void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
Expand Down Expand Up @@ -126,7 +126,7 @@ void dispose(fptr_t _fa) {
delete fa;
}

int meta_size() { return sizeof(vllm::Metadata); }
int meta_size() { return sizeof(vllm::Signal); }

void register_buffer(fptr_t _fa, torch::Tensor &t,
const std::vector<std::string> &handles,
Expand Down
227 changes: 75 additions & 152 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,17 @@

namespace vllm {

constexpr int kMaxBlocks = 64;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
struct Signal {
alignas(64) union {
uint64_t flag;
unsigned char data[8];
} start;
alignas(64) union {
uint64_t flag;
unsigned char data[8];
} end;
alignas(128) uint32_t start[kMaxBlocks][8];
alignas(128) uint32_t end[kMaxBlocks][8];
};

struct Metadata {
alignas(128) Signal sg;
alignas(128) int counter;
};
static_assert(offsetof(Metadata, counter) == 128);
static_assert(sizeof(Metadata) == 256);

struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };

struct RankSignals {
volatile Signal *signals[8];
};
struct __align__(16) RankSignals { volatile Signal *signals[8]; };

// like std::array, but aligned
template <typename T, int sz>
Expand Down Expand Up @@ -135,70 +123,49 @@ DINLINE O downcast(array_t<float, O::size> val) {
}
}

// compute flag at compile time
__host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
auto m = std::numeric_limits<uint64_t>::max();
return m >> ((8 - ngpus) * 8);
}

// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
int rank) {
constexpr auto FLAG = compute_flag(ngpus);
if (blockIdx.x == 0) {
if (threadIdx.x < ngpus)
// simultaneously write to the corresponding byte to all other ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->start.data[rank] = 255;
else if (threadIdx.x == 32)
// reset
meta->sg.end.flag = 0;
}
if (threadIdx.x == 0) {
while (meta->sg.start.flag != FLAG)
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->end[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x])
;
}
__syncthreads();
}

// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
int rank) {
constexpr auto FLAG = compute_flag(ngpus);
__syncthreads();
__shared__ int num;
if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
__syncthreads();

// Only the last completing block can perform the end synchronization
// This can ensures when the final busy wait ends, all ranks must have
// finished reading each other's buffer.
if (num == gridDim.x - 1) {
if (threadIdx.x == 32) {
// reset in a different warp
meta->counter = 0;
meta->sg.start.flag = 0;
} else if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding byte to all other ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->end.data[rank] = 255;
}
// if this is the final sync, only one block needs it
// because kernel exit can serve as sync
if constexpr (final_sync) {
if (threadIdx.x == 0) {
while (meta->sg.end.flag != FLAG)
;
}
}
}
if constexpr (!final_sync) {
if (threadIdx.x == 0) {
while (meta->sg.end.flag != FLAG)
;
}
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->start[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x])
;
}
if constexpr (!final_sync) __syncthreads();
}

template <typename P, int ngpus, typename A>
Expand All @@ -214,32 +181,32 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
volatile Metadata *meta, T *__restrict__ result,
volatile Signal *self_sg, T *__restrict__ result,
int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
start_sync<ngpus>(sg, meta, rank);
start_sync<ngpus>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
((P *)result)[idx] =
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
}
end_sync<ngpus, true>(sg, meta, rank);
end_sync<ngpus, true>(sg, self_sg, rank);
}

template <typename P>
DINLINE P *get_tmp_buf(volatile Signal *sg) {
return (P *)(((Metadata *)sg) + 1);
return (P *)(((Signal *)sg) + 1);
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
volatile Metadata *meta, T *__restrict__ result,
volatile Signal *self_sg, T *__restrict__ result,
int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
Expand All @@ -248,6 +215,7 @@ __global__ void __launch_bounds__(512, 1)
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P *ptrs[ngpus];
P *tmps[ngpus];
#pragma unroll
Expand All @@ -257,75 +225,28 @@ __global__ void __launch_bounds__(512, 1)
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
start_sync<ngpus>(sg, meta, rank);
start_sync<ngpus>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
// Maybe TODO: replace this with per-block release-acquire
// can save about 1-2us (not a lot though)
end_sync<ngpus>(sg, meta, rank);

// stage 2: allgather
for (int idx = tid; idx < part; idx += stride) {
end_sync<ngpus>(sg, self_sg, rank);

// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int dst_idx = ((rank + i) % ngpus) * part + idx;
((P *)result)[dst_idx] = tmps[i][idx];
}
}
// process the last larger partition
int remaining = size - part * ngpus;
if (tid < remaining) {
int dst_idx = tid + part * ngpus;
((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
}

// faster than this
// for (int idx = tid; idx < size; idx += stride) {
// int target_rank = idx / part;
// if (target_rank == ngpus) target_rank -= 1;
// ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
// }
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
volatile Metadata *meta,
T *__restrict__ result, int rank,
int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
constexpr int hg = ngpus / 2;
// Actually not quite half butterfly.
// This is an all-to-all within each group containing half of the ranks
// followed by cross-group add. Equivalent to half butterfly when there
// are 4 GPUs, a common case for PCIe cards like T4 and A10.
const P *ptrs[hg];
{
int start = rank - rank % hg;
#pragma unroll
for (int i = 0; i < hg; i++) {
ptrs[i] = (const P *)_dp->ptrs[i + start];
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P *)result)[dst_idx] = tmps[i][idx];
}
}
}
start_sync<ngpus>(sg, meta, rank);
for (int idx = tid; idx < size; idx += stride) {
tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
}
end_sync<ngpus>(sg, meta, rank);

auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
// do the cross group reduction
for (int idx = tid; idx < size; idx += stride) {
auto tmp = tmp_out[idx];
packed_assign_add(tmp, src[idx]);
((P *)result)[idx] = tmp;
}
}

using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
Expand All @@ -341,7 +262,7 @@ class CustomAllreduce {
// below are device pointers
RankSignals sg_;
std::unordered_map<void *, RankData *> buffers_;
Metadata *meta_;
Signal *self_sg_;

// stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_;
Expand All @@ -352,32 +273,32 @@ class CustomAllreduce {
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
*
* There's a total of sizeof(Metadata) of prefix before the actual data,
* There's a total of sizeof(Signal) of prefix before the actual data,
* so meta + 1 points to actual temporary buffer.
*
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
const cudaIpcMemHandle_t *handles,
const std::vector<int64_t> &offsets, int rank,
bool full_nvlink = true)
: rank_(rank),
world_size_(offsets.size()),
full_nvlink_(full_nvlink),
meta_(meta),
self_sg_(meta),
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
Metadata *rank_meta;
Signal *rank_sg;
if (i != rank_) {
char *handle = open_ipc_handle(&handles[i]);
handle += offsets[i];
rank_meta = (Metadata *)handle;
rank_sg = (Signal *)handle;
} else {
rank_meta = meta_;
rank_sg = self_sg_;
}
sg_.signals[i] = &rank_meta->sg;
sg_.signals[i] = rank_sg;
}
}

Expand Down Expand Up @@ -492,6 +413,10 @@ class CustomAllreduce {
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " +
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));

RankData *ptrs;
cudaStreamCaptureStatus status;
Expand All @@ -512,9 +437,9 @@ class CustomAllreduce {
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
name<T, ngpus> \
<<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
Expand All @@ -526,8 +451,6 @@ class CustomAllreduce {
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} else { \
KL(ngpus, cross_device_reduce_half_butterfly); \
} \
break; \
}
Expand Down Expand Up @@ -556,7 +479,7 @@ class CustomAllreduce {
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
int, int, int);
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
*/
} // namespace vllm
Loading

0 comments on commit aceb4a7

Please sign in to comment.