diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index fc9621e885dc..e6f4e9e6b971 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -170,7 +170,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 256], + choices=[64, 80, 96, 112, 128, 192, 256], default=128) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 9188e811e298..00e55f6060b5 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -93,7 +93,7 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 256], + choices=[64, 80, 96, 112, 128, 192, 256], default=128) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--dtype", diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 45edc3252380..8f89f89786c3 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -754,6 +754,9 @@ void paged_attention_v1_launcher( case 128: LAUNCH_PAGED_ATTENTION_V1(128); break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; case 256: LAUNCH_PAGED_ATTENTION_V1(256); break; @@ -911,6 +914,9 @@ void paged_attention_v2_launcher( case 128: LAUNCH_PAGED_ATTENTION_V2(128); break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; case 256: LAUNCH_PAGED_ATTENTION_V2(256); break; diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 438e9bdb19f5..ed8cfbd421f0 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -390,6 +390,9 @@ void paged_attention_v1_impl_launcher( case 128: LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); break; + case 192: + LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + break; case 256: LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); break; @@ -703,6 +706,9 @@ void paged_attention_v2_impl_launcher( case 128: LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); break; + case 192: + LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + break; case 256: LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); break; diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index fdf313262ca9..8bc4766fc93c 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -28,7 +28,7 @@ # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 96, 112, 128, 256 +HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256 ] if not is_hip() else [64, 80, 96, 112, 128] BLOCK_SIZES = [16, 32] diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 9f0cb60dc16e..29572cfa5749 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -11,7 +11,7 @@ NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing -HEAD_SIZES = [64, 80, 96, 112, 128, 256] +HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] BLOCK_SIZES = [8, 16, 32] # Arbitrary values for testing diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 076730cdbae0..fbabc02bf9a9 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -10,7 +10,7 @@ IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] -HEAD_SIZES = [64, 80, 96, 112, 128, 256] +HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] ROTARY_DIMS = [None, 32] # None means rotary dim == head size NUM_HEADS = [7, 17] # Arbitrary values for testing BATCH_SIZES = [1, 5] # Arbitrary values for testing diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index e119fdcf1111..a214f40d1651 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -31,7 +31,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 256] + return [64, 80, 96, 112, 128, 192, 256] @staticmethod def get_kv_cache_shape(