Skip to content

Commit 80aa7e9

Browse files
[Hardware][Intel] Optimize CPU backend and add more performance tips (#4971)
Co-authored-by: Jianan Gu <[email protected]>
1 parent bd43973 commit 80aa7e9

File tree

6 files changed

+165
-13
lines changed

6 files changed

+165
-13
lines changed

Dockerfile.cpu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
FROM ubuntu:22.04 AS cpu-test-1
44

55
RUN apt-get update -y \
6-
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
6+
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \
77
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
88

9+
RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc
10+
11+
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
12+
913
RUN pip install --upgrade pip \
1014
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
1115

@@ -21,6 +25,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
2125

2226
WORKDIR /workspace/
2327

24-
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
28+
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
2529

2630
CMD ["/bin/bash"]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ vLLM is flexible and easy to use with:
6565
- Tensor parallelism support for distributed inference
6666
- Streaming outputs
6767
- OpenAI-compatible API server
68-
- Support NVIDIA GPUs and AMD GPUs
68+
- Support NVIDIA GPUs, AMD GPUs, and Intel CPUs
6969
- (Experimental) Prefix caching support
7070
- (Experimental) Multi-lora support
7171

docs/source/getting_started/cpu-installation.rst

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Table of contents:
1010
#. :ref:`Requirements <cpu_backend_requirements>`
1111
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
1212
#. :ref:`Build from source <build_cpu_backend_from_source>`
13+
#. :ref:`Intel Extension for PyTorch <ipex_guidance>`
1314
#. :ref:`Performance tips <cpu_backend_performance_tips>`
1415

1516
.. _cpu_backend_requirements:
@@ -18,7 +19,7 @@ Requirements
1819
------------
1920

2021
* OS: Linux
21-
* Compiler: gcc/g++>=12.3.0 (recommended)
22+
* Compiler: gcc/g++>=12.3.0 (optional, recommended)
2223
* Instruction set architecture (ISA) requirement: AVX512 is required.
2324

2425
.. _cpu_backend_quick_start_dockerfile:
@@ -41,7 +42,7 @@ Quick start using Dockerfile
4142
Build from source
4243
-----------------
4344

44-
- First, install required compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
45+
- First, install recommended compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
4546

4647
.. code-block:: console
4748
@@ -70,13 +71,31 @@ Build from source
7071

7172
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
7273

74+
.. _ipex_guidance:
75+
76+
Intel Extension for PyTorch
77+
---------------------------
78+
79+
- `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
80+
81+
- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed.
82+
7383
.. _cpu_backend_performance_tips:
7484

7585
Performance tips
7686
-----------------
7787

7888
- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
7989

90+
- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:
91+
92+
.. code-block:: console
93+
94+
$ sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library
95+
$ find / -name *libtcmalloc* # find the dynamic link library path
96+
$ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
97+
$ python examples/offline_inference.py # run vLLM
98+
8099
- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription.
81100

82101
- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading.

requirements-cpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
-r requirements-common.txt
33

44
# Dependencies for x86_64 CPUs
5-
torch == 2.3.0+cpu
5+
torch == 2.3.1+cpu
66
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.

vllm/attention/backends/torch_sdpa.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,16 @@
88

99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1010
AttentionMetadata)
11-
from vllm.attention.ops.paged_attn import (PagedAttention,
12-
PagedAttentionMetadata)
11+
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
12+
from vllm.utils import is_cpu
13+
14+
if is_cpu():
15+
try:
16+
from vllm.attention.ops.ipex_attn import PagedAttention
17+
except ImportError:
18+
from vllm.attention.ops.paged_attn import PagedAttention
19+
else:
20+
from vllm.attention.ops.paged_attn import PagedAttention
1321

1422

1523
class TorchSDPABackend(AttentionBackend):
@@ -197,13 +205,14 @@ def forward(
197205
attn_metadata.attn_bias):
198206
end = start + seq_len
199207
sub_out = scaled_dot_product_attention(
200-
query[:, start:end, :],
201-
key[:, start:end, :],
202-
value[:, start:end, :],
208+
query[None, :, start:end, :],
209+
key[None, :, start:end, :],
210+
value[None, :, start:end, :],
203211
attn_mask=mask,
204212
dropout_p=0.0,
205213
is_causal=not self.need_mask,
206-
scale=self.scale).movedim(query.dim() - 2, 0)
214+
scale=self.scale).squeeze(0).movedim(
215+
query.dim() - 2, 0)
207216
output[start:end, :, :] = sub_out
208217
start = end
209218
else:
@@ -248,7 +257,7 @@ def _make_alibi_bias(
248257

249258
num_heads = alibi_slopes.shape[0]
250259
bias = bias[None, :].repeat((num_heads, 1, 1))
251-
bias.mul_(alibi_slopes[:, None, None])
260+
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
252261
inf_mask = torch.empty(
253262
(1, seq_len, seq_len),
254263
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)

vllm/attention/ops/ipex_attn.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from typing import Dict, List, Optional, Tuple
2+
3+
import intel_extension_for_pytorch.llm.modules as ipex_modules
4+
import torch
5+
6+
from vllm import _custom_ops as ops
7+
8+
9+
class PagedAttention:
10+
11+
@staticmethod
12+
def get_supported_head_sizes() -> List[int]:
13+
return [64, 80, 96, 112, 128, 256]
14+
15+
@staticmethod
16+
def get_kv_cache_shape(
17+
num_blocks: int,
18+
block_size: int,
19+
num_kv_heads: int,
20+
head_size: int,
21+
*args,
22+
) -> Tuple[int, ...]:
23+
return (2, num_blocks, block_size * num_kv_heads * head_size)
24+
25+
@staticmethod
26+
def split_kv_cache(
27+
kv_cache: torch.Tensor,
28+
num_kv_heads: int,
29+
head_size: int,
30+
*args,
31+
) -> Tuple[torch.Tensor, torch.Tensor]:
32+
num_blocks = kv_cache.shape[1]
33+
34+
key_cache = kv_cache[0]
35+
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
36+
value_cache = kv_cache[1]
37+
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
38+
return key_cache, value_cache
39+
40+
@staticmethod
41+
def write_to_paged_cache(
42+
key: torch.Tensor,
43+
value: torch.Tensor,
44+
key_cache: torch.Tensor,
45+
value_cache: torch.Tensor,
46+
slot_mapping: torch.Tensor,
47+
kv_cache_dtype: str,
48+
kv_scale: float,
49+
*args,
50+
) -> None:
51+
ipex_modules.PagedAttention.reshape_and_cache(
52+
key, value, key_cache, value_cache,
53+
slot_mapping.flatten().int())
54+
55+
@staticmethod
56+
def forward_decode(
57+
query: torch.Tensor,
58+
key_cache: torch.Tensor,
59+
value_cache: torch.Tensor,
60+
block_tables: torch.Tensor,
61+
context_lens: torch.Tensor,
62+
max_context_len: int,
63+
kv_cache_dtype: str,
64+
num_kv_heads: int,
65+
scale: float,
66+
alibi_slopes: Optional[torch.Tensor],
67+
kv_scale: float,
68+
*args,
69+
) -> torch.Tensor:
70+
output = torch.empty_like(query)
71+
block_size = value_cache.shape[2]
72+
head_mapping = torch.arange(
73+
0,
74+
num_kv_heads,
75+
device="cpu",
76+
dtype=torch.int32,
77+
).view(num_kv_heads,
78+
1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
79+
ipex_modules.PagedAttention.single_query_cached_kv_attention(
80+
output, query.contiguous(), key_cache, value_cache, head_mapping,
81+
scale, block_tables, context_lens, block_size, max_context_len,
82+
alibi_slopes)
83+
84+
return output
85+
86+
@staticmethod
87+
def forward_prefix(
88+
query: torch.Tensor,
89+
key: torch.Tensor,
90+
value: torch.Tensor,
91+
key_cache: torch.Tensor,
92+
value_cache: torch.Tensor,
93+
block_tables: torch.Tensor,
94+
subquery_start_loc: torch.Tensor,
95+
prompt_lens_tensor: torch.Tensor,
96+
context_lens: torch.Tensor,
97+
max_subquery_len: int,
98+
alibi_slopes: Optional[torch.Tensor],
99+
*args,
100+
) -> torch.Tensor:
101+
raise NotImplementedError
102+
103+
@staticmethod
104+
def swap_blocks(
105+
src_kv_cache: torch.Tensor,
106+
dst_kv_cache: torch.Tensor,
107+
src_to_dst: Dict[int, int],
108+
*args,
109+
) -> None:
110+
raise NotImplementedError
111+
112+
@staticmethod
113+
def copy_blocks(
114+
kv_caches: List[torch.Tensor],
115+
src_to_dists: Dict[int, List[int]],
116+
*args,
117+
) -> None:
118+
key_caches = [kv_cache[0] for kv_cache in kv_caches]
119+
value_caches = [kv_cache[1] for kv_cache in kv_caches]
120+
ops.copy_blocks(key_caches, value_caches, src_to_dists)

0 commit comments

Comments
 (0)