Skip to content

Commit 0e3f06f

Browse files
bigPYJ1151jikunshangzhouyuan
authored
[Hardware][Intel] Add CPU inference backend (#3634)
Co-authored-by: Kunshang Ji <[email protected]> Co-authored-by: Yuan Zhou <[email protected]>
1 parent eb69d68 commit 0e3f06f

24 files changed

+2747
-5
lines changed

.buildkite/run-cpu-test.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# This script build the CPU docker image and run the offline inference inside the container.
2+
# It serves a sanity check for compilation and basic model usage.
3+
set -ex
4+
5+
# Try building the docker image
6+
docker build -t cpu-test -f Dockerfile.cpu .
7+
8+
# Setup cleanup
9+
remove_docker_container() { docker rm -f cpu-test || true; }
10+
trap remove_docker_container EXIT
11+
remove_docker_container
12+
13+
# Run the image and launch offline inference
14+
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-check cpu-test python3 examples/offline_inference.py

.buildkite/test-template.j2

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ steps:
88
queue: amd
99
command: bash .buildkite/run-amd-test.sh
1010

11+
- label: "CPU Test"
12+
command: bash .buildkite/run-cpu-test.sh
13+
1114
- label: ":docker: build image"
1215
commands:
1316
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."

CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.21)
22

33
project(vllm_extensions LANGUAGES CXX)
44

5+
option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
6+
57
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
8+
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
69

710
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
811

@@ -76,6 +79,19 @@ find_package(Torch REQUIRED)
7679
find_library(torch_python_LIBRARY torch_python PATHS
7780
"${TORCH_INSTALL_PREFIX}/lib")
7881

82+
#
83+
# Forward the non-CUDA device extensions to external CMake scripts.
84+
#
85+
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
86+
NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
87+
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
88+
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
89+
else()
90+
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
91+
endif()
92+
return()
93+
endif()
94+
7995
#
8096
# Set up GPU language and check the torch version and warn if it isn't
8197
# what is expected.

Dockerfile.cpu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
2+
3+
FROM ubuntu:22.04
4+
5+
RUN apt-get update -y \
6+
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
7+
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
8+
9+
RUN pip install --upgrade pip \
10+
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy
11+
12+
COPY ./ /workspace/vllm
13+
14+
WORKDIR /workspace/vllm
15+
16+
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
17+
18+
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
19+
20+
CMD ["/bin/bash"]

cmake/cpu_extension.cmake

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
2+
3+
#
4+
# Define environment variables for special configurations
5+
#
6+
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
7+
set(ENABLE_AVX512BF16 ON)
8+
endif()
9+
10+
include_directories("${CMAKE_SOURCE_DIR}/csrc")
11+
12+
#
13+
# Check the compile flags
14+
#
15+
list(APPEND CXX_COMPILE_FLAGS
16+
"-fopenmp"
17+
"-DVLLM_CPU_EXTENSION")
18+
19+
execute_process(COMMAND cat /proc/cpuinfo
20+
RESULT_VARIABLE CPUINFO_RET
21+
OUTPUT_VARIABLE CPUINFO)
22+
23+
if (NOT CPUINFO_RET EQUAL 0)
24+
message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
25+
endif()
26+
27+
function (find_isa CPUINFO TARGET OUT)
28+
string(FIND ${CPUINFO} ${TARGET} ISA_FOUND)
29+
if(NOT ISA_FOUND EQUAL -1)
30+
set(${OUT} ON PARENT_SCOPE)
31+
else()
32+
set(${OUT} OFF PARENT_SCOPE)
33+
endif()
34+
endfunction()
35+
36+
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
37+
38+
if (AVX512_FOUND)
39+
list(APPEND CXX_COMPILE_FLAGS
40+
"-mavx512f"
41+
"-mavx512vl"
42+
"-mavx512bw"
43+
"-mavx512dq")
44+
45+
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
46+
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
47+
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
48+
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
49+
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
50+
else()
51+
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
52+
endif()
53+
else()
54+
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
55+
endif()
56+
else()
57+
message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
58+
endif()
59+
60+
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
61+
62+
63+
#
64+
# Define extension targets
65+
#
66+
67+
#
68+
# _C extension
69+
#
70+
set(VLLM_EXT_SRC
71+
"csrc/cpu/activation.cpp"
72+
"csrc/cpu/attention.cpp"
73+
"csrc/cpu/cache.cpp"
74+
"csrc/cpu/layernorm.cpp"
75+
"csrc/cpu/pos_encoding.cpp"
76+
"csrc/cpu/pybind.cpp")
77+
78+
define_gpu_extension_target(
79+
_C
80+
DESTINATION vllm
81+
LANGUAGE CXX
82+
SOURCES ${VLLM_EXT_SRC}
83+
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
84+
WITH_SOABI
85+
)
86+
87+
add_custom_target(default)
88+
message(STATUS "Enabling C extension.")
89+
add_dependencies(default _C)
90+

csrc/cpu/activation.cpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#include "cpu_types.hpp"
2+
3+
namespace {
4+
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &),
5+
bool is_gated>
6+
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
7+
scalar_t *__restrict__ output) {
8+
using scalar_vec_t = vec_op::vec_t<scalar_t>;
9+
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
10+
11+
TORCH_CHECK(d % VEC_ELEM_NUM == 0);
12+
13+
#pragma omp parallel for
14+
for (int i = 0; i < num_tokens; ++i) {
15+
for (int j = 0; j < d; j += VEC_ELEM_NUM) {
16+
int start = i * d;
17+
if constexpr (is_gated) {
18+
start *= 2;
19+
}
20+
21+
const scalar_vec_t x(input + start + j);
22+
const vec_op::FP32Vec8 f32_x(x);
23+
vec_op::FP32Vec8 f32_ans = func(f32_x);
24+
25+
if constexpr (is_gated) {
26+
const scalar_vec_t y(input + start + d + j);
27+
const vec_op::FP32Vec8 f32_y(y);
28+
f32_ans = f32_y * f32_ans;
29+
}
30+
31+
const scalar_vec_t result(f32_ans);
32+
result.save(output + i * d + j);
33+
}
34+
}
35+
}
36+
37+
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
38+
const vec_op::FP32Vec8 zeros(0.0);
39+
const vec_op::FP32Vec8 ones(1.0);
40+
return x / (ones + (zeros - x).exp());
41+
}
42+
43+
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
44+
const vec_op::FP32Vec8 ones(1.0);
45+
const vec_op::FP32Vec8 w1(0.79788456f);
46+
const vec_op::FP32Vec8 w2(0.044715f);
47+
const vec_op::FP32Vec8 w3(0.5);
48+
const vec_op::FP32Vec8 x3 = x * x * x;
49+
const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
50+
return w3 * x * (ones + t);
51+
}
52+
53+
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
54+
const vec_op::FP32Vec8 ones(1.0);
55+
const vec_op::FP32Vec8 w1(0.79788456f);
56+
const vec_op::FP32Vec8 w2(0.044715f);
57+
const vec_op::FP32Vec8 w3(0.5);
58+
const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
59+
return w3 * x * (ones + t);
60+
}
61+
62+
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
63+
const vec_op::FP32Vec8 ones(1.0);
64+
const vec_op::FP32Vec8 w1(M_SQRT1_2);
65+
const vec_op::FP32Vec8 w2(0.5);
66+
return x * w2 * (ones + (x * w1).er());
67+
}
68+
69+
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
70+
const vec_op::FP32Vec8 ones(1.0);
71+
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
72+
const vec_op::FP32Vec8 w2(0.5);
73+
const vec_op::FP32Vec8 w3(0.044715);
74+
const vec_op::FP32Vec8 x_3 = x * x * x;
75+
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
76+
return x * w2 * (ones + inner.tanh());
77+
}
78+
}; // namespace
79+
80+
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
81+
int num_tokens = input.numel() / input.size(-1);
82+
int d = input.size(-1) / 2;
83+
84+
VLLM_DISPATCH_FLOATING_TYPES(
85+
input.scalar_type(), "silu_and_mul_impl", [&] {
86+
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
87+
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
88+
input.data_ptr<scalar_t>(),
89+
out.data_ptr<scalar_t>());
90+
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
91+
});
92+
}
93+
94+
void gelu_and_mul(torch::Tensor &out, // [..., d]
95+
torch::Tensor &input) // [..., 2 * d]
96+
{
97+
int num_tokens = input.numel() / input.size(-1);
98+
int d = input.size(-1) / 2;
99+
100+
VLLM_DISPATCH_FLOATING_TYPES(
101+
input.scalar_type(), "gelu_and_mul_impl", [&] {
102+
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
103+
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
104+
input.data_ptr<scalar_t>(),
105+
out.data_ptr<scalar_t>());
106+
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
107+
});
108+
}
109+
110+
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
111+
torch::Tensor &input) // [..., 2 * d]
112+
{
113+
int num_tokens = input.numel() / input.size(-1);
114+
int d = input.size(-1) / 2;
115+
116+
VLLM_DISPATCH_FLOATING_TYPES(
117+
input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
118+
CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
119+
activation_kernel<scalar_t, gelu_tanh_act, true>(
120+
num_tokens, d, input.data_ptr<scalar_t>(),
121+
out.data_ptr<scalar_t>());
122+
CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
123+
});
124+
}
125+
126+
void gelu_new(torch::Tensor &out, torch::Tensor &input) {
127+
int num_tokens = input.numel() / input.size(-1);
128+
int d = input.size(-1);
129+
130+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
131+
CPU_KERNEL_GUARD_IN(gelu_new_impl)
132+
activation_kernel<scalar_t, gelu_new_act, false>(
133+
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
134+
CPU_KERNEL_GUARD_OUT(gelu_new_impl)
135+
});
136+
}
137+
138+
void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
139+
int num_tokens = input.numel() / input.size(-1);
140+
int d = input.size(-1);
141+
142+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
143+
CPU_KERNEL_GUARD_IN(gelu_fast_impl)
144+
activation_kernel<scalar_t, gelu_fast_act, false>(
145+
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
146+
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
147+
});
148+
}

0 commit comments

Comments
 (0)