Skip to content

feat: kv fp8 quant calibration for fa3 and flashinfer #935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions lightllm/common/basemodel/triton_kernel/destindex_copy_kv_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch

import triton
import triton.language as tl


@triton.jit
def _fwd_kernel_destindex_copy_kv_per_head_fp8(
K,
Dest_loc,
Out,
scale,
stride_k_bs,
stride_k_h,
stride_k_d,
stride_o_bs,
stride_o_h,
stride_o_d,
head_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
FP8_MIN: tl.constexpr,
FP8_MAX: tl.constexpr,
):
cur_index = tl.program_id(0)
offs_h = tl.arange(0, BLOCK_HEAD)
offs_d = tl.arange(0, BLOCK_DMODEL)

dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)

k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]

# to fp8
scale_ptrs = scale + offs_h
scales = tl.load(scale_ptrs, mask=offs_h < head_num, other=1.0)
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
k_scale = k / scales[:, None]
k_fp8 = tl.clamp(k_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)

tl.store(o_ptrs, k_fp8, mask=offs_h[:, None] < head_num)
return


@torch.no_grad()
def destindex_copy_kv_fp8(K, DestLoc, scales, Out):
if scales is None:
Out[DestLoc] = K.to(torch.float8_e4m3fn)
return

seq_len = DestLoc.shape[0]
head_num = K.shape[1]
head_dim = K.shape[2]
assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]
BLOCK_HEAD = triton.next_power_of_2(head_num)
grid = (seq_len,)
num_warps = 1

_fwd_kernel_destindex_copy_kv_per_head_fp8[grid](
K,
DestLoc,
Out,
scales,
K.stride(0),
K.stride(1),
K.stride(2),
Out.stride(0),
Out.stride(1),
Out.stride(2),
head_num,
BLOCK_DMODEL=head_dim,
BLOCK_HEAD=BLOCK_HEAD,
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
num_warps=num_warps,
num_stages=1,
)


if __name__ == "__main__":
import torch.nn.functional as F
from lightllm.utils.vllm_utils import vllm_ops

B, N_CTX, H, HEAD_DIM = 32, 1024, 16, 128
dtype = torch.bfloat16
NUM = B
dest_loc = torch.arange(NUM).cuda() * 2
kv = torch.randn((len(dest_loc), H, HEAD_DIM), dtype=dtype).cuda()
out = torch.zeros((B * N_CTX, H, HEAD_DIM), dtype=torch.uint8).cuda()
scale = kv.abs().amax(dim=(0, 2)).to(torch.float32) / 448
destindex_copy_kv_fp8(kv, dest_loc, scale, out.view(torch.float8_e4m3fn))

assert torch.allclose(
out[:, :, :HEAD_DIM][dest_loc].view(torch.float8_e4m3fn).float() * scale.view(H, 1).expand(NUM, H, 1),
kv.float(),
atol=1e-5,
rtol=1e-1,
)
151 changes: 151 additions & 0 deletions lightllm/common/basemodel/triton_kernel/q_per_head_fp8_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import torch

import triton
import triton.language as tl


@triton.jit
def _per_head_max_reduce_kernel(
Q,
Scales,
StartLoc,
stride_q_t,
stride_q_h,
stride_scales_b,
FP8_MAX: tl.constexpr,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
b_id = tl.program_id(0)
h_id = tl.program_id(1)

max_val = 0.0

start_loc = tl.load(StartLoc + b_id)
end_loc = tl.load(StartLoc + b_id + 1)
for t_offset in range(start_loc, end_loc, BLOCK_T):
t_idx = t_offset + tl.arange(0, BLOCK_T)
q_range = tl.arange(0, BLOCK_D)
q_ptrs = Q + t_idx[:, None] * stride_q_t + h_id * stride_q_h + q_range[None, :]
mask = (t_idx[:, None] < end_loc) & (q_range[None, :] < stride_q_h)
q_vals = tl.load(q_ptrs, mask=mask, other=0.0)
max_val = tl.maximum(tl.max(q_vals.abs()), max_val)

scale = tl.where(max_val > 0, max_val / FP8_MAX, 1.0)
scale_ptr = Scales + b_id * stride_scales_b + h_id
tl.store(scale_ptr, scale)


@triton.jit
def _apply_quantization_kernel(
Q,
Q_out,
BatchIds,
Scales,
stride_q_t,
stride_q_h,
stride_qout_t,
stride_qout_h,
stride_scales_b,
FP8_MIN: tl.constexpr,
FP8_MAX: tl.constexpr,
BLOCK_D: tl.constexpr,
):
t_id = tl.program_id(0)
h_id = tl.program_id(1)

batch_id = tl.load(BatchIds + t_id)
scale_ptr = Scales + batch_id * stride_scales_b + h_id
scale = tl.load(scale_ptr)

q_range = tl.arange(0, BLOCK_D)
q_ptrs = Q + t_id * stride_q_t + h_id * stride_q_h + q_range
qout_ptrs = Q_out + t_id * stride_qout_t + h_id * stride_qout_h + q_range
mask = q_range < stride_q_h
q_vals = tl.load(q_ptrs, mask=mask, other=0.0)
q_scaled = q_vals / scale
q_clamped = tl.clamp(q_scaled, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
tl.store(qout_ptrs, q_clamped, mask=q_range < stride_qout_h)


@torch.no_grad()
def q_per_head_fp8_quant(q, seq_lens, b1_start_loc, scale_out=None, batch_ids=None):
T, H, D = q.shape
B = seq_lens.shape[0]

BLOCK_D = triton.next_power_of_2(D)
BLOCK_T = 256
num_warps = 4
num_stages = 2

q_out = torch.empty_like(q, dtype=torch.float8_e4m3fn)
if scale_out is None:
scale_out = torch.empty((B, H), dtype=torch.float32, device=q.device)
if batch_ids is None:
batch_ids = torch.repeat_interleave(torch.arange(B, device=q.device), seq_lens)

_per_head_max_reduce_kernel[(B, H)](
q,
scale_out,
b1_start_loc,
q.stride(0),
q.stride(1),
scale_out.stride(0),
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
BLOCK_T=BLOCK_T,
BLOCK_D=BLOCK_D,
num_warps=num_warps,
num_stages=num_stages,
)

_apply_quantization_kernel[(T, H)](
q,
q_out,
batch_ids,
scale_out,
q.stride(0),
q.stride(1),
q_out.stride(0),
q_out.stride(1),
scale_out.stride(0),
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
BLOCK_D=BLOCK_D,
num_warps=num_warps,
num_stages=num_stages,
)
return q_out, scale_out


def ref_q_per_head_fp8_quant(q, seq_lens):
min_fp8 = torch.finfo(torch.float8_e4m3fn).min
max_fp8 = torch.finfo(torch.float8_e4m3fn).max
B = seq_lens.size(0)
device = q.device
batch_ids = torch.repeat_interleave(torch.arange(B, device=device), seq_lens)
max_per_time_head = q.abs().amax(dim=2)
max_per_bh = torch.zeros((B, max_per_time_head.size(1)), device=device, dtype=max_per_time_head.dtype)
max_per_bh.scatter_reduce_(
0,
batch_ids.unsqueeze(-1).expand(-1, max_per_time_head.size(1)),
max_per_time_head,
reduce="amax",
include_self=False,
)
scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32)
scale_expanded = scales[batch_ids].view(-1, scales.size(1), 1)
q_q = (q / scale_expanded).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn)
return q_q, scales


if __name__ == "__main__":
B, T, H, D = 200, 1000, 4, 7 * 128
seq_lens = torch.ones((B,), dtype=torch.int32).cuda() * T // B
start_locs = torch.zeros(B + 1, dtype=torch.int32).cuda()
start_locs[1:] = seq_lens.cumsum(dim=0)
q = torch.randn((T, H, D), dtype=torch.float32).cuda()

q_out, scales = q_per_head_fp8_quant(q, seq_lens, start_locs)
q_out1, scales1 = ref_q_per_head_fp8_quant(q, seq_lens)
assert torch.allclose(scales, scales1, atol=1e-10, rtol=0)
assert torch.allclose(q_out.int(), q_out1.int(), atol=1e-10, rtol=0)
6 changes: 6 additions & 0 deletions lightllm/common/calibration_fp8kv_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager


class CalibrationFP8KVMemoryManager(OfflineFP8QuantMemManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=False)
6 changes: 6 additions & 0 deletions lightllm/common/export_calibration_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager


class ExportCalibrationMemoryManager(OfflineFP8QuantMemManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=True)
8 changes: 8 additions & 0 deletions lightllm/common/mem_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.int8kv_mem_manager import INT8KVMemoryManager
from lightllm.common.calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager
from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
from lightllm.utils.log_utils import init_logger
Expand All @@ -20,6 +22,12 @@ def select_mem_manager_class(mode):
logger.info("Model kv cache using mode triton int8kv")
elif "triton_fp8kv" in mode:
raise Exception("currently only for deepseek")
elif "calibration_fp8kv" in mode:
memory_manager_class = CalibrationFP8KVMemoryManager
logger.info("Model kv cache using mode calibration fp8kv")
elif "export_fp8kv_calibration" in mode:
memory_manager_class = ExportCalibrationMemoryManager
logger.info("Model kv cache using mode calibration fp8kv")
else:
memory_manager_class = MemoryManager
logger.info("Model kv cache using mode normal")
Expand Down
Loading