diff --git a/benchmark/benchmark_reshape_and_cache.py b/benchmark/benchmark_reshape_and_cache.py index c6d96e7..6fdf56e 100644 --- a/benchmark/benchmark_reshape_and_cache.py +++ b/benchmark/benchmark_reshape_and_cache.py @@ -1,82 +1,161 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations +import itertools import random -import time +from typing import Optional import torch -from tabulate import tabulate +import triton +from torch import Tensor -from tests import register_ops as ops -from tests.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random +from tests import register_ops as vllm_ops +from tests.utils import (check_ipex_availability, create_kv_caches_with_random, + parse_args) +HAS_IPEX = check_ipex_availability() -@torch.inference_mode() -def run_benchmark( - num_tokens: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, +if HAS_IPEX: + import intel_extension_for_pytorch as ipex + + +def reshape_and_cache_vllm( + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, kv_cache_dtype: str, - num_iters: int, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, +) -> None: + """vLLM's fused kernel for reshaping and caching K/V tensors.""" + vllm_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, k_scale, v_scale) + + +def reshape_and_cache_ipex( + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + kv_cache_dtype: str, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, +) -> None: + """IPEX native implementation using ipex.llm.modules.PagedAttention.""" + if not HAS_IPEX: + raise RuntimeError("IPEX is not available") + assert kv_cache_dtype == "auto", "IPEX reshape_and_cache uses 'auto' mode" + + ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, + value_cache, + slot_mapping) + + +def get_benchmark( + dtype: torch.dtype, device: str = "xpu", -) -> float: - """Return latency (seconds) for given num_tokens.""" - - if kv_cache_dtype == "fp8" and head_size % 16: - raise ValueError( - "fp8 kv-cache requires head_size to be a multiple of 16.") - - seed = 42 - random.seed(seed) - torch.manual_seed(seed) - torch.set_default_device(device) - - # create random key / value tensors [T, H, D]. - key = torch.randn(num_tokens, - num_heads, - head_size, - dtype=dtype, - device=device) - value = torch.randn_like(key) - - # prepare the slot mapping. - # each token is assigned a unique slot in the KV-cache. - num_slots = block_size * num_blocks - if num_tokens > num_slots: - raise ValueError( - "num_tokens cannot exceed the total number of cache slots") - slot_mapping_lst = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) - - num_layers = 1 # for simplicity, we use a single layer - key_caches, value_caches = create_kv_caches_with_random( - num_blocks, - block_size, - num_layers, - num_heads, - head_size, - kv_cache_dtype, - dtype, - device=device, - ) - key_cache, value_cache = key_caches[0], value_caches[0] +): + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=[ + "num_tokens", "num_heads", "head_size", "block_size", + "num_blocks" + ], + x_vals=configs, + line_arg="provider", + line_vals=["vllm", "ipex"] if HAS_IPEX else ["vllm"], + line_names=["vLLM", "IPEX"] if HAS_IPEX else ["vLLM"], + styles=[("blue", "-"), + ("red", "-")] if HAS_IPEX else [("blue", "-")], + ylabel="latency (us)", + plot_name="reshape_and_cache-benchmark", + args={}, + )) + @torch.inference_mode() + def benchmark(num_tokens, + num_heads, + head_size, + block_size, + num_blocks, + provider, + kv_cache_dtype="auto"): + + if kv_cache_dtype == "fp8" and head_size % 16: + raise ValueError( + "fp8 kv-cache requires head_size to be a multiple of 16.") + + torch.manual_seed(42) + torch.set_default_device(device) + + key = torch.randn(num_tokens, + num_heads, + head_size, + dtype=dtype, + device=device) + value = torch.randn_like(key) + num_slots = block_size * num_blocks + if num_tokens > num_slots: + raise ValueError( + "num_tokens cannot exceed the total number of cache slots") + slot_mapping_lst = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, + dtype=torch.long, + device=device) + + num_layers = 1 # for simplicity, we use a single layer + key_caches, value_caches = create_kv_caches_with_random( + num_blocks, + block_size, + num_layers, + num_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) + key_cache, value_cache = key_caches[0], value_caches[0] - # compute per-kernel scaling factors for fp8 conversion (if used). - k_scale = (key.amax() / 64.0).to(torch.float32) - v_scale = (value.amax() / 64.0).to(torch.float32) + # compute per-kernel scaling factors for fp8 conversion (if used). + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) - def run_xpu_benchmark(n_iters: int) -> float: - nonlocal key, value, key_cache, value_cache, slot_mapping torch.xpu.synchronize() - start = time.perf_counter() - for _ in range(n_iters): - ops.reshape_and_cache( + # Warm up + for _ in range(5): + if provider == "vllm": + reshape_and_cache_vllm( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + elif provider == "ipex" and HAS_IPEX: + reshape_and_cache_ipex( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + # Benchmark + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: { + "vllm": reshape_and_cache_vllm, + "ipex": reshape_and_cache_ipex + }[provider]( key, value, key_cache, @@ -85,96 +164,43 @@ def run_xpu_benchmark(n_iters: int) -> float: kv_cache_dtype, k_scale, v_scale, - ) - torch.xpu.synchronize() - end = time.perf_counter() - return (end - start) / n_iters - - # warm-up - run_xpu_benchmark(3) - - lat = run_xpu_benchmark(num_iters) - - # free tensors to mitigate OOM when sweeping - del key, value, key_cache, value_cache, slot_mapping - torch.xpu.empty_cache() - - return lat - - -def main(args): - rows = [] - for exp in range(1, 12): - n_tok = 2**exp - lat = run_benchmark( - num_tokens=n_tok, - num_heads=args.num_heads, - head_size=args.head_size, - block_size=args.block_size, - num_blocks=args.num_blocks, - dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], - kv_cache_dtype=args.kv_cache_dtype, - num_iters=args.iters, - device="xpu", + ), + quantiles=quantiles, ) - rows.append([ - n_tok, - args.num_heads, - args.head_size, - args.block_size, - args.num_blocks, - args.dtype, - args.kv_cache_dtype, - f"{lat * 1e6:.3f}", - ]) - print( - tabulate( - rows, - headers=[ - "num_tokens", - "num_heads", - "head_size", - "block_size", - "num_blocks", - "dtype", - "kv_cache_dtype", - "latency (us)", - ], - )) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + return benchmark -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--num-heads", type=int, default=8) - parser.add_argument( - "--head-size", - type=int, - choices=[64, 80, 96, 112, 120, 128, 192, 256], - default=128, - ) - parser.add_argument("--block-size", - type=int, - choices=[16, 32, 64], - default=64) - parser.add_argument("--num-blocks", type=int, default=1024) - - parser.add_argument( - "--dtype", - type=str, - choices=["half", "bfloat16"], - default="half", - ) - parser.add_argument( - "--kv-cache-dtype", - type=str, - choices=["auto", "fp8", "fp8_e4m3", "fp8_e5m2"], - default="auto", +if __name__ == "__main__": + args = parse_args() + + device = "xpu" + + print("Benchmark Configuration:") + print(f" Num Heads: {args.head_num_range}") + print(f" Head Size: {args.head_size}") + print(f" Block Size: {args.block_size}") + print(f" Num Blocks: {args.num_blocks}") + print(f" Data Type: {args.dtype}") + print(" KV Cache Dtype: auto (IPEX & vLLM)") + print(f" Device: {device}") + if HAS_IPEX: + print(f"✅ IPEX {ipex.__version__} is available.") + else: + print("⚠️ IPEX not available. Only benchmarking vLLM.") + + num_token_range = [2**i for i in range(1, 12)] + head_num_range = args.head_num_range + head_size_range = [args.head_size] + block_size_range = [args.block_size] + num_blocks_range = [args.num_blocks] + configs = list( + itertools.product(num_token_range, head_num_range, head_size_range, + block_size_range, num_blocks_range)) + + benchmark = get_benchmark( + dtype=args.dtype, + device=device, ) - - parser.add_argument("--iters", type=int, default=100) - args = parser.parse_args() - - main(args) + benchmark.run(print_data=True, save_path=None) diff --git a/benchmark/benchmark_rmsnorm.py b/benchmark/benchmark_rmsnorm.py index 4513c07..d58742f 100644 --- a/benchmark/benchmark_rmsnorm.py +++ b/benchmark/benchmark_rmsnorm.py @@ -9,7 +9,7 @@ from torch import nn from tests import register_ops as vllm_ops -from tests.utils import check_ipex_availability, get_model_config +from tests.utils import check_ipex_availability, parse_args HAS_IPEX = check_ipex_availability() @@ -268,112 +268,8 @@ def benchmark(head_num, batch_size, seq_len, provider): return benchmark -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--batch-size", - type=int, - default=4, - help="Batch size", - ) - parser.add_argument( - "--seq-len", - type=int, - default=128, - help="Sequence length", - ) - parser.add_argument( - "--hidden-size", - type=int, - default=4096, - help="Hidden size (2nd dimension) of the sequence", - ) - parser.add_argument( - "--intermediate-size", - type=int, - default=None, - help="Intermediate size for FFN layers", - ) - parser.add_argument( - "--num-groups", - type=int, - default=None, - help="Number of expert groups for MoE models", - ) - parser.add_argument( - "--dtype", - type=str, - default=torch.bfloat16, - help="Data type from model config", - ) - parser.add_argument( - "--model-name", - type=str, - default=None, - help="Model name to load configuration from", - ) - parser.add_argument("--head-num-range", - type=int, - nargs='+', - default=[12, 32, 40, 48, 64, 96, 128], - help=("Range of attention head numbers to test/use. " - "Default: 12 32 40 48 64 96 128")) - parser.add_argument( - "--tp-size", - type=int, - default=1, - help="Tensor parallelism size", - ) - parser.add_argument("--use-residual", - action="store_true", - help="Whether to use residual connection") - parser.add_argument( - "--save-path", - type=str, - default="./configs/rmsnorm/", - help="Path to save rmsnorm benchmark results", - ) - - args = parser.parse_args() - - if args.model_name: - model_config = get_model_config(args.model_name, args.tp_size) - - if args.hidden_size == 4096: - args.hidden_size = model_config["hidden_size"] - - if args.intermediate_size is None: - args.intermediate_size = model_config["intermediate_size"] - - if args.num_groups is None: - args.num_groups = model_config["num_groups"] - - if args.dtype is None: - args.dtype = model_config["dtype"] - - if args.head_num_range == [12, 32, 40, 48, 64, 96, 128]: - model_heads = model_config.get("num_attention_heads", 32) - if model_heads not in args.head_num_range: - args.head_num_range.append(model_heads) - args.head_num_range.sort() - print( - f"Added model's head number {model_heads} to head_num_range" - ) - - print(f"Using model configuration from: {args.model_name}") - print(f"Updated hidden_size: {args.hidden_size}") - print(f"Updated intermediate_size: {args.intermediate_size}") - print(f"Updated num_groups: {args.num_groups}") - print(f"Updated head_num_range: {args.head_num_range}") - print(f"Updated dtype: {args.dtype}") - - return args - - if __name__ == "__main__": - import argparse - args = parse_args() print("Final configuration:") diff --git a/tests/utils.py b/tests/utils.py index 0559bac..4c98a49 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import argparse import random import unittest from collections.abc import Sequence @@ -350,4 +351,109 @@ def check_ipex_availability(): return True else: print("Warning: IPEX not available, skipping IPEX benchmarks") - return False \ No newline at end of file + return False + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-size", + type=int, + default=4096, + help="Hidden size (2nd dimension) of the sequence", + ) + parser.add_argument( + "--intermediate-size", + type=int, + default=None, + help="Intermediate size for FFN layers", + ) + parser.add_argument( + "--num-groups", + type=int, + default=None, + help="Number of expert groups for MoE models", + ) + parser.add_argument( + "--dtype", + type=str, + default=torch.bfloat16, + help="Data type from model config", + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="Model name to load configuration from", + ) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) + parser.add_argument("--num-blocks", type=int, default=1024) + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8", "fp8_e4m3", "fp8_e5m2"], + default="auto", + ) + parser.add_argument("--block-size", + type=int, + choices=[16, 32, 64], + default=64) + parser.add_argument("--head-num-range", + type=int, + nargs='+', + default=[12, 32, 40, 48, 64, 96, 128], + help=("Range of attention head numbers to test/use. " + "Default: 12 32 40 48 64 96 128")) + parser.add_argument( + "--tp-size", + type=int, + default=1, + help="Tensor parallelism size", + ) + parser.add_argument("--use-residual", + action="store_true", + help="Whether to use residual connection") + parser.add_argument( + "--save-path", + type=str, + default="./configs/rmsnorm/", + help="Path to save rmsnorm benchmark results", + ) + + args = parser.parse_args() + + if args.model_name: + model_config = get_model_config(args.model_name, args.tp_size) + + args.hidden_size = model_config["hidden_size"] + args.intermediate_size = model_config["intermediate_size"] + args.num_groups = model_config["num_groups"] + args.dtype = model_config["dtype"] + args.head_size = model_config["head_dim"] + args.head_num_range = [model_config.get("num_attention_heads", 32)] + + print(f"Using model configuration from: {args.model_name}") + print(f"Updated hidden_size: {args.hidden_size}") + print(f"Updated intermediate_size: {args.intermediate_size}") + print(f"Updated num_groups: {args.num_groups}") + print(f"Updated head_num_range: {args.head_num_range}") + print(f"Updated dtype: {args.dtype}") + + return args