Skip to content

Commit 31a500c

Browse files
authored
[Core] [N-gram SD Optimization][1/n] Propose tokens with a single KMP (#22437)
Signed-off-by: Jialin Ouyang <[email protected]>
1 parent 4e8614e commit 31a500c

File tree

6 files changed

+389
-207
lines changed

6 files changed

+389
-207
lines changed

benchmarks/benchmark_block_pool.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import gc
4+
5+
from tabulate import tabulate
6+
7+
from benchmark_utils import TimeCollector
8+
from vllm.utils import FlexibleArgumentParser
9+
from vllm.v1.core.block_pool import BlockPool
10+
11+
12+
def main(args):
13+
rows = []
14+
for allocate_block in args.allocate_blocks:
15+
# Enforce a GC collect ahead to minimize the impact among runs
16+
gc.collect()
17+
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
18+
19+
get_blocks_times = TimeCollector(TimeCollector.US)
20+
free_blocks_times = TimeCollector(TimeCollector.US)
21+
for _ in range(args.num_iteration):
22+
with get_blocks_times:
23+
blocks = block_pool.get_new_blocks(allocate_block)
24+
with free_blocks_times:
25+
block_pool.free_blocks(blocks)
26+
27+
rows.append(
28+
[get_blocks_times.cnt, args.num_gpu_blocks, allocate_block]
29+
+ get_blocks_times.dump_avg_max()
30+
+ free_blocks_times.dump_avg_max()
31+
)
32+
33+
print(
34+
tabulate(
35+
rows,
36+
headers=[
37+
"Iterations",
38+
"Total\nBlocks",
39+
"Allocated\nBlocks",
40+
"Get Blocks\nAvg (us)",
41+
"Get Blocks\nMax (us)",
42+
"Free Blocks\nAvg (us)",
43+
"Free Blocks\nMax (us)",
44+
],
45+
tablefmt="grid",
46+
floatfmt=".3f",
47+
)
48+
)
49+
50+
51+
def invoke_main() -> None:
52+
parser = FlexibleArgumentParser(
53+
description="Benchmark the performance of BlockPool for KV Cache."
54+
)
55+
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
56+
parser.add_argument(
57+
"--num-iteration",
58+
type=int,
59+
default=1000,
60+
help="Number of iterations to run to stablize final data readings",
61+
)
62+
parser.add_argument(
63+
"--allocate-blocks",
64+
type=int,
65+
nargs="*",
66+
default=[10, 50, 100, 500, 1000],
67+
help="Number of blocks to allocate",
68+
)
69+
args = parser.parse_args()
70+
main(args)
71+
72+
73+
if __name__ == "__main__":
74+
invoke_main() # pragma: no cover
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import gc
4+
5+
import numpy as np
6+
from tabulate import tabulate
7+
8+
from benchmark_utils import TimeCollector
9+
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
10+
from vllm.utils import FlexibleArgumentParser
11+
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
12+
13+
14+
def main(args):
15+
rows = []
16+
for max_ngram in args.max_ngram:
17+
collector = TimeCollector(TimeCollector.US)
18+
19+
model_config = ModelConfig(
20+
model="facebook/opt-125m",
21+
task="generate",
22+
max_model_len=args.num_token + args.num_spec_token,
23+
tokenizer="facebook/opt-125m",
24+
tokenizer_mode="auto",
25+
dtype="auto",
26+
seed=None,
27+
trust_remote_code=False,
28+
)
29+
proposer = NgramProposer(
30+
vllm_config=VllmConfig(
31+
model_config=model_config,
32+
speculative_config=SpeculativeConfig(
33+
prompt_lookup_min=args.min_ngram,
34+
prompt_lookup_max=max_ngram,
35+
num_speculative_tokens=args.num_spec_token,
36+
method="ngram",
37+
),
38+
)
39+
)
40+
41+
# Warm up
42+
proposer.propose(np.random.randint(0, 20, (args.num_token,)))
43+
44+
gc.collect()
45+
for _ in range(args.num_iteration):
46+
tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
47+
with collector:
48+
for i in range(args.num_req):
49+
proposer.propose(tokens[i, :])
50+
rows.append(
51+
[args.num_req, args.num_token, args.min_ngram, max_ngram]
52+
+ collector.dump_avg_max()
53+
)
54+
55+
print(
56+
tabulate(
57+
rows,
58+
headers=[
59+
"# Request",
60+
"# Token",
61+
"Min Ngram",
62+
"Max Ngram",
63+
"Avg (us)",
64+
"Max (us)",
65+
],
66+
tablefmt="grid",
67+
floatfmt=".3f",
68+
)
69+
)
70+
71+
72+
def invoke_main() -> None:
73+
parser = FlexibleArgumentParser(
74+
description="Benchmark the performance of N-gram speculative decode drafting"
75+
)
76+
parser.add_argument(
77+
"--num-iteration",
78+
type=int,
79+
default=100,
80+
help="Number of iterations to run to stablize final data readings",
81+
)
82+
parser.add_argument(
83+
"--num-req", type=int, default=128, help="Number of requests in the batch"
84+
)
85+
parser.add_argument(
86+
"--num-token", type=int, default=1500, help="Number of tokens for each request"
87+
)
88+
parser.add_argument(
89+
"--min-ngram",
90+
type=int,
91+
default=3,
92+
help="Minimum n-gram to match",
93+
)
94+
parser.add_argument(
95+
"--max-ngram",
96+
type=int,
97+
nargs="*",
98+
default=[5, 7, 10, 15, 20],
99+
help="Maximum n-gram to match",
100+
)
101+
parser.add_argument(
102+
"--num-spec-token",
103+
type=int,
104+
default=3,
105+
help="Number of speculative tokens to generate",
106+
)
107+
args = parser.parse_args()
108+
main(args)
109+
110+
111+
if __name__ == "__main__":
112+
invoke_main() # pragma: no cover

benchmarks/benchmark_utils.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
43
import argparse
54
import json
65
import math
76
import os
8-
from typing import Any
7+
import time
8+
from types import TracebackType
9+
from typing import Any, Optional, Union
910

1011

1112
def convert_to_pytorch_benchmark_format(
@@ -72,3 +73,53 @@ def write_to_json(filename: str, records: list) -> None:
7273
cls=InfEncoder,
7374
default=lambda o: f"<{type(o).__name__} object is not JSON serializable>",
7475
)
76+
77+
78+
# Collect time and generate time metrics
79+
#
80+
# Example Usage:
81+
# collector = TimeCollector(TimeCollector.US)
82+
# for _ in range(total_iteration):
83+
# with collector:
84+
# ...
85+
# collector.dump_avg_max()
86+
class TimeCollector:
87+
NS: int = 1
88+
US: int = NS * 1000
89+
MS: int = US * 1000
90+
S: int = MS * 1000
91+
92+
def __init__(self, scale: int) -> None:
93+
self.cnt: int = 0
94+
self._sum: int = 0
95+
self._max: Optional[int] = None
96+
self.scale = scale
97+
self.start_time: int = time.monotonic_ns()
98+
99+
def collect(self, v: int) -> None:
100+
self.cnt += 1
101+
self._sum += v
102+
if self._max is None:
103+
self._max = v
104+
else:
105+
self._max = max(self._max, v)
106+
107+
def avg(self) -> Union[float, str]:
108+
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
109+
110+
def max(self) -> Union[float, str]:
111+
return self._max / self.scale if self._max else "N/A"
112+
113+
def dump_avg_max(self) -> list[Union[float, str]]:
114+
return [self.avg(), self.max()]
115+
116+
def __enter__(self) -> None:
117+
self.start_time = time.monotonic_ns()
118+
119+
def __exit__(
120+
self,
121+
exc_type: Optional[type[BaseException]],
122+
exc_value: Optional[BaseException],
123+
exc_traceback: Optional[TracebackType],
124+
) -> None:
125+
self.collect(time.monotonic_ns() - self.start_time)

benchmarks/kv_cache/benchmark_block_pool.py

Lines changed: 0 additions & 108 deletions
This file was deleted.

0 commit comments

Comments
 (0)