Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FA3 kvcache + split kv + gqa parallelization #1236

Draft
wants to merge 41 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e6841cb
Adding the flash3 kv cache API. Just compiling for now.
ganeshcolfax Aug 15, 2024
98610d6
start extending seqlen traits for kv cache
jayhshah Aug 15, 2024
671af33
added cache_batch_idx.
ganeshcolfax Aug 15, 2024
c671c3f
adding python interface.
ganeshcolfax Aug 15, 2024
faafc59
add test_kvcache.py.
ganeshcolfax Aug 15, 2024
e4fe20d
enable use of actual seqlen for kv cache
jayhshah Aug 16, 2024
bf50391
add new param to handle cache_batch_size
jayhshah Aug 16, 2024
a08ad4c
add semaphore for kv cache causal
jayhshah Aug 16, 2024
e33f107
add comparision with fa2.
ganeshcolfax Aug 16, 2024
bf64e86
change template parameter for SeqLenTraits for ease of further extension
jayhshah Aug 19, 2024
7c363c1
modify seqlentraits for gqa parallelism
jayhshah Aug 19, 2024
5576742
modify Ktraits for decoding QO layouts
jayhshah Aug 20, 2024
7e6cf1e
decouple types of seqlen traits q and k
jayhshah Aug 20, 2024
2433b2f
change logic of Q loads for gqa parallelization
jayhshah Aug 20, 2024
ee8b320
fix o strides
jayhshah Aug 20, 2024
6618ab5
complete gqa parallel changes for non-causal
jayhshah Aug 22, 2024
59cdccb
fix some errors
jayhshah Aug 22, 2024
b6e8f10
add causal logic
jayhshah Aug 26, 2024
7996455
add to kv cache api
jayhshah Aug 26, 2024
b2a09fd
add in lse writeout and store zero
jayhshah Sep 4, 2024
22bbff0
refactor for split kv
jayhshah Sep 5, 2024
fb84142
re-enable fp16/bf16 fwd
jayhshah Sep 5, 2024
12558b3
add 1 mma warpgroup option, enable splitkv for hdim 256
jayhshah Sep 6, 2024
42427a8
fix bug with finalize for split kv
jayhshah Sep 12, 2024
81d7bdb
delete unused files
jayhshah Sep 17, 2024
1c38e5b
add hid=64.
ganeshcolfax Sep 13, 2024
18cbd9c
change flash api for rebase
jayhshah Sep 18, 2024
c6b1c1f
avoid redundant compilation with combine kernel by only including nee…
jayhshah Sep 19, 2024
986247a
change Element to OutputType for template param in combine kernel. On…
jayhshah Sep 19, 2024
020ecf8
fix wrong tile size for hdim 64
jayhshah Sep 19, 2024
f07dcdd
revert OutputType change
jayhshah Sep 19, 2024
0375bad
changes for correct lse write out for splits=1 and splits > 1 case.
ganeshcolfax Sep 19, 2024
a52d64c
update parameters
jayhshah Sep 20, 2024
9dd6742
Merge branch 'fa3-kvcache-gqa' of github.com:Dao-AILab/flash-attentio…
jayhshah Sep 20, 2024
267628f
remove unused code
jayhshah Sep 20, 2024
8cb226b
added num_split_heuristics.
ganeshcolfax Sep 20, 2024
feacec5
Merge branch 'fa3-kvcache-gqa' of github.com:Dao-AILab/flash-attentio…
ganeshcolfax Sep 20, 2024
a5db3c1
add num_split_heuristics.
ganeshcolfax Sep 20, 2024
d9bd088
adding block_n and block_m for different headdim.
ganeshcolfax Sep 21, 2024
d7ca643
initialize semaphore when num splits != 1
jayhshah Sep 21, 2024
7876a02
add gqa decoding logic.
ganeshcolfax Sep 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/composable_kernel
Submodule composable_kernel updated 386 files
224 changes: 224 additions & 0 deletions hopper/combine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@

#pragma once

#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>

#include "kernel_traits.h"
#include "utils.h"

namespace flash {

using namespace cute;

////////////////////////////////////////////////////////////////////////////////////////////////////

template <class Element, class SmemLayout>
struct SharedStorageLSE {
cute::array_aligned<Element, cute::size_v<SmemLayout>> smem_lse;
};

// DONT use Kernel_traits here to avoid redundant compilation.
// template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
template<typename Element, typename ElementAccum, int kHeadDim, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
__global__ void combine_attn_seqk_parallel(Params const params) {
// using Element = typename Kernel_traits::OutputType;
// using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = int64_t; // Kernel_traits::index_t
constexpr int kMaxSplits = 1 << Log_max_splits;
// constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNThreads = 128; //Kernel_traits::kNThreads;

static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
static_assert(kNThreads == 128, "We assume that each block has 128 threads");

// Shared memory.
// kBlockM + 1 instead of kBlockM to reduce bank conflicts.
//__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1];
extern __shared__ char smem_[];
using SharedStorage = SharedStorageLSE<ElementAccum, Shape<Int<kMaxSplits>, Int<kBlockM+1>>>;
SharedStorage &shared_storage =
*reinterpret_cast<SharedStorage *>(smem_);
Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape<Int<kMaxSplits>, Int<kBlockM+1>>{});

// The thread and block index.
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;

const index_t lse_size = params.b * params.h * params.seqlen_q;
//if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q);

const index_t row_offset_lse = bidx * kBlockM;
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
Shape<Int<kMaxSplits>, Int<kBlockM>>{},
make_stride(lse_size, _1{}));

// LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
// This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});

// This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
Layout flat_layout = make_layout(lse_size);
Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));

Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);

constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;

// Read the LSE values from gmem and store them in shared memory, then transpose them.
constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
#pragma unroll
for (int l = 0; l < kNLsePerThread; ++l) {
const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
const int col = tidx % kBlockM;
ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
if (row < kMaxSplits) { sLSE(row,col) = lse; }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
}
// if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
__syncthreads();
Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
// To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
// each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
// kBlockM rows, so each time we load we can load 128 / kBlockM rows).
// constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
// static_assert(kThreadsPerSplit <= 32);
static_assert(kRowsPerLoadTranspose <= 32);
static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
#pragma unroll
for (int l = 0; l < kNLsePerThread; ++l) {
const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
const int col = tidx / kRowsPerLoadTranspose;
//if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY;
}
//return;

// Compute the logsumexp of the LSE along the split dimension.
ElementAccum lse_max = lse_accum(0);
#pragma unroll
for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
MaxOp<float> max_op;
lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
float lse_sum = expf(lse_accum(0) - lse_max);
#pragma unroll
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
SumOp<float> sum_op;
lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
// For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
if (params.unpadded_lse) {
const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
if (lse_offset < lse_size) {
gLSE_unpadded(lse_offset) = lse_logsum;
}
} else {
gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
}
}
//if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum);

// Store the scales exp(lse - lse_logsum) in shared memory.
#pragma unroll
for (int l = 0; l < kNLsePerThread; ++l) {
const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
const int col = tidx / kRowsPerLoadTranspose;
if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); }
}
__syncthreads();

const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
constexpr int kBlockN = kNThreads / kBlockM;
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
clear(tOrO);

// Predicates
Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
//if (cute::thread0()) print_tensor (cOaccum);
// Repeat the partitioning with identity layouts
Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }
}
// Load Oaccum in then scale and accumulate to O
for (int split = 0; split < params.num_splits; ++split) {
flash::copy</*Is_even_MN=*/false, Is_even_K>(
gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOrOaccum); ++m) {
int row = get<0>(tOcOaccum(0, m, 0));
ElementAccum lse_scale = sLSE(split,row);
#pragma unroll
for (int k = 0; k < size<2>(tOrOaccum); ++k) {
#pragma unroll
for (int i = 0; i < size<0>(tOrOaccum); ++i) {
tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
//tOrO(i, m, k) += tOrOaccum(i, m, k);
}
}
//if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); }
}
tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
}
//if (cute::thread0()) { print_tensor(tOrO); }

Tensor rO = flash::convert_type<Element>(tOrO);
// Write to gO
#pragma unroll
for (int m = 0; m < size<1>(rO); ++m) {
const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
//if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q);
if (idx < params.b * params.h * params.seqlen_q) {
//print ("final2\n");
const int batch_idx = idx / (params.h * params.seqlen_q);
const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
// The index to the rows of Q
const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride
+ head_idx * params.o_head_stride + row * params.o_row_stride;
#pragma unroll
for (int k = 0; k < size<2>(rO); ++k) {
if (Is_even_K || tOpOaccum(k)) {
const int col = get<1>(tOcOaccum(0, m, k));
Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
// TODO: Should check if this is using vectorized store, but it seems pretty fast
copy(rO(_, m, k), gO);
//if (cute::thread0()) { print ("final\n"); print_tensor(gO); }
// if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
// reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
}
}
}
}
}

}
Loading