Skip to content

Commit 20852c8

Browse files
authored
[CPU] Refactor CPU WNA16 (#28826)
Signed-off-by: jiang1.li <[email protected]>
1 parent 40b6b38 commit 20852c8

File tree

22 files changed

+1656
-78
lines changed

22 files changed

+1656
-78
lines changed

.buildkite/scripts/hardware_ci/run-cpu-test.sh

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,11 @@ function cpu_tests() {
7373
pytest -x -s -v \
7474
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs"
7575

76-
# Note: disable it until supports V1
77-
# Run AWQ test
78-
# docker exec cpu-test-"$NUMA_NODE" bash -c "
79-
# set -e
80-
# pytest -x -s -v \
81-
# tests/quantization/test_ipex_quant.py"
76+
# Run AWQ/GPTQ test
77+
docker exec cpu-test-"$NUMA_NODE" bash -c "
78+
set -e
79+
pytest -x -s -v \
80+
tests/quantization/test_cpu_wna16.py"
8281

8382
# Run multi-lora tests
8483
docker exec cpu-test-"$NUMA_NODE" bash -c "

cmake/cpu_extension.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ set(VLLM_EXT_SRC
375375
if (AVX512_FOUND AND NOT AVX512_DISABLED)
376376
set(VLLM_EXT_SRC
377377
"csrc/cpu/shm.cpp"
378+
"csrc/cpu/cpu_wna16.cpp"
378379
${VLLM_EXT_SRC})
379380
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
380381
set(VLLM_EXT_SRC

csrc/cpu/cpu_attn_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#ifndef CPU_ATTN_HPP
22
#define CPU_ATTN_HPP
33

4-
#include <unistd.h>
54
#include <type_traits>
65
#include <cstddef>
76

@@ -12,6 +11,7 @@
1211
#include "cpu_types.hpp"
1312
#include "scratchpad_manager.h"
1413
#include "cpu_attn_macros.h"
14+
#include "utils.hpp"
1515

1616
namespace cpu_attention {
1717
enum class ISA { AMX, VEC, VEC16 };

csrc/cpu/cpu_types_x86.hpp

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
104104
explicit FP16Vec16(bool, void* ptr)
105105
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
106106

107+
explicit FP16Vec16(const c10::Half v) : reg(_mm256_set1_epi16(v.x)) {}
108+
107109
explicit FP16Vec16(const FP32Vec16&);
108110

109111
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
@@ -141,6 +143,8 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
141143
explicit BF16Vec16(bool, void* ptr)
142144
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
143145

146+
explicit BF16Vec16(const c10::BFloat16 v) : reg(_mm256_set1_epi16(v.x)) {}
147+
144148
explicit BF16Vec16(const FP32Vec16&);
145149

146150
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
@@ -350,6 +354,22 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
350354

351355
explicit FP32Vec16(__m512 data) : reg(data) {}
352356

357+
// de-pack 4 bit values
358+
explicit FP32Vec16(int64_t value, const FP32Vec16& lut) {
359+
int64_t mask_0 = 0x0F0F0F0F0F0F0F0F;
360+
int64_t mask_1 = 0xF0F0F0F0F0F0F0F0;
361+
int64_t value_0 = value & mask_0;
362+
int64_t value_1 = value & mask_1;
363+
__m128i vec_0 = _mm_movpi64_epi64((__m64)value_0);
364+
__m128i vec_1 = _mm_movpi64_epi64((__m64)value_1);
365+
vec_0 = _mm_cvtepu8_epi16(vec_0);
366+
vec_1 = _mm_cvtepu8_epi16(vec_1);
367+
vec_1 = _mm_slli_epi16(vec_1, 4);
368+
__m128i vec = _mm_or_si128(vec_0, vec_1);
369+
__m512i vec_i32 = _mm512_cvtepu8_epi32(vec);
370+
reg = _mm512_permutexvar_ps(vec_i32, lut.reg);
371+
}
372+
353373
explicit FP32Vec16(const FP32Vec4& data)
354374
: reg((__m512)_mm512_inserti32x4(
355375
_mm512_inserti32x4(
@@ -426,14 +446,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
426446

427447
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
428448

429-
template <int group_size>
430-
float reduce_sub_sum(int idx) {
431-
static_assert(VEC_ELEM_NUM % group_size == 0);
432-
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
433-
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
434-
return _mm512_mask_reduce_add_ps(mask, reg);
435-
}
436-
437449
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
438450

439451
void save(float* ptr, const int elem_num) const {
@@ -755,6 +767,25 @@ inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
755767
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
756768
_mm512_stream_ps((float*)ptr, vec.reg);
757769
}
770+
771+
static void interleave_save(const BF16Vec16& vec0, const BF16Vec16& vec1,
772+
void* ptr) {
773+
__m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
774+
__m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
775+
vec_1 = _mm512_slli_epi32(vec_1, 16);
776+
vec_0 = _mm512_or_si512(vec_0, vec_1);
777+
_mm512_storeu_epi32(ptr, vec_0);
778+
}
779+
780+
static void interleave_save(const FP16Vec16& vec0, const FP16Vec16& vec1,
781+
void* ptr) {
782+
__m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
783+
__m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
784+
vec_1 = _mm512_slli_epi32(vec_1, 16);
785+
vec_0 = _mm512_or_si512(vec_0, vec_1);
786+
_mm512_storeu_epi32(ptr, vec_0);
787+
}
788+
758789
#endif
759790

760791
inline void mem_barrier() { _mm_mfence(); }

0 commit comments

Comments
 (0)