Skip to content

Commit ccc900b

Browse files
authored
Add cache ops (#3)
* add reshape_and_cache_flash_kernel Signed-off-by: Zhu, Zufang <[email protected]> Signed-off-by: Zhu <[email protected]> * add reshape_and_cache Signed-off-by: Zhu, Zufang <[email protected]> Signed-off-by: Zhu <[email protected]> * add benchmark and ut Signed-off-by: Zhu <[email protected]> * add vevtorization Signed-off-by: Zhu <[email protected]> * add ut and benchmark for reshape_and_cache Signed-off-by: Zhu <[email protected]> * refine test Signed-off-by: Zhu <[email protected]> * format by pre-commit Signed-off-by: Zhu <[email protected]> * update as review Signed-off-by: Zhu <[email protected]> --------- Signed-off-by: Zhu, Zufang <[email protected]> Signed-off-by: Zhu <[email protected]>
1 parent f83755f commit ccc900b

File tree

11 files changed

+1253
-19
lines changed

11 files changed

+1253
-19
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ endif()
146146

147147
if(VLLM_GPU_LANG STREQUAL "SYCL")
148148
set(VLLM_EXT_SRC
149+
"csrc/xpu/cache.cpp"
149150
"csrc/xpu/layernorm.cpp"
150151
"csrc/xpu/torch_bindings.cpp"
151152
)
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from __future__ import annotations
4+
5+
import random
6+
import time
7+
8+
import torch
9+
from tabulate import tabulate
10+
11+
from tests import register_ops as ops
12+
from tests.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
13+
14+
15+
@torch.inference_mode()
16+
def run_benchmark(
17+
num_tokens: int,
18+
num_heads: int,
19+
head_size: int,
20+
block_size: int,
21+
num_blocks: int,
22+
dtype: torch.dtype,
23+
kv_cache_dtype: str,
24+
num_iters: int,
25+
device: str = "xpu",
26+
) -> float:
27+
"""Return latency (seconds) for given num_tokens."""
28+
29+
if kv_cache_dtype == "fp8" and head_size % 16:
30+
raise ValueError(
31+
"fp8 kv-cache requires head_size to be a multiple of 16.")
32+
33+
seed = 42
34+
random.seed(seed)
35+
torch.manual_seed(seed)
36+
torch.set_default_device(device)
37+
38+
# create random key / value tensors [T, H, D].
39+
key = torch.randn(num_tokens,
40+
num_heads,
41+
head_size,
42+
dtype=dtype,
43+
device=device)
44+
value = torch.randn_like(key)
45+
46+
# prepare the slot mapping.
47+
# each token is assigned a unique slot in the KV-cache.
48+
num_slots = block_size * num_blocks
49+
if num_tokens > num_slots:
50+
raise ValueError(
51+
"num_tokens cannot exceed the total number of cache slots")
52+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
53+
slot_mapping = torch.tensor(slot_mapping_lst,
54+
dtype=torch.long,
55+
device=device)
56+
57+
num_layers = 1 # for simplicity, we use a single layer
58+
key_caches, value_caches = create_kv_caches_with_random(
59+
num_blocks,
60+
block_size,
61+
num_layers,
62+
num_heads,
63+
head_size,
64+
kv_cache_dtype,
65+
dtype,
66+
device=device,
67+
)
68+
key_cache, value_cache = key_caches[0], value_caches[0]
69+
70+
# compute per-kernel scaling factors for fp8 conversion (if used).
71+
k_scale = (key.amax() / 64.0).to(torch.float32)
72+
v_scale = (value.amax() / 64.0).to(torch.float32)
73+
74+
def run_xpu_benchmark(n_iters: int) -> float:
75+
nonlocal key, value, key_cache, value_cache, slot_mapping
76+
torch.xpu.synchronize()
77+
start = time.perf_counter()
78+
for _ in range(n_iters):
79+
ops.reshape_and_cache(
80+
key,
81+
value,
82+
key_cache,
83+
value_cache,
84+
slot_mapping,
85+
kv_cache_dtype,
86+
k_scale,
87+
v_scale,
88+
)
89+
torch.xpu.synchronize()
90+
end = time.perf_counter()
91+
return (end - start) / n_iters
92+
93+
# warm-up
94+
run_xpu_benchmark(3)
95+
96+
lat = run_xpu_benchmark(num_iters)
97+
98+
# free tensors to mitigate OOM when sweeping
99+
del key, value, key_cache, value_cache, slot_mapping
100+
torch.xpu.empty_cache()
101+
102+
return lat
103+
104+
105+
def main(args):
106+
rows = []
107+
for exp in range(1, 12):
108+
n_tok = 2**exp
109+
lat = run_benchmark(
110+
num_tokens=n_tok,
111+
num_heads=args.num_heads,
112+
head_size=args.head_size,
113+
block_size=args.block_size,
114+
num_blocks=args.num_blocks,
115+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
116+
kv_cache_dtype=args.kv_cache_dtype,
117+
num_iters=args.iters,
118+
device="xpu",
119+
)
120+
rows.append([
121+
n_tok,
122+
args.num_heads,
123+
args.head_size,
124+
args.block_size,
125+
args.num_blocks,
126+
args.dtype,
127+
args.kv_cache_dtype,
128+
f"{lat * 1e6:.3f}",
129+
])
130+
print(
131+
tabulate(
132+
rows,
133+
headers=[
134+
"num_tokens",
135+
"num_heads",
136+
"head_size",
137+
"block_size",
138+
"num_blocks",
139+
"dtype",
140+
"kv_cache_dtype",
141+
"latency (us)",
142+
],
143+
))
144+
145+
146+
if __name__ == "__main__":
147+
import argparse
148+
149+
parser = argparse.ArgumentParser()
150+
parser.add_argument("--num-heads", type=int, default=8)
151+
parser.add_argument(
152+
"--head-size",
153+
type=int,
154+
choices=[64, 80, 96, 112, 120, 128, 192, 256],
155+
default=128,
156+
)
157+
parser.add_argument("--block-size",
158+
type=int,
159+
choices=[16, 32, 64],
160+
default=64)
161+
parser.add_argument("--num-blocks", type=int, default=1024)
162+
163+
parser.add_argument(
164+
"--dtype",
165+
type=str,
166+
choices=["half", "bfloat16"],
167+
default="half",
168+
)
169+
170+
parser.add_argument(
171+
"--kv-cache-dtype",
172+
type=str,
173+
choices=["auto", "fp8", "fp8_e4m3", "fp8_e5m2"],
174+
default="auto",
175+
)
176+
177+
parser.add_argument("--iters", type=int, default=100)
178+
args = parser.parse_args()
179+
180+
main(args)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from __future__ import annotations
4+
5+
import random
6+
import time
7+
8+
import torch
9+
from tabulate import tabulate
10+
11+
from tests import register_ops as ops
12+
from tests.utils import (STR_DTYPE_TO_TORCH_DTYPE,
13+
create_kv_caches_with_random_flash)
14+
15+
16+
@torch.inference_mode()
17+
def run_benchmark(
18+
num_tokens: int,
19+
num_heads: int,
20+
head_size: int,
21+
block_size: int,
22+
num_blocks: int,
23+
dtype: torch.dtype,
24+
kv_cache_dtype: str,
25+
num_iters: int,
26+
device: str = "xpu",
27+
) -> float:
28+
"""Return latency (seconds) for given num_tokens."""
29+
30+
if kv_cache_dtype == "fp8" and head_size % 16:
31+
raise ValueError(
32+
"fp8 kv-cache requires head_size to be a multiple of 16.")
33+
34+
seed = 42
35+
random.seed(seed)
36+
torch.manual_seed(seed)
37+
torch.set_default_device(device)
38+
39+
# create random key / value tensors [T, H, D].
40+
key = torch.randn(num_tokens,
41+
num_heads,
42+
head_size,
43+
dtype=dtype,
44+
device=device)
45+
value = torch.randn_like(key)
46+
47+
# prepare the slot mapping.
48+
# each token is assigned a unique slot in the KV-cache.
49+
num_slots = block_size * num_blocks
50+
if num_tokens > num_slots:
51+
raise ValueError(
52+
"num_tokens cannot exceed the total number of cache slots")
53+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
54+
slot_mapping = torch.tensor(slot_mapping_lst,
55+
dtype=torch.long,
56+
device=device)
57+
58+
num_layers = 1 # for simplicity, we use a single layer
59+
key_caches, value_caches = create_kv_caches_with_random_flash(
60+
num_blocks,
61+
block_size,
62+
num_layers,
63+
num_heads,
64+
head_size,
65+
kv_cache_dtype,
66+
dtype,
67+
device=device,
68+
)
69+
key_cache, value_cache = key_caches[0], value_caches[0]
70+
71+
# compute per-kernel scaling factors for fp8 conversion (if used).
72+
k_scale = (key.amax() / 64.0).to(torch.float32)
73+
v_scale = (value.amax() / 64.0).to(torch.float32)
74+
75+
def run_xpu_benchmark(n_iters: int) -> float:
76+
nonlocal key, value, key_cache, value_cache, slot_mapping
77+
torch.xpu.synchronize()
78+
start = time.perf_counter()
79+
for _ in range(n_iters):
80+
ops.reshape_and_cache_flash(
81+
key,
82+
value,
83+
key_cache,
84+
value_cache,
85+
slot_mapping,
86+
kv_cache_dtype,
87+
k_scale,
88+
v_scale,
89+
)
90+
torch.xpu.synchronize()
91+
end = time.perf_counter()
92+
return (end - start) / n_iters
93+
94+
# warm-up
95+
run_xpu_benchmark(3)
96+
97+
lat = run_xpu_benchmark(num_iters)
98+
99+
# free tensors to mitigate OOM when sweeping
100+
del key, value, key_cache, value_cache, slot_mapping
101+
torch.xpu.empty_cache()
102+
103+
return lat
104+
105+
106+
def main(args):
107+
rows = []
108+
for exp in range(1, 12):
109+
n_tok = 2**exp
110+
lat = run_benchmark(
111+
num_tokens=n_tok,
112+
num_heads=args.num_heads,
113+
head_size=args.head_size,
114+
block_size=args.block_size,
115+
num_blocks=args.num_blocks,
116+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
117+
kv_cache_dtype=args.kv_cache_dtype,
118+
num_iters=args.iters,
119+
device="xpu",
120+
)
121+
rows.append([
122+
n_tok,
123+
args.num_heads,
124+
args.head_size,
125+
args.block_size,
126+
args.num_blocks,
127+
args.dtype,
128+
args.kv_cache_dtype,
129+
f"{lat * 1e6:.3f}",
130+
])
131+
print(
132+
tabulate(
133+
rows,
134+
headers=[
135+
"num_tokens",
136+
"num_heads",
137+
"head_size",
138+
"block_size",
139+
"num_blocks",
140+
"dtype",
141+
"kv_cache_dtype",
142+
"latency (us)",
143+
],
144+
))
145+
146+
147+
if __name__ == "__main__":
148+
import argparse
149+
150+
parser = argparse.ArgumentParser()
151+
parser.add_argument("--num-heads", type=int, default=8)
152+
parser.add_argument(
153+
"--head-size",
154+
type=int,
155+
choices=[64, 80, 96, 112, 120, 128, 192, 256],
156+
default=128,
157+
)
158+
parser.add_argument("--block-size",
159+
type=int,
160+
choices=[16, 32, 64],
161+
default=64)
162+
parser.add_argument("--num-blocks", type=int, default=512)
163+
164+
parser.add_argument(
165+
"--dtype",
166+
type=str,
167+
choices=["half", "bfloat16"],
168+
default="half",
169+
)
170+
171+
parser.add_argument(
172+
"--kv-cache-dtype",
173+
type=str,
174+
choices=["auto", "fp8", "fp8_e4m3", "fp8_e5m2"],
175+
default="auto",
176+
)
177+
178+
parser.add_argument("--iters", type=int, default=100)
179+
args = parser.parse_args()
180+
181+
main(args)

0 commit comments

Comments
 (0)