Skip to content

Commit

Permalink
[Model] Support MAP-NEO model (#5081)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
  • Loading branch information
xingweiqu and zhuohan123 committed May 31, 2024
1 parent 533c217 commit a22dea5
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 256],
choices=[64, 80, 96, 112, 128, 192, 256],
default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true")
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def benchmark_rope_kernels_multi_lora(
parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 256],
choices=[64, 80, 96, 112, 128, 192, 256],
default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype",
Expand Down
6 changes: 6 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,9 @@ void paged_attention_v1_launcher(
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V1(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V1(256);
break;
Expand Down Expand Up @@ -911,6 +914,9 @@ void paged_attention_v2_launcher(
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V2(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V2(256);
break;
Expand Down
6 changes: 6 additions & 0 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ void paged_attention_v1_impl_launcher(
case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 192:
LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break;
case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
Expand Down Expand Up @@ -703,6 +706,9 @@ void paged_attention_v2_impl_launcher(
case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 192:
LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break;
case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 128, 256
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256
] if not is_hip() else [64, 80, 96, 112, 128]

BLOCK_SIZES = [16, 32]
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
BLOCK_SIZES = [8, 16, 32]

# Arbitrary values for testing
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES = [1, 5] # Arbitrary values for testing
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
return [64, 80, 96, 112, 128, 192, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down

0 comments on commit a22dea5

Please sign in to comment.