Skip to content

[Common] Optimize KV cache related kernels #1914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 76 additions & 40 deletions transformer_engine/common/fused_attn/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
namespace transformer_engine {
namespace kv_cache {

constexpr int block_size = 1024;

template <typename dtype>
__global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
Expand All @@ -22,21 +24,29 @@ __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *bat
actual_b = i + 1;
}
}
bool flag = (batch_indices[0] != 0);
for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) {
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) {
if (flag || ((batch_indices[batch_idx] - batch_indices[0]) != batch_idx)) {
int num_tokens = (cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]) -
(cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]);
int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k;
int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k;
int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v;
int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v;
for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) {
*(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i);
}
for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) {
*(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i);
int num_elts = max(num_elts_k, num_elts_v);
for (int token_idx = blockIdx.x; token_idx < num_tokens; token_idx += gridDim.x) {
int src_offset = batch_indices[batch_idx] * max_seq_len + token_idx;
int des_offset = batch_idx * max_seq_len + token_idx;
dtype *k_cache_src_offset = k_cache + src_offset * num_elts_k;
dtype *k_cache_des_offset = k_cache + des_offset * num_elts_k;
dtype *v_cache_src_offset = v_cache + src_offset * num_elts_v;
dtype *v_cache_des_offset = v_cache + des_offset * num_elts_v;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
if (i < num_elts_k) {
*(k_cache_des_offset + i) = *(k_cache_src_offset + i);
}
if (i < num_elts_v) {
*(v_cache_des_offset + i) = *(v_cache_src_offset + i);
}
}
}
}
}
Expand All @@ -55,19 +65,26 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int new_token_offset = batch_idx * max_ctx_len;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int hd = h_kv * max(d_k, d_v);
for (int i = blockIdx.y; i < new_len; i += gridDim.y) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) =
*(new_k + (new_token_offset + i) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) =
*(new_v + (new_token_offset + i) * h_kv * d_v + j);
dtype *new_token_id_k = new_k + (batch_idx * max_ctx_len + i) * num_elts_k;
dtype *new_token_id_v = new_v + (batch_idx * max_ctx_len + i) * num_elts_v;
dtype *token_id_k =
k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k;
dtype *token_id_v =
v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v;
for (int j = threadIdx.x; j < hd; j += blockDim.x) {
if (j < num_elts_k) {
*(token_id_k + j) = *(new_token_id_k + j);
}
if (j < num_elts_v) {
*(token_id_v + j) = *(new_token_id_v + j);
}
}
}
}
Expand All @@ -76,14 +93,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int hd = h_kv * max(d_k, d_v);
for (int i = blockIdx.y; i < new_len; i += gridDim.y) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j);
dtype *new_token_id_k = new_k + (i * b + batch_idx) * num_elts_k;
dtype *new_token_id_v = new_v + (i * b + batch_idx) * num_elts_v;
dtype *token_id_k =
k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k;
dtype *token_id_v =
v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v;
for (int j = threadIdx.x; j < hd; j += blockDim.x) {
if (j < num_elts_k) {
*(token_id_k + j) = *(new_token_id_k + j);
}
if (j < num_elts_v) {
*(token_id_v + j) = *(new_token_id_v + j);
}
}
}
}
Expand All @@ -92,16 +119,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int hd = h_kv * max(d_k, d_v);
for (int i = blockIdx.y; i < new_len; i += gridDim.y) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) =
*(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) =
*(new_v + (cu_new_lens[batch_idx] + i) * h_kv * d_v + j);
dtype *new_token_id_k = new_k + (cu_new_lens[batch_idx] + i) * num_elts_k;
dtype *new_token_id_v = new_v + (cu_new_lens[batch_idx] + i) * num_elts_v;
dtype *token_id_k =
k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k;
dtype *token_id_v =
v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v;
for (int j = threadIdx.x; j < hd; j += blockDim.x) {
if (j < num_elts_k) {
*(token_id_k + j) = *(new_token_id_k + j);
}
if (j < num_elts_v) {
*(token_id_v + j) = *(new_token_id_v + j);
}
}
}
}
Expand All @@ -116,14 +151,15 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
bool is_non_paged, cudaStream_t stream) {
if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) {
if (is_non_paged) {
reindex_kv_cache_kernel<<<16, 256, 0, stream>>>(
reindex_kv_cache_kernel<<<max_seq_len, block_size, 0, stream>>>(
reinterpret_cast<dtype *>(k_cache.data.dptr),
reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
}
copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>(
dim3 grid_size(b, max_ctx_len);
copy_to_kv_cache_kernel<<<grid_size, block_size, 0, stream>>>(
reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr),
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
Expand Down