Skip to content

Commit d51e6fa

Browse files
jianan-guchunyuan-w
authored andcommitted
[inductor][cpp] Add FlexAttention support for CPU inference (pytorch#141453)
This PR brings the FlexAttention inference support for the inductor backend in torch.compile (support precisions: bf16 and fp32) on CPUs. Based on the existing CPP template, this PR extends and implements a FlexAttention CPP template to support broad attention variants, and meanwhile brings optimized performance on CPUs. With this, users can transparently extend their Flex Attention usages to CPUs with good and common support from torch.compile, both functionality and performance. For UT tests, in this PR, we include partial critical tests for CPUs as the following (conduct inference tests): ``` pytest test/inductor/test_flex_attention.py `TestFlexAttention` #common functions: run_test preprocess_paged_attention run_paged_attention run_test_with_paged_attention run_test_with_call run_dynamic_test run_automatic_dynamic_test #test functions: test_builtin_score_mods test_builtin_score_mods_automatic_dynamic test_builtin_score_mods_different_seqlen test_builtin_score_mods_different_block_size test_kv_batch_broadcast test_GQA test_cpu_error_message_return_lse test_validate_cpu_dtype_error_message `TestPagedAttention` #test function: test_paged_builtin_score_mods ``` For the rest UTs in `test/inductor/test_flex_attention.py ` and `test/inductor/test_flex_decoding.py`, due to bigger lines of changes (1500+ LOC) that make this PR hard to review, will submit another PR specific for CPU device UTs enabling and refactor. Besides, more optimizations are also planned in follow up PRs, including: - Block sparse computation - Flash decoding tuning Pull Request resolved: pytorch#141453 Approved by: https://github.com/drisspg, https://github.com/leslie-fang-intel Co-authored-by: Wu, Chunyuan <[email protected]>
1 parent 5ba61d7 commit d51e6fa

File tree

6 files changed

+1794
-175
lines changed

6 files changed

+1794
-175
lines changed

aten/src/ATen/native/CPUBlas.cpp

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,9 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
11251125
if (dtype == ScalarType::Half) {
11261126
static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16;
11271127
return fp16_support;
1128+
} else if (dtype == ScalarType::Float) {
1129+
static bool fp32_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx2;
1130+
return fp32_support;
11281131
} else if (dtype == ScalarType::BFloat16) {
11291132
static bool bf16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core;
11301133
return bf16_support;
@@ -1192,18 +1195,29 @@ void brgemm(
11921195
int64_t ld_b,
11931196
int64_t ld_c,
11941197
const bool add_C,
1195-
const at::Half* A,
1196-
const at::Half* B,
1197-
float* C) {
1198+
const float* A,
1199+
const float* B,
1200+
float* C,
1201+
bool is_vnni) {
1202+
1203+
TORCH_CHECK(!is_vnni,
1204+
"Float Brgemm does not support vnni layout.");
1205+
11981206
#if defined(ONEDNN_UKERNEL_ENABLED)
1199-
if (Brgemm::device_check(ScalarType::Half)) {
1200-
Brgemm::call<at::Half, at::Half, float>(
1207+
if (Brgemm::device_check(ScalarType::Float)) {
1208+
Brgemm::call<float, float, float>(
12011209
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
12021210
return;
12031211
}
12041212
#endif
1205-
TORCH_CHECK(false,
1206-
"Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported");
1213+
// fallback path
1214+
auto beta = add_C ? 1 : 0;
1215+
gemm(
1216+
at::native::TransposeType::NoTranspose,
1217+
at::native::TransposeType::NoTranspose,
1218+
N, M, K, 1,
1219+
B, ld_b, A, ld_a,
1220+
beta, C, ld_c);
12071221
}
12081222

12091223
void brgemm(
@@ -1216,22 +1230,64 @@ void brgemm(
12161230
const bool add_C,
12171231
const at::BFloat16* A,
12181232
const at::BFloat16* B,
1219-
float* C) {
1233+
float* C,
1234+
bool is_vnni) {
12201235
#if defined(ONEDNN_UKERNEL_ENABLED)
1221-
if (Brgemm::device_check(ScalarType::BFloat16)) {
1236+
if (is_vnni && Brgemm::device_check(ScalarType::BFloat16)) {
12221237
Brgemm::call<at::BFloat16, at::BFloat16, float>(
12231238
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
12241239
return;
12251240
}
12261241
#endif
1227-
TORCH_CHECK(false,
1228-
"BFloat16 Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512 is supported");
1242+
// fallback path
1243+
TORCH_CHECK(!is_vnni,
1244+
"BFloat16 Brgemm VNNI format is only supported on X64 when oneDNN ukernel is enabled and `amx` is supported");
1245+
auto beta = add_C ? 1 : 0;
1246+
gemm(
1247+
at::native::TransposeType::NoTranspose,
1248+
at::native::TransposeType::NoTranspose,
1249+
N, M, K, 1,
1250+
B, ld_b, A, ld_a,
1251+
beta, C, ld_c);
1252+
}
1253+
1254+
void brgemm(
1255+
int64_t M,
1256+
int64_t N,
1257+
int64_t K,
1258+
int64_t ld_a,
1259+
int64_t ld_b,
1260+
int64_t ld_c,
1261+
const bool add_C,
1262+
const at::Half* A,
1263+
const at::Half* B,
1264+
float* C,
1265+
bool is_vnni) {
1266+
#if defined(ONEDNN_UKERNEL_ENABLED)
1267+
if (is_vnni && Brgemm::device_check(ScalarType::Half)) {
1268+
Brgemm::call<at::Half, at::Half, float>(
1269+
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
1270+
return;
1271+
}
1272+
#endif
1273+
// fallback path
1274+
TORCH_CHECK(!is_vnni,
1275+
"Half Brgemm VNNI format is only supported on X64 when oneDNN ukernel is enabled and `amx_fp16` is supported");
1276+
auto beta = add_C ? 1 : 0;
1277+
gemm(
1278+
at::native::TransposeType::NoTranspose,
1279+
at::native::TransposeType::NoTranspose,
1280+
N, M, K, 1,
1281+
B, ld_b, A, ld_a,
1282+
beta, C, ld_c);
12291283
}
12301284

1231-
void brgemm_release() {
1285+
void brgemm_release(bool is_vnni) {
12321286
#if defined(ONEDNN_UKERNEL_ENABLED)
1233-
dnnl::ukernel::brgemm::release_hw_context();
1234-
Brgemm::get_current() = nullptr;
1287+
if (is_vnni) {
1288+
dnnl::ukernel::brgemm::release_hw_context();
1289+
Brgemm::get_current() = nullptr;
1290+
}
12351291
#endif
12361292
}
12371293

aten/src/ATen/native/CPUBlas.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ TORCH_API void brgemm(
204204
const bool add_C,
205205
const at::Half* A,
206206
const at::Half* B,
207-
float* C);
207+
float* C,
208+
bool is_vnni = true);
208209

209210
TORCH_API void brgemm(
210211
int64_t M,
@@ -216,10 +217,24 @@ TORCH_API void brgemm(
216217
const bool add_C,
217218
const at::BFloat16* A,
218219
const at::BFloat16* B,
219-
float* C);
220+
float* C,
221+
bool is_vnni = true);
222+
223+
TORCH_API void brgemm(
224+
int64_t M,
225+
int64_t N,
226+
int64_t K,
227+
int64_t ld_a,
228+
int64_t ld_b,
229+
int64_t ld_c,
230+
const bool add_C,
231+
const float* A,
232+
const float* B,
233+
float* C,
234+
bool is_vnni = false);
220235

221236
// Release brgemm hardware context
222-
TORCH_API void brgemm_release();
237+
TORCH_API void brgemm_release(bool is_vnni = true);
223238

224239
// Pack B matrix to get better performance if needed
225240
void pack(
@@ -233,6 +248,6 @@ void pack(
233248
void* out);
234249

235250
// Whether pack is supported in the platform.
236-
bool could_pack(ScalarType dt_in);
251+
TORCH_API bool could_pack(ScalarType dt_in);
237252

238253
} // namespace at::native::cpublas

0 commit comments

Comments
 (0)