Skip to content

Commit 7b07434

Browse files
LuFinchpytorchmergebot
authored andcommitted
[Intel GPU] Support f32 intermediate dtype, headdim size <=576 and f32 causal mask for SDPA (pytorch#152091)
In OneDNN v3.7, SDPA has below defects: 1. The dtype of intermediate value is the same as QKV, while Pytorch uses FP32 dtype for intermediate value to make sure better accuracy. 2. Only support headdim size <= 256. 3. Don't support implict causal mask when QKV is FP32. We need to build an attention mask explicitly with aten ops. In OneDNN v3.8, they have update for these defects. Since these are tiny changes, I decided to put them in single PR. Pull Request resolved: pytorch#152091 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/drisspg
1 parent 4d93985 commit 7b07434

File tree

4 files changed

+54
-52
lines changed

4 files changed

+54
-52
lines changed

aten/src/ATen/native/mkldnn/xpu/Attention.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
2424
}
2525
return false;
2626
}
27-
if (query_size_last > 256) {
27+
constexpr int MAX_HEAD_DIM = 576;
28+
if (query_size_last > MAX_HEAD_DIM) {
2829
if (debug) {
2930
TORCH_WARN(
30-
"OneDNN attention requires q,k,v to have head dimension less than 256.",
31-
" Got ",
31+
"OneDNN attention requires q,k,v to have head dimension less than ",
32+
MAX_HEAD_DIM,
33+
". Got ",
3234
query_size_last,
3335
" instead.");
3436
}

aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
1+
#include <ATen/OpMathType.h>
12
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
23
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
34
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
4-
55
#include <oneapi/dnnl/dnnl.hpp>
66

7+
namespace {
8+
79
using namespace at::native::onednn;
810
using logical_tensor = dnnl::graph::logical_tensor;
911
using data_type = logical_tensor::data_type;
1012
using dims = logical_tensor::dims;
1113
using op = dnnl::graph::op;
1214
using partition = dnnl::graph::partition;
1315

14-
namespace {
16+
inline data_type to_logical_tensor_data_type(c10::ScalarType scalar_type) {
17+
return scalar_type == c10::ScalarType::Float ? data_type::f32
18+
: scalar_type == c10::ScalarType::Half ? data_type::f16
19+
: scalar_type == c10::ScalarType::BFloat16 ? data_type::bf16
20+
: data_type::undef;
21+
}
22+
1523
struct SDPALogicalParams {
1624
enum class TensorID {
1725
query,
@@ -39,11 +47,7 @@ struct SDPALogicalParams {
3947
const std::optional<at::Tensor>& attn_mask_,
4048
const at::Tensor& output_,
4149
bool is_causal) {
42-
const data_type dtype = // to logical_tensor data type
43-
query_.scalar_type() == c10::ScalarType::Float ? data_type::f32
44-
: query_.scalar_type() == c10::ScalarType::Half ? data_type::f16
45-
: query_.scalar_type() == c10::ScalarType::BFloat16 ? data_type::bf16
46-
: data_type::undef;
50+
const data_type dtype = to_logical_tensor_data_type(query_.scalar_type());
4751
TORCH_INTERNAL_ASSERT(
4852
(dtype != data_type::undef),
4953
"Only FP16/BF16/FP32 datatypes are currently supported");
@@ -84,22 +88,27 @@ struct SDPALogicalParams {
8488
reshaped_key.strides().vec()};
8589
scale = {
8690
static_cast<size_t>(TensorID::scale),
87-
dtype,
91+
to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())),
8892
scalar_shape,
8993
logical_tensor::layout_type::strided,
9094
logical_tensor::property_type::constant};
9195
if (is_causal) {
9296
neg_inf = {
9397
static_cast<size_t>(TensorID::neg_inf),
94-
dtype,
98+
to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())),
9599
scalar_shape,
96100
logical_tensor::layout_type::strided,
97101
logical_tensor::property_type::constant};
98102
}
99103
if (attn_mask_.has_value()) {
104+
const data_type mask_dtype =
105+
to_logical_tensor_data_type(attn_mask_->scalar_type());
106+
TORCH_INTERNAL_ASSERT(
107+
(mask_dtype != data_type::undef),
108+
"Only FP16/BF16/FP32 datatypes are currently supported for attn_mask");
100109
attn_mask = {
101110
static_cast<size_t>(TensorID::attn_mask),
102-
dtype,
111+
mask_dtype,
103112
reshaped_attn_mask.sizes().vec(),
104113
reshaped_attn_mask.strides().vec()};
105114
}
@@ -147,7 +156,12 @@ partition create_sdpa_graph_partition(
147156
size_t lt_id = static_cast<size_t>(SDPALogicalParams::TensorID::end);
148157
size_t op_id = 0;
149158

150-
logical_tensor matmul_qk_out{lt_id++, dtype};
159+
// OneDNN graph has optimized implementation for `f16` or `bf16` SDPA with
160+
// `f32` intermediate data type on Intel Graphics Products with Intel(R) Xe
161+
// Matrix Extensions (Intel(R) XMX) support, which means the
162+
// Q/K/V tensors have bf16 or f16 data type while the output of the first
163+
// MatMul, Scale, Mask, and the input of SoftMax are in f32 data type.
164+
logical_tensor matmul_qk_out{lt_id++, data_type::f32};
151165
op matmul_qk{
152166
op_id++,
153167
op::kind::MatMul,
@@ -156,7 +170,7 @@ partition create_sdpa_graph_partition(
156170
"matmul_qk"};
157171
matmul_qk.set_attr<bool>(op::attr::transpose_b, true);
158172

159-
logical_tensor scaled_qk_out{lt_id++, dtype};
173+
logical_tensor scaled_qk_out{lt_id++, data_type::f32};
160174
op scale_mul{
161175
op_id++,
162176
op::kind::Multiply,
@@ -181,7 +195,7 @@ partition create_sdpa_graph_partition(
181195
if (params.attn_mask.has_value()) {
182196
TORCH_INTERNAL_ASSERT(
183197
!is_causal, "Additive mask cannot use with is_causal.");
184-
masked_qk_out = {lt_id++, dtype};
198+
masked_qk_out = {lt_id++, data_type::f32};
185199
mask_add = {
186200
op_id++,
187201
op::kind::Add,
@@ -216,7 +230,7 @@ partition create_sdpa_graph_partition(
216230
{mask_gt_out.value()},
217231
"mask_gt"};
218232

219-
masked_qk_out = {lt_id++, dtype};
233+
masked_qk_out = {lt_id++, data_type::f32};
220234
mask_select = {
221235
op_id++,
222236
op::kind::Select,
@@ -349,24 +363,16 @@ void gpu_float_sdpa(
349363
at::scalar_tensor(-std::numeric_limits<float>::infinity(), opts));
350364
};
351365

352-
static bool driver_support_implict_causal = true;
353-
if (attn_mask.has_value()) {
354-
TORCH_INTERNAL_ASSERT(
355-
!is_causal,
356-
"scaled_dot_product_fused_attention_overrideable_xpu: "
357-
"attn_mask cannot present with is_causal");
358-
} else {
359-
// Currenetly implict mask only supports square fp16 cases
360-
const bool support_implict_causal = driver_support_implict_causal &&
361-
(query.dtype() == at::kHalf || query.dtype() == at::kBFloat16) &&
362-
seq_len_q == seq_len_k;
363-
if (is_causal && !support_implict_causal) {
364-
attn_mask = get_tril_mask();
365-
is_causal = false;
366-
}
366+
// OneDNN doesn't support fp32 ukernel for implicit causal mask,
367+
// and the reference implementation is worse than aten math + explict causal
368+
// mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32
369+
// ukernel for implicit causal mask.
370+
if (is_causal && query.dtype() == at::kFloat) {
371+
attn_mask = get_tril_mask();
372+
is_causal = false;
367373
}
368374

369-
std::vector<logical_tensor> l_inputs, l_outputs;
375+
std::vector<dnnl::graph::logical_tensor> l_inputs, l_outputs;
370376
std::optional<dnnl::graph::compiled_partition> compiled_partition;
371377

372378
auto get_compiled_partition = [&]() {
@@ -388,24 +394,18 @@ void gpu_float_sdpa(
388394
return compiled_partition;
389395
};
390396

391-
// maybe retry without causal mask
392-
try {
393-
compiled_partition = get_compiled_partition();
394-
} catch (std::exception& e) {
395-
if (is_causal) {
396-
attn_mask = get_tril_mask();
397-
is_causal = false;
398-
compiled_partition = get_compiled_partition();
399-
driver_support_implict_causal = false;
400-
} else {
401-
throw e;
402-
}
403-
}
397+
compiled_partition = get_compiled_partition();
404398

405-
Tensor softmax_scale1 = at::full({}, softmax_scale, query.options());
399+
Tensor softmax_scale1 = at::full(
400+
{},
401+
softmax_scale,
402+
query.options().dtype(at::toOpMathType(query.scalar_type())));
406403
std::optional<at::Tensor> neg_inf;
407404
if (is_causal) {
408-
neg_inf = at::full({}, -INFINITY, query.options());
405+
neg_inf = at::full(
406+
{},
407+
-INFINITY,
408+
query.options().dtype(at::toOpMathType(query.scalar_type())));
409409
}
410410

411411
std::vector<dnnl::graph::tensor> outputs = {

cmake/Modules/FindMKLDNN.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ IF(NOT MKLDNN_FOUND)
5252
endif()
5353
ExternalProject_Add(xpu_mkldnn_proj
5454
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN
55-
GIT_TAG v3.7.1
55+
GIT_TAG v3.8.1
5656
PREFIX ${XPU_MKLDNN_DIR_PREFIX}
5757
BUILD_IN_SOURCE 0
5858
CMAKE_ARGS -DCMAKE_C_COMPILER=icx

test/test_transformers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4023,11 +4023,11 @@ def test_fused_attention_different_dk_dv(self, device):
40234023

40244024
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
40254025

4026-
def test_onednn_attention_fail_d256(self, device):
4027-
# Test that onednn graph attention dispatching correctly bails out on d > 256
4026+
def test_onednn_attention_fail_d576(self, device):
4027+
# Test that onednn graph attention dispatching correctly bails out on d > 576
40284028
b, h = 1, 2
40294029
s_q, s_kv = 128, 128
4030-
d_qk, d_v = 512, 512
4030+
d_qk, d_v = 1024, 1024
40314031

40324032
q = torch.randn(b, h, s_q, d_qk, device=device, dtype=torch.bfloat16)
40334033
k = torch.randn(b, h, s_kv, d_qk, device=device, dtype=torch.bfloat16)

0 commit comments

Comments
 (0)