Skip to content

Commit 9bcf922

Browse files
LuminolTrussellb
andauthored
[Core] Add xxHash as a high-performance hash option for accelerating prefix caching (#29163)
Signed-off-by: LuminolT <[email protected]> Signed-off-by: Lumis Chen <[email protected]> Co-authored-by: Russell Bryant <[email protected]>
1 parent 5aa9b09 commit 9bcf922

File tree

7 files changed

+332
-8
lines changed

7 files changed

+332
-8
lines changed

benchmarks/benchmark_hash.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Micro benchmark comparing built-in hash(), SHA-256, and xxHash.
5+
6+
This focuses on a single test payload shaped like the prefix-cache hash input:
7+
(32-byte bytes object, 32-int tuple)
8+
9+
Usage:
10+
python benchmarks/hash_micro_benchmark.py --iterations 20000
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import argparse
16+
import random
17+
import statistics
18+
import time
19+
from collections.abc import Callable, Iterable
20+
21+
from vllm.utils.hashing import sha256, xxhash
22+
23+
24+
def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]:
25+
"""Generate a deterministic test payload."""
26+
random.seed(seed)
27+
bytes_data = bytes(random.getrandbits(8) for _ in range(32))
28+
int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32))
29+
return (bytes_data, int_tuple)
30+
31+
32+
def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int):
33+
"""Return (avg_seconds, std_seconds) for hashing `data` `iterations` times."""
34+
times: list[float] = []
35+
36+
# Warm-up to avoid first-run noise.
37+
for _ in range(200):
38+
func(data)
39+
40+
for _ in range(iterations):
41+
start = time.perf_counter()
42+
func(data)
43+
end = time.perf_counter()
44+
times.append(end - start)
45+
46+
avg = statistics.mean(times)
47+
std = statistics.stdev(times) if len(times) > 1 else 0.0
48+
return avg, std
49+
50+
51+
def _run_benchmarks(
52+
benchmarks: Iterable[tuple[str, Callable[[tuple], object]]],
53+
data: tuple,
54+
iterations: int,
55+
):
56+
"""Yield (name, avg, std) for each benchmark, skipping unavailable ones."""
57+
for name, func in benchmarks:
58+
try:
59+
avg, std = _benchmark_func(func, data, iterations)
60+
except ModuleNotFoundError as exc:
61+
print(f"Skipping {name}: {exc}")
62+
continue
63+
yield name, avg, std
64+
65+
66+
def builtin_hash(data: tuple) -> int:
67+
"""Wrapper for Python's built-in hash()."""
68+
return hash(data)
69+
70+
71+
def main() -> None:
72+
parser = argparse.ArgumentParser(description=__doc__)
73+
parser.add_argument(
74+
"--iterations",
75+
type=int,
76+
default=10_000,
77+
help="Number of measured iterations per hash function.",
78+
)
79+
parser.add_argument(
80+
"--seed", type=int, default=42, help="Random seed for test payload."
81+
)
82+
args = parser.parse_args()
83+
84+
data = _generate_test_data(args.seed)
85+
benchmarks = (
86+
("SHA256 (pickle)", sha256),
87+
("xxHash (pickle)", xxhash),
88+
("built-in hash()", builtin_hash),
89+
)
90+
91+
print("=" * 60)
92+
print("HASH FUNCTION MICRO BENCHMARK")
93+
print("=" * 60)
94+
print("Test data: (32-byte bytes object, 32-int tuple)")
95+
print(f"Iterations: {args.iterations:,}")
96+
print("=" * 60)
97+
98+
results = list(_run_benchmarks(benchmarks, data, args.iterations))
99+
builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None)
100+
101+
print("\nResults:")
102+
for name, avg, std in results:
103+
print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs")
104+
105+
if builtin_entry:
106+
_, builtin_avg, _ = builtin_entry
107+
print("\n" + "=" * 60)
108+
print("SUMMARY (relative to built-in hash())")
109+
print("=" * 60)
110+
for name, avg, _ in results:
111+
if name == "built-in hash()":
112+
continue
113+
speed_ratio = avg / builtin_avg
114+
print(f"• {name} is {speed_ratio:.1f}x slower than built-in hash()")
115+
else:
116+
print("\nBuilt-in hash() result missing; cannot compute speed ratios.")
117+
118+
119+
if __name__ == "__main__":
120+
main()
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
"""
5+
Simple benchmark to compare prefix-cache block hashing algorithms.
6+
7+
Example:
8+
python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import argparse
14+
import random
15+
import statistics
16+
import sys
17+
import time
18+
from collections.abc import Callable, Iterable, Sequence
19+
20+
from vllm.utils.hashing import get_hash_fn_by_name
21+
from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash
22+
23+
SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor")
24+
25+
26+
def _generate_blocks(
27+
num_blocks: int, block_size: int, vocab_size: int, seed: int
28+
) -> list[list[int]]:
29+
rng = random.Random(seed)
30+
return [
31+
[rng.randrange(vocab_size) for _ in range(block_size)]
32+
for _ in range(num_blocks)
33+
]
34+
35+
36+
def _hash_all_blocks(
37+
hash_fn: Callable[[object], bytes],
38+
blocks: Iterable[Sequence[int]],
39+
) -> float:
40+
parent_hash: BlockHash | None = None
41+
start = time.perf_counter()
42+
for block in blocks:
43+
parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None)
44+
end = time.perf_counter()
45+
return end - start
46+
47+
48+
def _benchmark(
49+
hash_algo: str,
50+
blocks: list[list[int]],
51+
trials: int,
52+
) -> tuple[float, float, float] | None:
53+
try:
54+
hash_fn = get_hash_fn_by_name(hash_algo)
55+
init_none_hash(hash_fn)
56+
timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)]
57+
except ModuleNotFoundError as exc:
58+
print(f"Skipping {hash_algo}: {exc}", file=sys.stderr)
59+
return None
60+
61+
avg = statistics.mean(timings)
62+
best = min(timings)
63+
# throughput: tokens / second
64+
tokens_hashed = len(blocks) * len(blocks[0])
65+
throughput = tokens_hashed / best
66+
return avg, best, throughput
67+
68+
69+
def main() -> None:
70+
parser = argparse.ArgumentParser(description=__doc__)
71+
parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.")
72+
parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.")
73+
parser.add_argument(
74+
"--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)."
75+
)
76+
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
77+
parser.add_argument(
78+
"--trials", type=int, default=5, help="Number of timed trials per algorithm."
79+
)
80+
parser.add_argument(
81+
"--algorithms",
82+
nargs="+",
83+
default=SUPPORTED_ALGOS,
84+
choices=SUPPORTED_ALGOS,
85+
help="Hash algorithms to benchmark.",
86+
)
87+
args = parser.parse_args()
88+
89+
blocks = _generate_blocks(
90+
args.num_blocks, args.block_size, args.vocab_size, args.seed
91+
)
92+
print(
93+
f"Benchmarking {len(args.algorithms)} algorithms on "
94+
f"{args.num_blocks} blocks (block size={args.block_size})."
95+
)
96+
97+
for algo in args.algorithms:
98+
result = _benchmark(algo, blocks, args.trials)
99+
if result is None:
100+
continue
101+
102+
avg, best, throughput = result
103+
print(
104+
f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s "
105+
f"throughput: {throughput / 1e6:.2f}M tokens/s"
106+
)
107+
108+
109+
if __name__ == "__main__":
110+
main()

docs/benchmarking/cli.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,35 @@ vllm bench serve \
670670

671671
</details>
672672

673+
### 🧪 Hashing Benchmarks
674+
675+
<details class="admonition abstract" markdown="1">
676+
<summary>Show more</summary>
677+
678+
Two helper scripts live in `benchmarks/` to compare hashing options used by prefix caching and related utilities. They are standalone (no server required) and help choose a hash algorithm before enabling prefix caching in production.
679+
680+
- `benchmarks/benchmark_hash.py`: Micro-benchmark that measures per-call latency of three implementations on a representative `(bytes, tuple[int])` payload.
681+
682+
```bash
683+
python benchmarks/benchmark_hash.py --iterations 20000 --seed 42
684+
```
685+
686+
- `benchmarks/benchmark_prefix_block_hash.py`: End-to-end block hashing benchmark that runs the full prefix-cache hash pipeline (`hash_block_tokens`) across many fake blocks and reports throughput.
687+
688+
```bash
689+
python benchmarks/benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 --trials 5
690+
```
691+
692+
Supported algorithms: `sha256`, `sha256_cbor`, `xxhash`, `xxhash_cbor`. Install optional deps to exercise all variants:
693+
694+
```bash
695+
uv pip install xxhash cbor2
696+
```
697+
698+
If an algorithm’s dependency is missing, the script will skip it and continue.
699+
700+
</details>
701+
673702
### ⚡ Request Prioritization Benchmark
674703

675704
<details class="admonition abstract" markdown="1">

tests/v1/engine/test_engine_args.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.engine.arg_utils import EngineArgs
1010
from vllm.usage.usage_lib import UsageContext
1111
from vllm.utils.argparse_utils import FlexibleArgumentParser
12+
from vllm.utils.hashing import _xxhash
1213

1314

1415
def test_prefix_caching_from_cli():
@@ -48,6 +49,21 @@ def test_prefix_caching_from_cli():
4849
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
4950

5051

52+
@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed")
53+
def test_prefix_caching_xxhash_from_cli():
54+
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
55+
56+
# set hash algorithm to xxhash (pickle)
57+
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"])
58+
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
59+
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash"
60+
61+
# set hash algorithm to xxhash_cbor
62+
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"])
63+
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
64+
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor"
65+
66+
5167
def test_defaults_with_usage_context():
5268
engine_args = EngineArgs(model="facebook/opt-125m")
5369
vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)

vllm/config/cache.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
"fp8_ds_mla",
3131
]
3232
MambaDType = Literal["auto", "float32"]
33-
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
33+
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
3434
KVOffloadingBackend = Literal["native", "lmcache"]
3535

3636

@@ -77,9 +77,21 @@ class CacheConfig:
7777
"""Whether to enable prefix caching."""
7878
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
7979
"""Set the hash algorithm for prefix caching:\n
80-
- "sha256" uses Pickle for object serialization before hashing.\n
80+
- "sha256" uses Pickle for object serialization before hashing. This is the
81+
current default, as SHA256 is the most secure choice to avoid potential
82+
hash collisions.\n
8183
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
82-
serializes objects using canonical CBOR and hashes them with SHA-256."""
84+
serializes objects using canonical CBOR and hashes them with SHA-256.\n
85+
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
86+
non-cryptographic hashing. Requires the optional ``xxhash`` package.
87+
IMPORTANT: Use of a hashing algorithm that is not considered
88+
cryptographically secure theoretically increases the risk of hash collisions,
89+
which can cause undefined behavior or even leak private information in
90+
multi-tenant environments. Even if collisions are still very unlikely, it is
91+
important to consider your security risk tolerance against the performance
92+
benefits before turning this on.\n
93+
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
94+
reproducible hashing. Requires the optional ``xxhash`` package."""
8395
cpu_offload_gb: float = Field(default=0, ge=0)
8496
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
8597
no offloading. Intuitively, this argument can be seen as a virtual way to

vllm/utils/hashing.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@
1111

1212
import cbor2
1313

14+
try:
15+
# It is important that this remains an optional dependency.
16+
# It would not be allowed in environments with strict security controls,
17+
# so it's best not to have it installed when not in use.
18+
import xxhash as _xxhash
19+
20+
if not hasattr(_xxhash, "xxh3_128_digest"):
21+
_xxhash = None
22+
except ImportError: # pragma: no cover
23+
_xxhash = None
24+
1425

1526
def sha256(input: Any) -> bytes:
1627
"""Hash any picklable Python object using SHA-256.
@@ -47,6 +58,27 @@ def sha256_cbor(input: Any) -> bytes:
4758
return hashlib.sha256(input_bytes).digest()
4859

4960

61+
def _xxhash_digest(input_bytes: bytes) -> bytes:
62+
if _xxhash is None:
63+
raise ModuleNotFoundError(
64+
"xxhash is required for the 'xxhash' prefix caching hash algorithms. "
65+
"Install it via `pip install xxhash`."
66+
)
67+
return _xxhash.xxh3_128_digest(input_bytes)
68+
69+
70+
def xxhash(input: Any) -> bytes:
71+
"""Hash picklable objects using xxHash."""
72+
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
73+
return _xxhash_digest(input_bytes)
74+
75+
76+
def xxhash_cbor(input: Any) -> bytes:
77+
"""Hash objects serialized with CBOR using xxHash."""
78+
input_bytes = cbor2.dumps(input, canonical=True)
79+
return _xxhash_digest(input_bytes)
80+
81+
5082
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
5183
"""Get a hash function by name, or raise an error if the function is not found.
5284
@@ -60,6 +92,10 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
6092
return sha256
6193
if hash_fn_name == "sha256_cbor":
6294
return sha256_cbor
95+
if hash_fn_name == "xxhash":
96+
return xxhash
97+
if hash_fn_name == "xxhash_cbor":
98+
return xxhash_cbor
6399

64100
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
65101

0 commit comments

Comments
 (0)