-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[Core] Add xxHash as a high-performance hash option for accelerating prefix caching #29163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
37a4069
c6ae405
4e231a7
cefd2f6
4af648c
8b0f2f5
19d6402
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| Micro benchmark comparing built-in hash(), SHA-256, and xxHash. | ||
|
|
||
| This focuses on a single test payload shaped like the prefix-cache hash input: | ||
| (32-byte bytes object, 32-int tuple) | ||
|
|
||
| Usage: | ||
| python benchmarks/hash_micro_benchmark.py --iterations 20000 | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import random | ||
| import statistics | ||
| import time | ||
| from collections.abc import Callable, Iterable | ||
|
|
||
| from vllm.utils.hashing import sha256, xxhash | ||
|
|
||
|
|
||
| def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]: | ||
| """Generate a deterministic test payload.""" | ||
| random.seed(seed) | ||
| bytes_data = bytes(random.getrandbits(8) for _ in range(32)) | ||
| int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32)) | ||
| return (bytes_data, int_tuple) | ||
|
|
||
|
|
||
| def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int): | ||
| """Return (avg_seconds, std_seconds) for hashing `data` `iterations` times.""" | ||
| times: list[float] = [] | ||
|
|
||
| # Warm-up to avoid first-run noise. | ||
| for _ in range(200): | ||
| func(data) | ||
|
|
||
| for _ in range(iterations): | ||
| start = time.perf_counter() | ||
| func(data) | ||
| end = time.perf_counter() | ||
| times.append(end - start) | ||
|
|
||
| avg = statistics.mean(times) | ||
| std = statistics.stdev(times) if len(times) > 1 else 0.0 | ||
| return avg, std | ||
|
|
||
|
|
||
| def _run_benchmarks( | ||
| benchmarks: Iterable[tuple[str, Callable[[tuple], object]]], | ||
| data: tuple, | ||
| iterations: int, | ||
| ): | ||
| """Yield (name, avg, std) for each benchmark, skipping unavailable ones.""" | ||
| for name, func in benchmarks: | ||
| try: | ||
| avg, std = _benchmark_func(func, data, iterations) | ||
| except ModuleNotFoundError as exc: | ||
| print(f"Skipping {name}: {exc}") | ||
| continue | ||
| yield name, avg, std | ||
|
|
||
|
|
||
| def builtin_hash(data: tuple) -> int: | ||
| """Wrapper for Python's built-in hash().""" | ||
| return hash(data) | ||
|
|
||
|
|
||
| def main() -> None: | ||
| parser = argparse.ArgumentParser(description=__doc__) | ||
| parser.add_argument( | ||
| "--iterations", | ||
| type=int, | ||
| default=10_000, | ||
| help="Number of measured iterations per hash function.", | ||
| ) | ||
| parser.add_argument( | ||
| "--seed", type=int, default=42, help="Random seed for test payload." | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| data = _generate_test_data(args.seed) | ||
| benchmarks = ( | ||
| ("SHA256 (pickle)", sha256), | ||
| ("xxHash (pickle)", xxhash), | ||
| ("built-in hash()", builtin_hash), | ||
| ) | ||
|
|
||
| print("=" * 60) | ||
| print("HASH FUNCTION MICRO BENCHMARK") | ||
| print("=" * 60) | ||
| print("Test data: (32-byte bytes object, 32-int tuple)") | ||
| print(f"Iterations: {args.iterations:,}") | ||
| print("=" * 60) | ||
|
|
||
| results = list(_run_benchmarks(benchmarks, data, args.iterations)) | ||
| builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None) | ||
|
|
||
| print("\nResults:") | ||
| for name, avg, std in results: | ||
| print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs") | ||
|
|
||
| if builtin_entry: | ||
| _, builtin_avg, _ = builtin_entry | ||
| print("\n" + "=" * 60) | ||
| print("SUMMARY (relative to built-in hash())") | ||
| print("=" * 60) | ||
| for name, avg, _ in results: | ||
| if name == "built-in hash()": | ||
| continue | ||
| speed_ratio = avg / builtin_avg | ||
| print(f"• {name} is {speed_ratio:.1f}x slower than built-in hash()") | ||
| else: | ||
| print("\nBuilt-in hash() result missing; cannot compute speed ratios.") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| """ | ||
| Simple benchmark to compare prefix-cache block hashing algorithms. | ||
|
|
||
| Example: | ||
| python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import random | ||
| import statistics | ||
| import sys | ||
| import time | ||
| from collections.abc import Callable, Iterable, Sequence | ||
|
|
||
| from vllm.utils.hashing import get_hash_fn_by_name | ||
| from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash | ||
|
|
||
| SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor") | ||
|
|
||
|
|
||
| def _generate_blocks( | ||
| num_blocks: int, block_size: int, vocab_size: int, seed: int | ||
| ) -> list[list[int]]: | ||
| rng = random.Random(seed) | ||
| return [ | ||
| [rng.randrange(vocab_size) for _ in range(block_size)] | ||
| for _ in range(num_blocks) | ||
| ] | ||
|
|
||
|
|
||
| def _hash_all_blocks( | ||
| hash_fn: Callable[[object], bytes], | ||
| blocks: Iterable[Sequence[int]], | ||
| ) -> float: | ||
| parent_hash: BlockHash | None = None | ||
| start = time.perf_counter() | ||
| for block in blocks: | ||
| parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None) | ||
| end = time.perf_counter() | ||
| return end - start | ||
|
|
||
|
|
||
| def _benchmark( | ||
| hash_algo: str, | ||
| blocks: list[list[int]], | ||
| trials: int, | ||
| ) -> tuple[float, float, float] | None: | ||
| try: | ||
| hash_fn = get_hash_fn_by_name(hash_algo) | ||
| init_none_hash(hash_fn) | ||
| timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)] | ||
| except ModuleNotFoundError as exc: | ||
| print(f"Skipping {hash_algo}: {exc}", file=sys.stderr) | ||
| return None | ||
|
|
||
| avg = statistics.mean(timings) | ||
| best = min(timings) | ||
| # throughput: tokens / second | ||
| tokens_hashed = len(blocks) * len(blocks[0]) | ||
| throughput = tokens_hashed / best | ||
| return avg, best, throughput | ||
|
|
||
|
|
||
| def main() -> None: | ||
| parser = argparse.ArgumentParser(description=__doc__) | ||
| parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.") | ||
| parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.") | ||
| parser.add_argument( | ||
| "--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)." | ||
| ) | ||
| parser.add_argument("--seed", type=int, default=0, help="Random seed.") | ||
| parser.add_argument( | ||
| "--trials", type=int, default=5, help="Number of timed trials per algorithm." | ||
| ) | ||
| parser.add_argument( | ||
| "--algorithms", | ||
| nargs="+", | ||
| default=SUPPORTED_ALGOS, | ||
| choices=SUPPORTED_ALGOS, | ||
| help="Hash algorithms to benchmark.", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| blocks = _generate_blocks( | ||
| args.num_blocks, args.block_size, args.vocab_size, args.seed | ||
| ) | ||
| print( | ||
| f"Benchmarking {len(args.algorithms)} algorithms on " | ||
| f"{args.num_blocks} blocks (block size={args.block_size})." | ||
| ) | ||
|
|
||
| for algo in args.algorithms: | ||
| result = _benchmark(algo, blocks, args.trials) | ||
| if result is None: | ||
| continue | ||
|
|
||
| avg, best, throughput = result | ||
| print( | ||
| f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s " | ||
| f"throughput: {throughput / 1e6:.2f}M tokens/s" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -46,6 +46,7 @@ scipy # Required for phi-4-multimodal-instruct | |||
| ninja # Required for xgrammar, rocm, tpu, xpu | ||||
| pybase64 # fast base64 implementation | ||||
| cbor2 # Required for cross-language serialization of hashable objects | ||||
| xxhash # Required for fast hashing for prefix caching | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer this not be included in our main requirements file, given that it's optional. In fact, it MUST remain optional to avoid getting flagged in environments with strict security controls.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| setproctitle # Used to set process names for better debugging and monitoring | ||||
| openai-harmony >= 0.0.3 # Required for gpt-oss | ||||
| anthropic == 0.71.0 | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -30,7 +30,7 @@ | |||||||||||||||||||||||||||||||||||
| "fp8_ds_mla", | ||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||
| MambaDType = Literal["auto", "float32"] | ||||||||||||||||||||||||||||||||||||
| PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] | ||||||||||||||||||||||||||||||||||||
| PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"] | ||||||||||||||||||||||||||||||||||||
| KVOffloadingBackend = Literal["native", "lmcache"] | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -79,7 +79,11 @@ class CacheConfig: | |||||||||||||||||||||||||||||||||||
| """Set the hash algorithm for prefix caching:\n | ||||||||||||||||||||||||||||||||||||
| - "sha256" uses Pickle for object serialization before hashing.\n | ||||||||||||||||||||||||||||||||||||
| - "sha256_cbor" provides a reproducible, cross-language compatible hash. It | ||||||||||||||||||||||||||||||||||||
| serializes objects using canonical CBOR and hashes them with SHA-256.""" | ||||||||||||||||||||||||||||||||||||
| serializes objects using canonical CBOR and hashes them with SHA-256.\n | ||||||||||||||||||||||||||||||||||||
| - "xxhash" uses Pickle serialization with xxHash (128-bit) for faster, | ||||||||||||||||||||||||||||||||||||
| non-cryptographic hashing. Requires the optional ``xxhash`` package.\n | ||||||||||||||||||||||||||||||||||||
| - "xxhash_cbor" combines canonical CBOR serialization with xxHash for | ||||||||||||||||||||||||||||||||||||
| reproducible hashing. Requires the optional ``xxhash`` package.""" | ||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
| cpu_offload_gb: float = Field(default=0, ge=0) | ||||||||||||||||||||||||||||||||||||
| """The space in GiB to offload to CPU, per GPU. Default is 0, which means | ||||||||||||||||||||||||||||||||||||
| no offloading. Intuitively, this argument can be seen as a virtual way to | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,14 @@ | |||||||||||
|
|
||||||||||||
| import cbor2 | ||||||||||||
|
|
||||||||||||
| try: | ||||||||||||
| import xxhash as _xxhash | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be helpful to add a comment here that it is important for this to remain an optional dependency, as it would be considered problematic to include at all in environments with strict security controls.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
|
||||||||||||
| if not hasattr(_xxhash, "xxh3_128_digest"): | ||||||||||||
| _xxhash = None | ||||||||||||
| except ImportError: # pragma: no cover | ||||||||||||
| _xxhash = None | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def sha256(input: Any) -> bytes: | ||||||||||||
| """Hash any picklable Python object using SHA-256. | ||||||||||||
|
|
@@ -46,6 +54,27 @@ def sha256_cbor(input: Any) -> bytes: | |||||||||||
| return hashlib.sha256(input_bytes).digest() | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _xxhash_digest(input_bytes: bytes) -> bytes: | ||||||||||||
| if _xxhash is None: | ||||||||||||
| raise ModuleNotFoundError( | ||||||||||||
| "xxhash is required for the 'xxhash' prefix caching hash algorithms. " | ||||||||||||
| "Install it via `pip install xxhash`." | ||||||||||||
| ) | ||||||||||||
| return _xxhash.xxh3_128_digest(input_bytes) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def xxhash(input: Any) -> bytes: | ||||||||||||
| """Hash picklable objects using xxHash.""" | ||||||||||||
| input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) | ||||||||||||
| return _xxhash_digest(input_bytes) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def xxhash_cbor(input: Any) -> bytes: | ||||||||||||
| """Hash objects serialized with CBOR using xxHash.""" | ||||||||||||
| input_bytes = cbor2.dumps(input, canonical=True) | ||||||||||||
| return _xxhash_digest(input_bytes) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: | ||||||||||||
| """Get a hash function by name, or raise an error if the function is not found. | ||||||||||||
|
|
@@ -59,5 +88,9 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: | |||||||||||
| return sha256 | ||||||||||||
| if hash_fn_name == "sha256_cbor": | ||||||||||||
| return sha256_cbor | ||||||||||||
| if hash_fn_name == "xxhash": | ||||||||||||
| return xxhash | ||||||||||||
| if hash_fn_name == "xxhash_cbor": | ||||||||||||
| return xxhash_cbor | ||||||||||||
|
|
||||||||||||
| raise ValueError(f"Unsupported hash function: {hash_fn_name}") | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's worth mentioning these scripts in the documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out! It will be helpful for evaluating the hash algorithms used in prefix caching. I will update the relevant documentation accordingly.