Skip to content

Commit 0f10df7

Browse files
LuFinchpytorchmergebot
authored andcommitted
[Intel GPU] Make SDPA output has the same stride as Query. (pytorch#154340)
Fixes [pytorch#153903](pytorch#153903). Currently the output tensor of SDPA XPU is always defined as contiguous stride, while CPU/CUDA flash_attention and cudnn_attention allocate output tensor with stride the same as Query. This PR aligns XPU's behavior with CUDA/CPU to make XPU compatible to CPU/CUDA's modeling code. The function `alloc_with_matching_layout` is copied from cudnn https://github.com/pytorch/pytorch/blob/8c16d0e4047a8ac5885baf52e8779fb3e36f2987/aten/src/ATen/native/cudnn/MHA.cpp#L874 Pull Request resolved: pytorch#154340 Approved by: https://github.com/Skylion007, https://github.com/EikanWang, https://github.com/guangyey
1 parent 1e20745 commit 0f10df7

File tree

4 files changed

+68
-4
lines changed

4 files changed

+68
-4
lines changed

aten/src/ATen/native/mkldnn/xpu/Attention.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
190190
const int64_t seq_len_q = query.size(2);
191191
const int64_t seq_len_kv = key.size(2);
192192

193-
auto opts = query.options();
194-
auto output = at::empty({batch_size, num_head, seq_len_q, head_dim}, opts);
193+
at::Tensor output;
195194
at::Tensor logsumexp, debug_attn_mask; // not supported
196195

197196
at::native::onednn::gpu_float_sdpa(

aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,39 @@ partition& find_or_create_graph_partition(
330330
}
331331
return *partition_;
332332
}
333+
334+
void alloc_with_matching_layout(
335+
const at::Tensor& q,
336+
at::Tensor& output,
337+
const std::vector<int64_t>& shape) {
338+
TORCH_INTERNAL_ASSERT(
339+
shape.size() == q.sizes().size(),
340+
"OneDNN SDPA alloc_with_matching_layout got requested shape ndim != q ndim");
341+
342+
if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) {
343+
output = at::empty_like(q);
344+
return;
345+
}
346+
347+
// get the "fill order," which is just an argsort on the strides
348+
std::vector<int> fill_order(shape.size());
349+
std::iota(fill_order.begin(), fill_order.end(), 0);
350+
const auto q_strides = q.strides();
351+
std::stable_sort(
352+
fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) {
353+
return q_strides[idx1] < q_strides[idx2];
354+
});
355+
std::vector<int64_t> ordered_strides(shape.size());
356+
int64_t current_stride = 1;
357+
for (const int dim_idx : fill_order) {
358+
ordered_strides[dim_idx] = current_stride;
359+
current_stride *= shape[dim_idx];
360+
}
361+
output = at::empty(at::IntArrayRef(shape), q.options())
362+
.as_strided(
363+
at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0);
364+
}
365+
333366
} // namespace
334367

335368
namespace at::native::onednn {
@@ -347,7 +380,14 @@ void gpu_float_sdpa(
347380
std::optional<at::Tensor> attn_mask,
348381
bool is_causal,
349382
float softmax_scale,
350-
const Tensor& output) {
383+
Tensor& output) {
384+
if (!output.defined()) {
385+
// allocate output tensor with layout matched to query
386+
std::vector<int64_t> output_shape = {
387+
batch_size, num_head, seq_len_q, head_dim_v};
388+
alloc_with_matching_layout(query, output, output_shape);
389+
}
390+
351391
auto& eng = GpuEngineManager::Instance().get_engine();
352392
auto& strm = GpuStreamManager::Instance().get_stream();
353393

aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,5 @@ void gpu_float_sdpa(
178178
std::optional<at::Tensor> attn_mask,
179179
bool is_causal,
180180
float softmax_scale,
181-
const Tensor& output);
181+
Tensor& output);
182182
} // namespace at::native::onednn

test/test_transformers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4059,6 +4059,31 @@ def test_fused_attention_broadcasted_input(self, device):
40594059

40604060
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
40614061

4062+
def test_attention_preserves_query_layout(self, device):
4063+
4064+
def test_attention(permute_order: list[list[int]]):
4065+
BHSqD = [4, 16, 256, 64]
4066+
BHSkvD = [4, 16, 512, 64]
4067+
4068+
shape_q = [BHSqD[idx] for idx in permute_order]
4069+
shape_kv = [BHSkvD[idx] for idx in permute_order]
4070+
reverse = [permute_order.index(idx) for idx in range(4)]
4071+
q = torch.randn(*shape_q, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
4072+
k = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
4073+
v = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
4074+
self.assertEqual(q.shape, BHSqD)
4075+
self.assertEqual(k.shape, BHSkvD)
4076+
self.assertEqual(v.shape, BHSkvD)
4077+
4078+
out = F.scaled_dot_product_attention(q, k, v)
4079+
self.assertTrue(out.permute(permute_order).is_contiguous())
4080+
4081+
permutable = [0, 1, 2]
4082+
permute_orders = itertools.permutations(permutable)
4083+
4084+
for permute_order in permute_orders:
4085+
test_attention(list(permute_order) + [3])
4086+
40624087
@parametrize("type", ["dense"])
40634088
@parametrize("is_contiguous", [True, False])
40644089
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):

0 commit comments

Comments
 (0)