Skip to content

Commit 95898ba

Browse files
LucasWilkinsonmgornytmm1tridao
authored
Upstream Sync (#45)
* Support ROCM builds from source distribution, and improve error handling (Dao-AILab#1446) * Always update both submodules to include them in sdist Always update both submodules, irrespectively of whether a CUDA or a ROCM build is being done, to ensure that the necessary files from both are present in sdist. Otherwise, attempt to perform a ROCM build from sdist fails because of missing `composable_kernel` srouces. * Include `*.py` files from composable_kernel in sdist Include the `*.py` files from `csrc` in sdist, to ensure that the `generate.py` script is present. * Replace the `os.system()` calls in `setup.py` with `subprocess.run()` * Add error checking to `subprocess.run()` calls in `setup.py` Add error checking to ensure that `setup.py` fails immediately if one of the commands fail. Otherwise, the failures result only in messages to stderr that could be missed, and could lead to more confusing errors later in the build process. * Call git in `setup.py` only when working in a git repository Call git commands in `setup.py` only when the `.git` directory is present, indicating that we are working in a git checkout. Otherwise, just assert that the needed files are there. With this, building from a source distribution no longer attempts to call git in an incorrect directory. * [Build] Update version of setuptools used to generate core package (Dao-AILab#1460) * Don't compile for CUDA 11, compile for official pytorch 2.6.0 * Bump to v2.7.4 * Drop Pytorch 2.1 * [FA3] Compile with nvcc 12.8 instead of 12.3 * Fix comment in assert * [CE] Assert logit_scale > 0 * Implement HeadDim_V != HeadDim_QK, support hdimQK=192, hdimV=128 * Fix shape_O in epilogue params when kHeadDimV != kHeadDim * Remove old combine.h * Fix loading paged V when kHeadDimV != kHeadDim * Fix shape_V for storing new KV when kHeadDimV != kHeadDim * Implement the case of LargeHeadDimV * Rename Mma0->MmaQK, Mma1->MmaPV, use Cluster only if hdimV >= 192 * Pass _1 or _0 to cute::aligned_struct * Fix compilation for FP8 when kHeadDimV != kHeadDim * Support Qv * Test varlen_q=True by default for kvcache * Fix num_splits heuristic being called before get_pack_gqa * Fix num_splits heuristic again when PackGQA * Tile fwd_combine kernel along headdim, don't need kBlockM > 128 * Use bf16 instead of fp16 in benchmark_gemm.py * Update Cutlass to 3.7 * Use nvcc 12.6 but ptxas 12.8 * cicc uses the same version as ptxas * Split hdimdiff into a separate translation unit * Update benchmark script * Update Cutlass to 3.8 * Adjust tile size for hdim 64 * Adjust ninja build file * build head diff + fix build errors Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Michał Górny <mgorny@gentoo.org> Co-authored-by: Aman Karmani <aman@tmm1.net> Co-authored-by: Tri Dao <tridpq@gmail.com>
1 parent 720c948 commit 95898ba

File tree

340 files changed

+2326
-1432
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

340 files changed

+2326
-1432
lines changed

.github/workflows/publish.yml

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,16 @@ jobs:
4444
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4545
os: [ubuntu-20.04]
4646
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
47-
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001']
48-
cuda-version: ['11.8.0', '12.3.2']
47+
torch-version: ['2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0']
48+
cuda-version: ['12.4.1']
4949
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
5050
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
5151
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
5252
# when building without C++11 ABI and using it on nvcr images.
5353
cxx11_abi: ['FALSE', 'TRUE']
5454
exclude:
5555
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
56-
# Pytorch < 2.2 does not support Python 3.12
57-
- torch-version: '2.1.2'
58-
python-version: '3.12'
5956
# Pytorch < 2.5 does not support Python 3.13
60-
- torch-version: '2.1.2'
61-
python-version: '3.13'
6257
- torch-version: '2.2.2'
6358
python-version: '3.13'
6459
- torch-version: '2.3.1'
@@ -113,7 +108,7 @@ jobs:
113108
run: |
114109
pip install --upgrade pip
115110
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
116-
pip install setuptools==68.0.0
111+
pip install setuptools==75.8.0
117112
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
118113
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
119114
pip install typing-extensions==4.12.2
@@ -122,8 +117,8 @@ jobs:
122117
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
123118
# This code is ugly, maybe there's a better way to do this.
124119
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
125-
minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
126-
maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
120+
minv = {'2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
121+
maxv = {'2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
127122
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
128123
)
129124
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
@@ -149,7 +144,7 @@ jobs:
149144
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
150145
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
151146
# However this still fails so I'm using a newer version of setuptools
152-
pip install setuptools==68.0.0
147+
pip install setuptools==75.8.0
153148
pip install ninja packaging wheel
154149
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
155150
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
@@ -203,7 +198,9 @@ jobs:
203198

204199
- name: Install dependencies
205200
run: |
206-
pip install ninja packaging setuptools wheel twine
201+
pip install ninja packaging wheel twine
202+
# Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv)
203+
pip install setuptools==75.8.0
207204
# We don't want to download anything CUDA-related here
208205
pip install torch --index-url https://download.pytorch.org/whl/cpu
209206

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,18 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
176176
# BF16 source files
177177
file(GLOB FA3_BF16_GEN_SRCS
178178
"hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
179+
file(GLOB FA3_BF16_GEN_SRCS_
180+
"hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
181+
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
179182
file(GLOB FA3_BF16_GEN_SRCS_
180183
"hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu")
181184
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
182185
# FP16 source files
183186
file(GLOB FA3_FP16_GEN_SRCS
184187
"hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
188+
file(GLOB FA3_FP16_GEN_SRCS_
189+
"hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
190+
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
185191
file(GLOB FA3_FP16_GEN_SRCS_
186192
"hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu")
187193
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ recursive-include csrc *.h
33
recursive-include csrc *.cuh
44
recursive-include csrc *.cpp
55
recursive-include csrc *.hpp
6+
recursive-include csrc *.py
67

78
recursive-include vllm_flash_attn *.cu
89
recursive-include vllm_flash_attn *.h

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Currently released:
4444

4545
Requirements: H100 / H800 GPU, CUDA >= 12.3.
4646

47-
For now, we highly recommend CUDA 12.3 for best performance.
47+
We highly recommend CUDA 12.8 for best performance.
4848

4949
To install:
5050
```sh
@@ -65,7 +65,7 @@ flash_attn_interface.flash_attn_func()
6565
## Installation and features
6666
**Requirements:**
6767
- CUDA toolkit or ROCm toolkit
68-
- PyTorch 1.12 and above.
68+
- PyTorch 2.2 and above.
6969
- `packaging` Python package (`pip install packaging`)
7070
- `ninja` Python package (`pip install ninja`) *
7171
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
@@ -98,7 +98,7 @@ MAX_JOBS=4 pip install flash-attn --no-build-isolation
9898

9999
### NVIDIA CUDA Support
100100
**Requirements:**
101-
- CUDA 11.7 and above.
101+
- CUDA 12.0 and above.
102102

103103
We recommend the
104104
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)

benchmarks/benchmark_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs
2626

2727
torch.manual_seed(0)
2828
repeats = 30
29-
dtype = torch.float16
29+
dtype = torch.bfloat16
3030
device = 'cuda'
3131
verbose = False
3232
m, n = 8192, 8192

csrc/cutlass

Submodule cutlass updated 2220 files

flash_attn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.7.3"
1+
__version__ = "2.7.4.post1"
22

33
from flash_attn.flash_attn_interface import (
44
flash_attn_func,

flash_attn/ops/triton/cross_entropy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def forward(
166166
if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
167167
labels = F.pad(labels, (0, 1))[..., :-1]
168168
assert labels.data_ptr() % 16 == 0
169+
assert logit_scale > 0.0
169170
n_rows, n_cols = logits.shape
170171
assert labels.shape == (n_rows,)
171172
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)

hopper/benchmark_attn.py

Lines changed: 19 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
5656
return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3)
5757

5858

59-
def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=(-1, -1)):
59+
def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)):
6060
if causal:
6161
avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2
6262
else:
@@ -67,7 +67,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=
6767
col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))
6868
col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1))
6969
avg_seqlen = (col_right - col_left + 1).float().mean().item()
70-
return batch * nheads * 2 * seqlen_q * avg_seqlen * headdim * 2
70+
return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)
7171

7272

7373
def convert_to_cudnn_type(torch_type):
@@ -242,21 +242,6 @@ def run(*args, **kwargs):
242242
time_f = {}
243243
time_b = {}
244244

245-
# tflops_matmul = {}
246-
# m, n = 8192, 8192
247-
# for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
248-
# a = torch.randn(m, k, device=device, dtype=dtype)
249-
# b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
250-
# nFLOPS_matmul = 2 * m * n * k
251-
# m5 = time_fwd(torch.matmul, a, b, desc='cuBLAS')
252-
# print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')
253-
# tflops_matmul[k] = nFLOPS_matmul / m5.mean * 1e-12
254-
# # import pickle
255-
# # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp:
256-
# # with open(f'flash3_matmul_tflops_h100.plk', 'wb') as fp:
257-
# # pickle.dump(tflops_matmul, fp, protocol=pickle.HIGHEST_PROTOCOL)
258-
# exit(0)
259-
260245
# for headdim in [64, 128, 256]:
261246
# for headdim in [64, 96, 128, 192]:
262247
# for headdim in [64, 96, 128, 192, 256]:
@@ -272,9 +257,11 @@ def run(*args, **kwargs):
272257
# headdim = 128
273258
nheads_kv = nheads
274259
# nheads_kv = nheads // 4
260+
headdim_v = headdim
261+
# headdim_v = 128
275262

276263
for batch_size, seqlen in bs_seqlen_vals:
277-
num_splits = 1
264+
num_splits = 0
278265
window_size = (-1, -1)
279266
# window_size = (seqlen // 2 - 1, 0)
280267
sink_token_length = 0
@@ -285,20 +272,16 @@ def run(*args, **kwargs):
285272
# leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32)
286273
q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
287274
k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)
288-
v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)
275+
v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True)
289276
q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]]
290277
v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_()
291278
v_fa3 = v if not V_colmajor else v_colmajor
292279
# q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
293280
# k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
294-
# v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
295-
g = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
296-
o = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
281+
# v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype)
282+
g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True)
283+
o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True)
297284
stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32)
298-
a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen)
299-
b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2)
300-
# x = torch.randn(batch_size * seqlen, 4096, device=device, dtype=dtype)
301-
# w = torch.randn(4096 * 2, 4096, device=device, dtype=dtype).transpose(-1, -2)
302285
if varlen:
303286
q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_() for x in [q, k, v]]
304287
cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q
@@ -318,16 +301,16 @@ def run(*args, **kwargs):
318301
page_table = None
319302

320303
for causal in [False, True]:
321-
# for causal in [False]:
304+
# for causal in [True]:
322305
print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
323-
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, causal=causal, window_size=window_size)
306+
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size)
324307
if cudnn is not None:
325308
# if False:
326-
if headdim <= 256 and dtype != torch.float8_e4m3fn:
309+
if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:
327310
cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0])
328311
cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0])
329312
# _, m0 = benchmark_forward(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
330-
if dtype != torch.float8_e4m3fn:
313+
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
331314
# if False:
332315
if not varlen:
333316
m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
@@ -343,7 +326,7 @@ def run(*args, **kwargs):
343326
repeats=repeats, verbose=False, desc='Fav2')
344327
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean
345328
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True)
346-
if headdim <= 256 and dtype != torch.float8_e4m3fn:
329+
if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:
347330
if triton_attention is not None:
348331
qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]]
349332
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
@@ -356,7 +339,7 @@ def run(*args, **kwargs):
356339
# # pytorch_profiler(triton_attention, q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), v.transpose(1, 2).contiguous(), causal, 1 / math.sqrt(headdim), backward=True)
357340
if cudnn is not None:
358341
# if False:
359-
if headdim <= 256 and dtype != torch.float8_e4m3fn:
342+
if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:
360343
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
361344
m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')
362345
time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean
@@ -375,12 +358,7 @@ def run(*args, **kwargs):
375358
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
376359
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
377360
time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
378-
# time.sleep(1)
379-
# m5 = time_fwd(torch.bmm, a, b, desc='cuBLAS', repeats=repeats, verbose=False)
380-
# nFLOPS_matmul = nFLOPS
381-
# nFLOPS_matmul = 2 * x.shape[0] * x.shape[1] * w.shape[1]
382-
# m5 = time_fwd(torch.matmul, x, w, desc='cuBLAS')
383-
if dtype != torch.float8_e4m3fn:
361+
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
384362
time.sleep(1)
385363
if not varlen:
386364
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic,
@@ -396,11 +374,11 @@ def run(*args, **kwargs):
396374
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
397375
# benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')
398376

399-
if dtype != torch.float8_e4m3fn:
377+
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
400378
# if False:
401379
print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS')
402380
print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS')
403-
if headdim <= 256 and dtype != torch.float8_e4m3fn:
381+
if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:
404382
if triton_attention is not None:
405383
print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS')
406384
# if causal:
@@ -409,7 +387,7 @@ def run(*args, **kwargs):
409387
print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
410388
print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
411389
print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
412-
if dtype != torch.float8_e4m3fn:
390+
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
413391
print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
414392
# benchmark_forward(torch.square, k)
415393
# print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')

0 commit comments

Comments
 (0)