Skip to content

Commit 6d21ae2

Browse files
LucasWilkinsonmgornytmm1tridaovasqu
authored
Upstream Sync - up to: 27f501d (#56)
* Support ROCM builds from source distribution, and improve error handling (Dao-AILab#1446) * Always update both submodules to include them in sdist Always update both submodules, irrespectively of whether a CUDA or a ROCM build is being done, to ensure that the necessary files from both are present in sdist. Otherwise, attempt to perform a ROCM build from sdist fails because of missing `composable_kernel` srouces. * Include `*.py` files from composable_kernel in sdist Include the `*.py` files from `csrc` in sdist, to ensure that the `generate.py` script is present. * Replace the `os.system()` calls in `setup.py` with `subprocess.run()` * Add error checking to `subprocess.run()` calls in `setup.py` Add error checking to ensure that `setup.py` fails immediately if one of the commands fail. Otherwise, the failures result only in messages to stderr that could be missed, and could lead to more confusing errors later in the build process. * Call git in `setup.py` only when working in a git repository Call git commands in `setup.py` only when the `.git` directory is present, indicating that we are working in a git checkout. Otherwise, just assert that the needed files are there. With this, building from a source distribution no longer attempts to call git in an incorrect directory. * [Build] Update version of setuptools used to generate core package (Dao-AILab#1460) * Don't compile for CUDA 11, compile for official pytorch 2.6.0 * Bump to v2.7.4 * Drop Pytorch 2.1 * [FA3] Compile with nvcc 12.8 instead of 12.3 * Fix comment in assert * [CE] Assert logit_scale > 0 * Implement HeadDim_V != HeadDim_QK, support hdimQK=192, hdimV=128 * Fix shape_O in epilogue params when kHeadDimV != kHeadDim * Remove old combine.h * Fix loading paged V when kHeadDimV != kHeadDim * Fix shape_V for storing new KV when kHeadDimV != kHeadDim * Implement the case of LargeHeadDimV * Rename Mma0->MmaQK, Mma1->MmaPV, use Cluster only if hdimV >= 192 * Pass _1 or _0 to cute::aligned_struct * Fix compilation for FP8 when kHeadDimV != kHeadDim * Support Qv * Test varlen_q=True by default for kvcache * Fix num_splits heuristic being called before get_pack_gqa * Fix num_splits heuristic again when PackGQA * Tile fwd_combine kernel along headdim, don't need kBlockM > 128 * Use bf16 instead of fp16 in benchmark_gemm.py * Update Cutlass to 3.7 * Use nvcc 12.6 but ptxas 12.8 * cicc uses the same version as ptxas * Split hdimdiff into a separate translation unit * Update benchmark script * Update Cutlass to 3.8 * Adjust tile size for hdim 64 * Adjust ninja build file * Rename collective_mainloop -> mainloop, move tile_scheduler variable * Move functions getting number of m/n blocks to a separate file * Update cutlass 3.8 to fix error w cudaGetDriverEntryPointByVersion * Fix FP8 test * make seqused optional on top level interface (Dao-AILab#1497) * Temporarily change package name of FA3 to allow FA2 & FA3 install * Update benchmark_split_kv.py to work w new API * Add tp_degree to benchmark_split_kv * Fix divide by 0 in causal tile_scheduler for large seqlen * Use split for super long sequences that don't fit into L2 * Make rotary test optional in FA3 * Enable MLA flag in FA3 (rope=64, latent=512) (Dao-AILab#1504) * Enable MLA flag in FA3 (rope=64, latent=512) * updated HasQv in flash_fwd_launch_template.h * Add simple script to benchmark MLA decode * Add dynamic splits * Update to Cutlass 3.8.0 tag * Adjust seqlen_q in MLA decode benchmark script * Fix loop in prepare_scheduler.cu (h/t Jay Shah) Only affects the case where batch size > 256 * fix: add "typename" prior to dependent type name (Dao-AILab#1517) This project uses c++17 which still has this requirement. Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com> * Add FLOPS to MLA decode benchmark * Change margin in prepare_scheduler.cu from 20% to 10% * Fix cuda 12.1 build (Dao-AILab#1511) Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * Don't use IntraWGOverlap for hdim 64,512 * Remove sink token It wasn't working anyway * fix: prompt index to type longlong to avoid numerical overflow (Dao-AILab#1500) * Add option for WG1 to use RS MMA but WG2 using SS MMA * Add kwargs to _write_ninja_file for compatibility with new torch * Move writing P to smem as separate function * Fix causal scheduler not considering hdim_v != hdim * Always split fwd_combine_kernel on batch * For each batch, if num_splits=1, write to O instead of O_partial * Enable TMA when page size is a multiple of kBlockN * Update ptxas to 12.8.93 (i.e. 12.8.1) * Use tile size 192 x 128 for hdim 64 causal * Update benchmark_mla_decode.py * Benchmark MHA, GQA, MQA, MLA in the same script * Benchmark FlashMLA if it's available * Run all 4 attn variants in benchmark * Move scheduler.get_next_work to before the epilogue * Enable Cluster for hdim128 back * Move tOrO init in mainloop * Adjust heuristic for get_pagedkv_tma * Enable PDL * Simplify prepare_varlen_num_blocks_kernel, restrict to batch <= 992 * Fix: num_splits_dynamic_ptr needs to be set before get_num_splits * Loop on num_splits instead of parameterizing it in kvcache test * Add option to precompute scheduler metadata * Update MLA decode benchmark to use get_scheduler_metadata * Fix FP8 test to quantize KV cache for reference impl as well * Dynamic autotune configs for devices with warp size != 32 (Dao-AILab#1534) Generate a list of autotune configs based on device warp size to avoid triton error if maximum threads per block is exceeded. * update binding Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> --------- Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Michał Górny <mgorny@gentoo.org> Co-authored-by: Aman Karmani <aman@tmm1.net> Co-authored-by: Tri Dao <tridpq@gmail.com> Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Co-authored-by: Ted Zadouri <tedzadouri@gmail.com> Co-authored-by: Jiang, Zhiwei <zhiwei.jiang@intel.com> Co-authored-by: xin-w8023 <43900898+xin-w8023@users.noreply.github.com> Co-authored-by: schung-amd <Steven.Chung@amd.com>
1 parent dc9d410 commit 6d21ae2

37 files changed

+1568
-664
lines changed

CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cmake_minimum_required(VERSION 3.26)
22

3-
project(vllm_flash_attn LANGUAGES CXX)
3+
project(vllm_flash_attn LANGUAGES CXX CUDA)
44
set(CMAKE_CXX_STANDARD 17)
55
set(CMAKE_CXX_EXTENSIONS OFF)
66

@@ -213,7 +213,9 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
213213
SRCS "${FA3_GEN_SRCS}"
214214
CUDA_ARCHS "${FA3_ARCHS}")
215215
set_gencode_flags_for_srcs(
216-
SRCS "hopper/flash_fwd_combine.cu"
216+
SRCS
217+
hopper/flash_fwd_combine.cu
218+
hopper/flash_prepare_scheduler.cu
217219
CUDA_ARCHS "${FA3_ARCHS}")
218220
endif()
219221

@@ -223,6 +225,7 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
223225
LANGUAGE ${VLLM_GPU_LANG}
224226
SOURCES
225227
hopper/flash_fwd_combine.cu
228+
hopper/flash_prepare_scheduler.cu
226229
hopper/flash_api.cpp
227230
hopper/flash_api_torch_lib.cpp
228231
${FA3_GEN_SRCS}

csrc/flash_attn/src/flash_bwd_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
118118
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
119119
+ (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
120120
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
121-
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
121+
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
122122
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
123123
+ (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
124124
const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM;

csrc/flash_attn/src/flash_bwd_preprocess_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ inline __device__ void compute_dot_do_o(const Params &params) {
7979
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
8080
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
8181
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
82-
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
82+
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
8383
// Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
8484
const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM;
8585

@@ -205,7 +205,7 @@ inline __device__ void convert_dQ(const Params &params, const int nsplits) {
205205
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
206206
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
207207
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
208-
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
208+
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
209209

210210
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
211211
Shape<Int<kBlockM>, Int<kHeadDim>>{},

csrc/flash_attn/src/flash_fwd_kernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
362362

363363
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
364364
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
365-
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
365+
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
366366
// if (cute::thread0()) { print(tOrP); }
367367
FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
368368
// if (cute::thread0()) { print(scores); }
@@ -424,7 +424,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
424424

425425
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
426426
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
427-
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
427+
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
428428
FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
429429
}
430430

@@ -942,7 +942,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
942942
Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
943943
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
944944
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
945-
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
945+
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
946946

947947
FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
948948

@@ -1002,7 +1002,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
10021002
Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
10031003
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
10041004
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
1005-
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
1005+
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
10061006

10071007
FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
10081008
}

flash_attn/ops/triton/layer_norm.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@
1515
import triton
1616
import triton.language as tl
1717

18+
def triton_autotune_configs():
19+
# Return configs with a valid warp count for the current device
20+
configs=[]
21+
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
22+
max_threads_per_block=1024
23+
# Default to warp size 32 if not defined by device
24+
warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
25+
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
26+
warp_count=1
27+
while warp_count*warp_size <= max_threads_per_block:
28+
configs.append(triton.Config({}, num_warps=warp_count))
29+
warp_count*=2
30+
return configs
1831

1932
def layer_norm_ref(
2033
x,
@@ -126,14 +139,7 @@ def rms_norm_ref(
126139

127140

128141
@triton.autotune(
129-
configs=[
130-
triton.Config({}, num_warps=1),
131-
triton.Config({}, num_warps=2),
132-
triton.Config({}, num_warps=4),
133-
triton.Config({}, num_warps=8),
134-
triton.Config({}, num_warps=16),
135-
triton.Config({}, num_warps=32),
136-
],
142+
configs=triton_autotune_configs(),
137143
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
138144
)
139145
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@@ -393,14 +399,7 @@ def _layer_norm_fwd(
393399

394400

395401
@triton.autotune(
396-
configs=[
397-
triton.Config({}, num_warps=1),
398-
triton.Config({}, num_warps=2),
399-
triton.Config({}, num_warps=4),
400-
triton.Config({}, num_warps=8),
401-
triton.Config({}, num_warps=16),
402-
triton.Config({}, num_warps=32),
403-
],
402+
configs=triton_autotune_configs(),
404403
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
405404
)
406405
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})

hopper/benchmark_attn.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import namedtuple
22
from functools import partial
33
import math
4+
import os
45
from typing import NamedTuple
56
import torch
67
import torch.nn as nn
@@ -34,6 +35,8 @@
3435
triton_attention = None
3536
triton_attention = None
3637

38+
DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
39+
3740

3841
def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
3942
# # Warmup
@@ -53,7 +56,7 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
5356
# time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc)
5457
# # return time_f[1].mean
5558
# return time_f[1]
56-
return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3)
59+
return Timing(do_bench(lambda: func(*args, **kwargs), warmup=3, rep=repeats) * 1e-3)
5760

5861

5962
def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)):
@@ -250,21 +253,24 @@ def run(*args, **kwargs):
250253
# for headdim in [64, 96, 128, 192, 256]:
251254
for headdim in [128]:
252255
nheads = dim // headdim
256+
# nheads = 128
253257
# headdim = 64
254258
# batch_size = 64
255259
# seqlen = 512
256260
# nheads = 8
257261
# headdim = 128
258262
nheads_kv = nheads
259263
# nheads_kv = nheads // 4
264+
# nheads_kv = 1
260265
headdim_v = headdim
261-
# headdim_v = 128
266+
# headdim_v = 512
267+
has_qv = headdim == 64 and headdim_v == 512
268+
# has_qv = False
262269

263270
for batch_size, seqlen in bs_seqlen_vals:
264271
num_splits = 0
265272
window_size = (-1, -1)
266273
# window_size = (seqlen // 2 - 1, 0)
267-
sink_token_length = 0
268274
pack_gqa = None
269275
# seqlen_q = 64
270276
seqlen_q = seqlen
@@ -276,6 +282,7 @@ def run(*args, **kwargs):
276282
q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]]
277283
v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_()
278284
v_fa3 = v if not V_colmajor else v_colmajor
285+
qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None
279286
# q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
280287
# k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
281288
# v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype)
@@ -303,7 +310,7 @@ def run(*args, **kwargs):
303310
for causal in [False, True]:
304311
# for causal in [True]:
305312
print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
306-
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size)
313+
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)
307314
if cudnn is not None:
308315
# if False:
309316
if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:
@@ -351,17 +358,17 @@ def run(*args, **kwargs):
351358

352359
time.sleep(1)
353360
if not varlen:
354-
# m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
355-
m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
361+
# m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
362+
m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, qv=qv, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
356363
# pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
357364
else:
358365
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
359366
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
360367
time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
361-
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
368+
if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD:
362369
time.sleep(1)
363370
if not varlen:
364-
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic,
371+
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
365372
repeats=repeats, verbose=False, desc='Fav3')
366373
else:
367374
_, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
@@ -387,7 +394,7 @@ def run(*args, **kwargs):
387394
print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
388395
print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
389396
print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
390-
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
397+
if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD:
391398
print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
392399
# benchmark_forward(torch.square, k)
393400
# print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')
@@ -397,7 +404,8 @@ def run(*args, **kwargs):
397404
# import pickle
398405
# # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp:
399406
# # with open(f'flash3_attn_time_h100_cudnn_triton_20241208.plk', 'wb') as fp:
400-
# with open(f'flash3_attn_time_h100_fa3_20241208.plk', 'wb') as fp:
407+
# with open(f'flash3_attn_time_h100_fa3_20250313.plk', 'wb') as fp:
408+
# # with open(f'flash3_attn_time_h100_fa3_fp8_20250313.plk', 'wb') as fp:
401409
# # with open(f'flash3_attn_time_h100_fp8_hdim{headdim}.plk', 'wb') as fp:
402410
# # with open(f'flash3_attn_time_h100_hdim{headdim}_1031.plk', 'wb') as fp:
403411
# pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)

0 commit comments

Comments
 (0)