Skip to content

Commit 37a4069

Browse files
committed
[FEAT] Add xxhash and xxhash_cbor algorithms for prefix caching and update tests
Signed-off-by: LuminolT <[email protected]>
1 parent 8ac3a41 commit 37a4069

File tree

5 files changed

+159
-7
lines changed

5 files changed

+159
-7
lines changed

benchmarks/hash_perf_demo.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
"""
5+
Simple benchmark to compare prefix-cache block hashing algorithms.
6+
7+
Example:
8+
python benchmarks/hash_perf_demo.py --num-blocks 20000 --block-size 32
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import argparse
14+
import random
15+
import statistics
16+
import sys
17+
import time
18+
from collections.abc import Callable, Iterable, Sequence
19+
20+
from vllm.utils.hashing import get_hash_fn_by_name
21+
from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash
22+
23+
SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor")
24+
25+
26+
def _generate_blocks(
27+
num_blocks: int, block_size: int, vocab_size: int, seed: int
28+
) -> list[list[int]]:
29+
rng = random.Random(seed)
30+
return [
31+
[rng.randrange(vocab_size) for _ in range(block_size)]
32+
for _ in range(num_blocks)
33+
]
34+
35+
36+
def _hash_all_blocks(
37+
hash_fn: Callable[[object], bytes],
38+
blocks: Iterable[Sequence[int]],
39+
) -> float:
40+
parent_hash: BlockHash | None = None
41+
start = time.perf_counter()
42+
for block in blocks:
43+
parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None)
44+
end = time.perf_counter()
45+
return end - start
46+
47+
48+
def _benchmark(
49+
hash_algo: str,
50+
blocks: list[list[int]],
51+
trials: int,
52+
) -> tuple[float, float, float] | None:
53+
try:
54+
hash_fn = get_hash_fn_by_name(hash_algo)
55+
init_none_hash(hash_fn)
56+
timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)]
57+
except ModuleNotFoundError as exc:
58+
print(f"Skipping {hash_algo}: {exc}", file=sys.stderr)
59+
return None
60+
61+
avg = statistics.mean(timings)
62+
best = min(timings)
63+
# throughput: tokens / second
64+
tokens_hashed = len(blocks) * len(blocks[0])
65+
throughput = tokens_hashed / best
66+
return avg, best, throughput
67+
68+
69+
def main() -> None:
70+
parser = argparse.ArgumentParser(description=__doc__)
71+
parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.")
72+
parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.")
73+
parser.add_argument(
74+
"--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)."
75+
)
76+
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
77+
parser.add_argument(
78+
"--trials", type=int, default=5, help="Number of timed trials per algorithm."
79+
)
80+
parser.add_argument(
81+
"--algorithms",
82+
nargs="+",
83+
default=SUPPORTED_ALGOS,
84+
choices=SUPPORTED_ALGOS,
85+
help="Hash algorithms to benchmark.",
86+
)
87+
args = parser.parse_args()
88+
89+
blocks = _generate_blocks(args.num_blocks, args.block_size, args.vocab_size, args.seed)
90+
print(
91+
f"Benchmarking {len(args.algorithms)} algorithms on "
92+
f"{args.num_blocks} blocks (block size={args.block_size})."
93+
)
94+
95+
for algo in args.algorithms:
96+
result = _benchmark(algo, blocks, args.trials)
97+
if result is None:
98+
continue
99+
100+
avg, best, throughput = result
101+
print(
102+
f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s "
103+
f"throughput: {throughput/1e6:.2f}M tokens/s"
104+
)
105+
106+
107+
if __name__ == "__main__":
108+
main()

tests/v1/engine/test_engine_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ def test_prefix_caching_from_cli():
4242
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
4343
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
4444

45+
# set hash algorithm to xxhash (pickle)
46+
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"])
47+
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
48+
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash"
49+
50+
# set hash algorithm to xxhash_cbor
51+
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"])
52+
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
53+
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor"
54+
4555
# an invalid hash algorithm raises an error
4656
parser.exit_on_error = False
4757
with pytest.raises(ArgumentError):

vllm/config/cache.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
"fp8_ds_mla",
3131
]
3232
MambaDType = Literal["auto", "float32"]
33-
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
33+
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
3434
KVOffloadingBackend = Literal["native", "lmcache"]
3535

3636

@@ -79,7 +79,11 @@ class CacheConfig:
7979
"""Set the hash algorithm for prefix caching:\n
8080
- "sha256" uses Pickle for object serialization before hashing.\n
8181
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
82-
serializes objects using canonical CBOR and hashes them with SHA-256."""
82+
serializes objects using canonical CBOR and hashes them with SHA-256.\n
83+
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
84+
non-cryptographic hashing. Requires the optional ``xxhash`` package.\n
85+
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
86+
reproducible hashing. Requires the optional ``xxhash`` package."""
8387
cpu_offload_gb: float = Field(default=0, ge=0)
8488
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
8589
no offloading. Intuitively, this argument can be seen as a virtual way to

vllm/utils/hashing.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from typing import Any
1010

1111
import cbor2
12+
try:
13+
import xxhash as _xxhash
14+
except ImportError: # pragma: no cover
15+
_xxhash = None
1216

1317

1418
def sha256(input: Any) -> bytes:
@@ -46,6 +50,27 @@ def sha256_cbor(input: Any) -> bytes:
4650
return hashlib.sha256(input_bytes).digest()
4751

4852

53+
def _xxhash_digest(input_bytes: bytes) -> bytes:
54+
if _xxhash is None:
55+
raise ModuleNotFoundError(
56+
"xxhash is required for the 'xxhash' prefix caching hash algorithms. "
57+
"Install it via `pip install xxhash`."
58+
)
59+
return _xxhash.xxh3_128_digest(input_bytes)
60+
61+
62+
def xxhash(input: Any) -> bytes:
63+
"""Hash picklable objects using xxHash."""
64+
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
65+
return _xxhash_digest(input_bytes)
66+
67+
68+
def xxhash_cbor(input: Any) -> bytes:
69+
"""Hash objects serialized with CBOR using xxHash."""
70+
input_bytes = cbor2.dumps(input, canonical=True)
71+
return _xxhash_digest(input_bytes)
72+
73+
4974
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
5075
"""Get a hash function by name, or raise an error if the function is not found.
5176
@@ -59,5 +84,9 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
5984
return sha256
6085
if hash_fn_name == "sha256_cbor":
6186
return sha256_cbor
87+
if hash_fn_name == "xxhash":
88+
return xxhash
89+
if hash_fn_name == "xxhash_cbor":
90+
return xxhash_cbor
6291

6392
raise ValueError(f"Unsupported hash function: {hash_fn_name}")

vllm/v1/core/kv_cache_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm import envs
1313
from vllm.config import VllmConfig
1414
from vllm.logger import init_logger
15-
from vllm.utils.hashing import sha256_cbor
15+
from vllm.utils.hashing import sha256_cbor, xxhash_cbor
1616
from vllm.utils.math_utils import cdiv
1717
from vllm.utils.mem_constants import GiB_bytes
1818
from vllm.v1.kv_cache_interface import (
@@ -83,18 +83,19 @@ def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash:
8383
#
8484
# The function `init_none_hash` initializes this variable globally.
8585
NONE_HASH: BlockHash
86+
_CBOR_HASH_FUNCTIONS = frozenset({sha256_cbor, xxhash_cbor})
8687

8788

8889
def init_none_hash(hash_fn: Callable[[Any], bytes]):
8990
global NONE_HASH
9091

9192
hash_seed = os.getenv("PYTHONHASHSEED")
92-
if hash_seed is None and hash_fn is sha256_cbor:
93+
if hash_seed is None and hash_fn in _CBOR_HASH_FUNCTIONS:
9394
logger.warning(
9495
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
95-
"block-hashes when using sha256_cbor as the hash function."
96-
"Consider setting PYTHONHASHSEED to a fixed value for "
97-
"reproducibility."
96+
"block-hashes when using CBOR-based hash functions such as "
97+
"sha256_cbor or xxhash_cbor. Consider setting PYTHONHASHSEED to a "
98+
"fixed value for reproducibility."
9899
)
99100

100101
if hash_seed is None:

0 commit comments

Comments
 (0)