1+ #include < ATen/OpMathType.h>
12#include < ATen/native/mkldnn/xpu/detail/Attr.h>
23#include < ATen/native/mkldnn/xpu/detail/Utils.h>
34#include < ATen/native/mkldnn/xpu/detail/oneDNN.h>
4-
55#include < oneapi/dnnl/dnnl.hpp>
66
7+ namespace {
8+
79using namespace at ::native::onednn;
810using logical_tensor = dnnl::graph::logical_tensor;
911using data_type = logical_tensor::data_type;
1012using dims = logical_tensor::dims;
1113using op = dnnl::graph::op;
1214using partition = dnnl::graph::partition;
1315
14- namespace {
16+ inline data_type to_logical_tensor_data_type (c10::ScalarType scalar_type) {
17+ return scalar_type == c10::ScalarType::Float ? data_type::f32
18+ : scalar_type == c10::ScalarType::Half ? data_type::f16
19+ : scalar_type == c10::ScalarType::BFloat16 ? data_type::bf16
20+ : data_type::undef;
21+ }
22+
1523struct SDPALogicalParams {
1624 enum class TensorID {
1725 query,
@@ -39,11 +47,7 @@ struct SDPALogicalParams {
3947 const std::optional<at::Tensor>& attn_mask_,
4048 const at::Tensor& output_,
4149 bool is_causal) {
42- const data_type dtype = // to logical_tensor data type
43- query_.scalar_type () == c10::ScalarType::Float ? data_type::f32
44- : query_.scalar_type () == c10::ScalarType::Half ? data_type::f16
45- : query_.scalar_type () == c10::ScalarType::BFloat16 ? data_type::bf16
46- : data_type::undef;
50+ const data_type dtype = to_logical_tensor_data_type (query_.scalar_type ());
4751 TORCH_INTERNAL_ASSERT (
4852 (dtype != data_type::undef),
4953 " Only FP16/BF16/FP32 datatypes are currently supported" );
@@ -84,22 +88,27 @@ struct SDPALogicalParams {
8488 reshaped_key.strides ().vec ()};
8589 scale = {
8690 static_cast <size_t >(TensorID::scale),
87- dtype ,
91+ to_logical_tensor_data_type ( at::toOpMathType (query_. scalar_type ())) ,
8892 scalar_shape,
8993 logical_tensor::layout_type::strided,
9094 logical_tensor::property_type::constant};
9195 if (is_causal) {
9296 neg_inf = {
9397 static_cast <size_t >(TensorID::neg_inf),
94- dtype ,
98+ to_logical_tensor_data_type ( at::toOpMathType (query_. scalar_type ())) ,
9599 scalar_shape,
96100 logical_tensor::layout_type::strided,
97101 logical_tensor::property_type::constant};
98102 }
99103 if (attn_mask_.has_value ()) {
104+ const data_type mask_dtype =
105+ to_logical_tensor_data_type (attn_mask_->scalar_type ());
106+ TORCH_INTERNAL_ASSERT (
107+ (mask_dtype != data_type::undef),
108+ " Only FP16/BF16/FP32 datatypes are currently supported for attn_mask" );
100109 attn_mask = {
101110 static_cast <size_t >(TensorID::attn_mask),
102- dtype ,
111+ mask_dtype ,
103112 reshaped_attn_mask.sizes ().vec (),
104113 reshaped_attn_mask.strides ().vec ()};
105114 }
@@ -147,7 +156,12 @@ partition create_sdpa_graph_partition(
147156 size_t lt_id = static_cast <size_t >(SDPALogicalParams::TensorID::end);
148157 size_t op_id = 0 ;
149158
150- logical_tensor matmul_qk_out{lt_id++, dtype};
159+ // OneDNN graph has optimized implementation for `f16` or `bf16` SDPA with
160+ // `f32` intermediate data type on Intel Graphics Products with Intel(R) Xe
161+ // Matrix Extensions (Intel(R) XMX) support, which means the
162+ // Q/K/V tensors have bf16 or f16 data type while the output of the first
163+ // MatMul, Scale, Mask, and the input of SoftMax are in f32 data type.
164+ logical_tensor matmul_qk_out{lt_id++, data_type::f32 };
151165 op matmul_qk{
152166 op_id++,
153167 op::kind::MatMul,
@@ -156,7 +170,7 @@ partition create_sdpa_graph_partition(
156170 " matmul_qk" };
157171 matmul_qk.set_attr <bool >(op::attr::transpose_b, true );
158172
159- logical_tensor scaled_qk_out{lt_id++, dtype };
173+ logical_tensor scaled_qk_out{lt_id++, data_type:: f32 };
160174 op scale_mul{
161175 op_id++,
162176 op::kind::Multiply,
@@ -181,7 +195,7 @@ partition create_sdpa_graph_partition(
181195 if (params.attn_mask .has_value ()) {
182196 TORCH_INTERNAL_ASSERT (
183197 !is_causal, " Additive mask cannot use with is_causal." );
184- masked_qk_out = {lt_id++, dtype };
198+ masked_qk_out = {lt_id++, data_type:: f32 };
185199 mask_add = {
186200 op_id++,
187201 op::kind::Add,
@@ -216,7 +230,7 @@ partition create_sdpa_graph_partition(
216230 {mask_gt_out.value ()},
217231 " mask_gt" };
218232
219- masked_qk_out = {lt_id++, dtype };
233+ masked_qk_out = {lt_id++, data_type:: f32 };
220234 mask_select = {
221235 op_id++,
222236 op::kind::Select,
@@ -349,24 +363,16 @@ void gpu_float_sdpa(
349363 at::scalar_tensor (-std::numeric_limits<float >::infinity (), opts));
350364 };
351365
352- static bool driver_support_implict_causal = true ;
353- if (attn_mask.has_value ()) {
354- TORCH_INTERNAL_ASSERT (
355- !is_causal,
356- " scaled_dot_product_fused_attention_overrideable_xpu: "
357- " attn_mask cannot present with is_causal" );
358- } else {
359- // Currenetly implict mask only supports square fp16 cases
360- const bool support_implict_causal = driver_support_implict_causal &&
361- (query.dtype () == at::kHalf || query.dtype () == at::kBFloat16 ) &&
362- seq_len_q == seq_len_k;
363- if (is_causal && !support_implict_causal) {
364- attn_mask = get_tril_mask ();
365- is_causal = false ;
366- }
366+ // OneDNN doesn't support fp32 ukernel for implicit causal mask,
367+ // and the reference implementation is worse than aten math + explict causal
368+ // mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32
369+ // ukernel for implicit causal mask.
370+ if (is_causal && query.dtype () == at::kFloat ) {
371+ attn_mask = get_tril_mask ();
372+ is_causal = false ;
367373 }
368374
369- std::vector<logical_tensor> l_inputs, l_outputs;
375+ std::vector<dnnl::graph:: logical_tensor> l_inputs, l_outputs;
370376 std::optional<dnnl::graph::compiled_partition> compiled_partition;
371377
372378 auto get_compiled_partition = [&]() {
@@ -388,24 +394,18 @@ void gpu_float_sdpa(
388394 return compiled_partition;
389395 };
390396
391- // maybe retry without causal mask
392- try {
393- compiled_partition = get_compiled_partition ();
394- } catch (std::exception& e) {
395- if (is_causal) {
396- attn_mask = get_tril_mask ();
397- is_causal = false ;
398- compiled_partition = get_compiled_partition ();
399- driver_support_implict_causal = false ;
400- } else {
401- throw e;
402- }
403- }
397+ compiled_partition = get_compiled_partition ();
404398
405- Tensor softmax_scale1 = at::full ({}, softmax_scale, query.options ());
399+ Tensor softmax_scale1 = at::full (
400+ {},
401+ softmax_scale,
402+ query.options ().dtype (at::toOpMathType (query.scalar_type ())));
406403 std::optional<at::Tensor> neg_inf;
407404 if (is_causal) {
408- neg_inf = at::full ({}, -INFINITY, query.options ());
405+ neg_inf = at::full (
406+ {},
407+ -INFINITY,
408+ query.options ().dtype (at::toOpMathType (query.scalar_type ())));
409409 }
410410
411411 std::vector<dnnl::graph::tensor> outputs = {
0 commit comments