Skip to content

Commit 60a5050

Browse files
xw285cornellpytorchmergebot
authored andcommitted
[AMD] SDPA internal changes (pytorch#144320)
Summary: All the internal changes needed to enable flash attention w/ SDPA in fbcode. Test Plan: ``` TORCH_ROCM_FA_PREFER_CK=1 buck run -m rocm621 mode/opt-amd-gpu scripts/xdwang/example:sdpa +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | Batch Size | Sequence Length | Heads | Head Dim | Flash Time (µs) | Math Time (µs) | xformers Time (µs) | Flash TFlops | Math TFlops | xformers TFlops | Speedup (Flash/Math) | Speedup (xformers/Math) | xformers trace_url | Flash trace_url | +==============+===================+=========+============+===================+==================+======================+================+===============+===================+========================+===========================+======================+===================+ | 1 | 4096 | 32 | 64 | 455.552 | 7748.76 | 513.449 | 301.698 | 17.7369 | 267.678 | 17.0096 | 15.0916 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | 1 | 4096 | 16 | 128 | 329.971 | 4741.11 | 386.049 | 416.519 | 28.9888 | 356.014 | 14.3683 | 12.2811 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | 1 | 8192 | 32 | 64 | 1455.76 | 31869.6 | 1665.49 | 377.642 | 17.2501 | 330.087 | 21.8921 | 19.1353 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | 1 | 8192 | 16 | 128 | 1265.77 | 18972.8 | 1479.48 | 434.325 | 28.976 | 371.588 | 14.9891 | 12.824 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | 1 | 16384 | 32 | 64 | 5732.99 | 121861 | 6816.77 | 383.573 | 18.0453 | 322.59 | 21.2562 | 17.8767 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | 1 | 16384 | 16 | 128 | 4749.69 | 73776.4 | 5404.03 | 462.982 | 29.8066 | 406.923 | 15.5329 | 13.6521 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | Batch Size | Sequence Length | Heads | Head Dim | Flash Time (µs) | Math Time (µs) | xformers Time (µs) | Flash TFlops | Math TFlops | xformers TFlops | Speedup (Flash/Math) | Speedup (xformers/Math) | xformers trace_url | Flash trace_url | +==============+===================+=========+============+===================+==================+======================+================+===============+===================+========================+===========================+======================+===================+ | 1 | 4096 | 32 | 64 | 1615.41 | 8342.67 | 1822.72 | 212.7 | 41.1855 | 188.508 | 5.16443 | 4.57705 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | 1 | 4096 | 16 | 128 | 1357.97 | 5943.53 | 1432.34 | 253.022 | 57.8104 | 239.886 | 4.37676 | 4.14953 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | 1 | 8192 | 32 | 64 | 5556.5 | 31726.7 | 6502.17 | 247.348 | 43.3197 | 211.374 | 5.70984 | 4.8794 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | 1 | 8192 | 16 | 128 | 5186 | 22529.4 | 5590.36 | 265.019 | 61.0044 | 245.85 | 4.34427 | 4.03004 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | 1 | 16384 | 32 | 64 | 22527.7 | 130413 | 26527.6 | 244.035 | 42.155 | 207.239 | 5.789 | 4.91613 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ | 1 | 16384 | 16 | 128 | 18347.9 | 87553.2 | 20358 | 299.628 | 62.791 | 270.044 | 4.77184 | 4.30068 | | | +--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+ ``` Reviewed By: leitian, feikou, yoyoyocmu, sijiac Differential Revision: D67262726 Pull Request resolved: pytorch#144320 Approved by: https://github.com/jianyuh, https://github.com/eqy, https://github.com/leitian
1 parent 7d9f26d commit 60a5050

File tree

7 files changed

+640
-484
lines changed

7 files changed

+640
-484
lines changed

aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
126126
int window_size_left,
127127
int window_size_right,
128128
const bool return_softmax,
129-
std::optional<at::Generator> gen_) {
129+
const std::optional<at::Generator>& gen_) {
130130
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
131131
check_gpu_arch(stream);
132132

@@ -254,7 +254,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
254254
int window_size_left,
255255
int window_size_right,
256256
const bool return_softmax,
257-
std::optional<at::Generator> gen_) {
257+
const std::optional<at::Generator>& gen_) {
258258
TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt");
259259
const bool paged_KV = block_table_.has_value();
260260
TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt");
@@ -418,8 +418,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
418418
int window_size_left,
419419
int window_size_right,
420420
const bool deterministic,
421-
const at::Tensor philox_seed,
422-
const at::Tensor philox_offset) {
421+
const at::Tensor& philox_seed,
422+
const at::Tensor& philox_offset) {
423423
// Otherwise the kernel will be launched from cuda:0 device
424424
// Cast to char to avoid compiler warning about narrowing
425425
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
@@ -574,8 +574,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
574574
int window_size_left,
575575
int window_size_right,
576576
const bool deterministic,
577-
const at::Tensor philox_seed,
578-
const at::Tensor philox_offset)
577+
const at::Tensor& philox_seed,
578+
const at::Tensor& philox_offset)
579579
{
580580
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
581581

aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,13 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
383383
#if (defined(__gfx90a__) || defined(__gfx942__))
384384
float t = fmha_bwd(traits, args, stream_config);
385385
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
386+
#else
387+
TORCH_CHECK(
388+
false,
389+
"CK Flash Attention is not compiled with the right GPU arch. "
390+
"Either remove USE_CK_FLASH_ATTENTION or compile with the "
391+
"GPU arch that CK attention supports."
392+
)
386393
#endif
387394
} else {
388395
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.

aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
142142
has_dropout_randval,
143143
drop_seed_offset};
144144
}
145+
145146
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
146147
mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
147148
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
@@ -342,6 +343,13 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
342343
#if (defined(__gfx90a__) || defined(__gfx942__))
343344
float t = fmha_fwd(traits, args, stream_config);
344345
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
346+
#else
347+
TORCH_CHECK(
348+
false,
349+
"CK Flash Attention is not compiled with the right GPU arch. "
350+
"Either remove USE_CK_FLASH_ATTENTION or compile with the "
351+
"GPU arch that CK attention supports."
352+
)
345353
#endif
346354
}
347355
else {

aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,13 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
412412
#if (defined(__gfx90a__) || defined(__gfx942__))
413413
float t = fmha_bwd(traits, args, stream_config);
414414
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
415+
#else
416+
TORCH_CHECK(
417+
false,
418+
"CK Flash Attention is not compiled with the right GPU arch. "
419+
"Either remove USE_CK_FLASH_ATTENTION or compile with the "
420+
"GPU arch that CK attention supports."
421+
)
415422
#endif
416423
} else {
417424
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.

aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,13 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
341341
#if (defined(__gfx90a__) || defined(__gfx942__))
342342
float t = fmha_fwd(traits, args, stream_config);
343343
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
344+
#else
345+
TORCH_CHECK(
346+
false,
347+
"CK Flash Attention is not compiled with the right GPU arch. "
348+
"Either remove USE_CK_FLASH_ATTENTION or compile with the "
349+
"GPU arch that CK attention supports."
350+
)
344351
#endif
345352
}
346353
else {

0 commit comments

Comments
 (0)