Skip to content

Commit a22dea5

Browse files
[Model] Support MAP-NEO model (#5081)
Co-authored-by: Zhuohan Li <[email protected]>
1 parent 533c217 commit a22dea5

File tree

8 files changed

+18
-6
lines changed

8 files changed

+18
-6
lines changed

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
170170
parser.add_argument("--num-kv-heads", type=int, default=8)
171171
parser.add_argument("--head-size",
172172
type=int,
173-
choices=[64, 80, 96, 112, 128, 256],
173+
choices=[64, 80, 96, 112, 128, 192, 256],
174174
default=128)
175175
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
176176
parser.add_argument("--use-alibi", action="store_true")

benchmarks/kernels/benchmark_rope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def benchmark_rope_kernels_multi_lora(
9393
parser.add_argument("--num-heads", type=int, default=8)
9494
parser.add_argument("--head-size",
9595
type=int,
96-
choices=[64, 80, 96, 112, 128, 256],
96+
choices=[64, 80, 96, 112, 128, 192, 256],
9797
default=128)
9898
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
9999
parser.add_argument("--dtype",

csrc/attention/attention_kernels.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,9 @@ void paged_attention_v1_launcher(
754754
case 128:
755755
LAUNCH_PAGED_ATTENTION_V1(128);
756756
break;
757+
case 192:
758+
LAUNCH_PAGED_ATTENTION_V1(192);
759+
break;
757760
case 256:
758761
LAUNCH_PAGED_ATTENTION_V1(256);
759762
break;
@@ -911,6 +914,9 @@ void paged_attention_v2_launcher(
911914
case 128:
912915
LAUNCH_PAGED_ATTENTION_V2(128);
913916
break;
917+
case 192:
918+
LAUNCH_PAGED_ATTENTION_V2(192);
919+
break;
914920
case 256:
915921
LAUNCH_PAGED_ATTENTION_V2(256);
916922
break;

csrc/cpu/attention.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,9 @@ void paged_attention_v1_impl_launcher(
390390
case 128:
391391
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
392392
break;
393+
case 192:
394+
LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
395+
break;
393396
case 256:
394397
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
395398
break;
@@ -703,6 +706,9 @@ void paged_attention_v2_impl_launcher(
703706
case 128:
704707
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
705708
break;
709+
case 192:
710+
LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
711+
break;
706712
case 256:
707713
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
708714
break;

tests/kernels/test_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
# FlashAttention forward only supports head dimension at most 128
3030
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
31-
HEAD_SIZES = [64, 80, 96, 112, 128, 256
31+
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256
3232
] if not is_hip() else [64, 80, 96, 112, 128]
3333

3434
BLOCK_SIZES = [16, 32]

tests/kernels/test_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
NUM_TOKENS = [42] # Arbitrary values for testing
1212
NUM_LAYERS = [1] # Arbitrary values for testing
1313
NUM_HEADS = [8] # Arbitrary values for testing
14-
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
14+
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
1515
BLOCK_SIZES = [8, 16, 32]
1616

1717
# Arbitrary values for testing

tests/kernels/test_pos_encoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
IS_NEOX_STYLE = [True, False]
1212
DTYPES = [torch.half, torch.bfloat16, torch.float]
13-
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
13+
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
1414
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
1515
NUM_HEADS = [7, 17] # Arbitrary values for testing
1616
BATCH_SIZES = [1, 5] # Arbitrary values for testing

vllm/attention/ops/paged_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class PagedAttention:
3131

3232
@staticmethod
3333
def get_supported_head_sizes() -> List[int]:
34-
return [64, 80, 96, 112, 128, 256]
34+
return [64, 80, 96, 112, 128, 192, 256]
3535

3636
@staticmethod
3737
def get_kv_cache_shape(

0 commit comments

Comments
 (0)