Skip to content

Commit d3c8f36

Browse files
Revert "[Intel GPU] Make SDPA output has the same stride as Query. (pytorch#154340)"
This reverts commit 0f10df7. Reverted pytorch#154340 on behalf of https://github.com/etaf due to This PR breaks hugging face E2E run on XPU. ([comment](pytorch#154340 (comment)))
1 parent bb43ced commit d3c8f36

File tree

4 files changed

+4
-68
lines changed

4 files changed

+4
-68
lines changed

aten/src/ATen/native/mkldnn/xpu/Attention.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
190190
const int64_t seq_len_q = query.size(2);
191191
const int64_t seq_len_kv = key.size(2);
192192

193-
at::Tensor output;
193+
auto opts = query.options();
194+
auto output = at::empty({batch_size, num_head, seq_len_q, head_dim}, opts);
194195
at::Tensor logsumexp, debug_attn_mask; // not supported
195196

196197
at::native::onednn::gpu_float_sdpa(

aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -330,39 +330,6 @@ partition& find_or_create_graph_partition(
330330
}
331331
return *partition_;
332332
}
333-
334-
void alloc_with_matching_layout(
335-
const at::Tensor& q,
336-
at::Tensor& output,
337-
const std::vector<int64_t>& shape) {
338-
TORCH_INTERNAL_ASSERT(
339-
shape.size() == q.sizes().size(),
340-
"OneDNN SDPA alloc_with_matching_layout got requested shape ndim != q ndim");
341-
342-
if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) {
343-
output = at::empty_like(q);
344-
return;
345-
}
346-
347-
// get the "fill order," which is just an argsort on the strides
348-
std::vector<int> fill_order(shape.size());
349-
std::iota(fill_order.begin(), fill_order.end(), 0);
350-
const auto q_strides = q.strides();
351-
std::stable_sort(
352-
fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) {
353-
return q_strides[idx1] < q_strides[idx2];
354-
});
355-
std::vector<int64_t> ordered_strides(shape.size());
356-
int64_t current_stride = 1;
357-
for (const int dim_idx : fill_order) {
358-
ordered_strides[dim_idx] = current_stride;
359-
current_stride *= shape[dim_idx];
360-
}
361-
output = at::empty(at::IntArrayRef(shape), q.options())
362-
.as_strided(
363-
at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0);
364-
}
365-
366333
} // namespace
367334

368335
namespace at::native::onednn {
@@ -380,14 +347,7 @@ void gpu_float_sdpa(
380347
std::optional<at::Tensor> attn_mask,
381348
bool is_causal,
382349
float softmax_scale,
383-
Tensor& output) {
384-
if (!output.defined()) {
385-
// allocate output tensor with layout matched to query
386-
std::vector<int64_t> output_shape = {
387-
batch_size, num_head, seq_len_q, head_dim_v};
388-
alloc_with_matching_layout(query, output, output_shape);
389-
}
390-
350+
const Tensor& output) {
391351
auto& eng = GpuEngineManager::Instance().get_engine();
392352
auto& strm = GpuStreamManager::Instance().get_stream();
393353

aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,5 @@ void gpu_float_sdpa(
178178
std::optional<at::Tensor> attn_mask,
179179
bool is_causal,
180180
float softmax_scale,
181-
Tensor& output);
181+
const Tensor& output);
182182
} // namespace at::native::onednn

test/test_transformers.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4059,31 +4059,6 @@ def test_fused_attention_broadcasted_input(self, device):
40594059

40604060
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
40614061

4062-
def test_attention_preserves_query_layout(self, device):
4063-
4064-
def test_attention(permute_order: list[list[int]]):
4065-
BHSqD = [4, 16, 256, 64]
4066-
BHSkvD = [4, 16, 512, 64]
4067-
4068-
shape_q = [BHSqD[idx] for idx in permute_order]
4069-
shape_kv = [BHSkvD[idx] for idx in permute_order]
4070-
reverse = [permute_order.index(idx) for idx in range(4)]
4071-
q = torch.randn(*shape_q, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
4072-
k = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
4073-
v = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
4074-
self.assertEqual(q.shape, BHSqD)
4075-
self.assertEqual(k.shape, BHSkvD)
4076-
self.assertEqual(v.shape, BHSkvD)
4077-
4078-
out = F.scaled_dot_product_attention(q, k, v)
4079-
self.assertTrue(out.permute(permute_order).is_contiguous())
4080-
4081-
permutable = [0, 1, 2]
4082-
permute_orders = itertools.permutations(permutable)
4083-
4084-
for permute_order in permute_orders:
4085-
test_attention(list(permute_order) + [3])
4086-
40874062
@parametrize("type", ["dense"])
40884063
@parametrize("is_contiguous", [True, False])
40894064
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):

0 commit comments

Comments
 (0)