Commit 0f10df7
[Intel GPU] Make SDPA output has the same stride as Query. (pytorch#154340)
Fixes [pytorch#153903](pytorch#153903).
Currently the output tensor of SDPA XPU is always defined as contiguous stride, while CPU/CUDA flash_attention and cudnn_attention allocate output tensor with stride the same as Query.
This PR aligns XPU's behavior with CUDA/CPU to make XPU compatible to CPU/CUDA's modeling code.
The function `alloc_with_matching_layout` is copied from cudnn https://github.com/pytorch/pytorch/blob/8c16d0e4047a8ac5885baf52e8779fb3e36f2987/aten/src/ATen/native/cudnn/MHA.cpp#L874
Pull Request resolved: pytorch#154340
Approved by: https://github.com/Skylion007, https://github.com/EikanWang, https://github.com/guangyey1 parent 1e20745 commit 0f10df7
File tree
4 files changed
+68
-4
lines changed- aten/src/ATen/native/mkldnn/xpu
- detail
- test
4 files changed
+68
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
190 | 190 | | |
191 | 191 | | |
192 | 192 | | |
193 | | - | |
194 | | - | |
| 193 | + | |
195 | 194 | | |
196 | 195 | | |
197 | 196 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
330 | 330 | | |
331 | 331 | | |
332 | 332 | | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
333 | 366 | | |
334 | 367 | | |
335 | 368 | | |
| |||
347 | 380 | | |
348 | 381 | | |
349 | 382 | | |
350 | | - | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
351 | 391 | | |
352 | 392 | | |
353 | 393 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
178 | 178 | | |
179 | 179 | | |
180 | 180 | | |
181 | | - | |
| 181 | + | |
182 | 182 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4059 | 4059 | | |
4060 | 4060 | | |
4061 | 4061 | | |
| 4062 | + | |
| 4063 | + | |
| 4064 | + | |
| 4065 | + | |
| 4066 | + | |
| 4067 | + | |
| 4068 | + | |
| 4069 | + | |
| 4070 | + | |
| 4071 | + | |
| 4072 | + | |
| 4073 | + | |
| 4074 | + | |
| 4075 | + | |
| 4076 | + | |
| 4077 | + | |
| 4078 | + | |
| 4079 | + | |
| 4080 | + | |
| 4081 | + | |
| 4082 | + | |
| 4083 | + | |
| 4084 | + | |
| 4085 | + | |
| 4086 | + | |
4062 | 4087 | | |
4063 | 4088 | | |
4064 | 4089 | | |
| |||
0 commit comments