Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
472 commits
Select commit Hold shift + click to select a range
81cdf4c
[Cute] Don't need i64_to_f32x2 anymore
tridao Aug 12, 2025
c4be578
Remove old xentropy kernel
tridao Aug 12, 2025
3edef7c
Remove old fused softmax kernel from apex/Megatron
tridao Aug 12, 2025
2715c53
Remove old attn decode kernel from FasterTransformer
tridao Aug 12, 2025
f28841d
Remove old rotary kernel
tridao Aug 12, 2025
a1c2e22
[Cute] Implement page table with TMA for fwd_sm100
tridao Aug 12, 2025
581b68d
[Cute] Remove trailing bracket (#1809)
jduprat Aug 13, 2025
3c51f15
[Cute] Make sure R2P happen
tridao Aug 13, 2025
d2e3fc3
feat: add support for pytorch2.8 (#1801)
NanoCode012 Aug 13, 2025
69b33b5
[Cute] Implement PackGQA with TMA for fwd_sm100
tridao Aug 14, 2025
060c918
Bump to v2.8.3
tridao Aug 14, 2025
cd9383f
[BugFix] Fix flash_attn_with_kvcache with scalar cache_seqlen (#1795)
stepinto Aug 15, 2025
b31ae1e
[Cute] Port fwd_combine kernel from C++ to cute-dsl
tridao Aug 17, 2025
591dc7e
[Cute] Simplify tile scheduler storing params
tridao Aug 17, 2025
f8b4f15
[Cute] Implement sink for fwd_sm90
tridao Aug 17, 2025
e1407db
[Cute] Implement PackGQA with TMA for fwd_sm90
tridao Aug 17, 2025
0e60e39
[Cute] Use R2P for masking in fwd_sm90
tridao Aug 17, 2025
199401d
Add sorting and head swizzle to varlen scheduler (#1823)
jayhshah Aug 22, 2025
632fe2a
Fixes incorrect variable reference in comment (#1775)
LoserCheems Aug 24, 2025
832d544
Update the initialization of dk/dv_semaphore (#1839)
y-sq Aug 25, 2025
478841a
Update tile_scheduler.hpp (#1841)
ghadiaravi13 Aug 26, 2025
6f2b052
ci: Move build job to workflow template (#1835)
ko3n1g Aug 27, 2025
b247655
ci: Build via workflow template (#1844)
ko3n1g Aug 27, 2025
d0ed097
ci: Switch to workflow_dispatch (#1847)
ko3n1g Aug 29, 2025
203b9b3
[`FA3`] Allow returning LSE via kwarg (#1851)
vasqu Aug 29, 2025
27b64c7
[BugFix] fix flash_fwd.FlashAttentionForwardSm80 bugs (#1856)
mingyangHao Sep 2, 2025
6387433
[FIX] Allow m_block_size == 192 and mma_pv_is_rs == False in Sm90 CuT…
reubenconducts Sep 2, 2025
afc97c6
make FA3 compatible with CUDA 13 Builds (#1860)
johnnynunez Sep 4, 2025
dfb6649
[BUILD] SBSA wheels + CUDA 13 Support (#1865)
johnnynunez Sep 5, 2025
e8c7344
benchmark: qualify all attention backends by methods list (#1881)
rajesh-s Sep 12, 2025
b3846b0
ABI stable fa3 (#1791)
mikaylagawarecki Sep 12, 2025
7bdb426
[NVIDIA] Enable Blackwell Family Specific (#1882)
johnnynunez Sep 12, 2025
e980f0f
fix typo in flops calculation for local attention (#1883)
henrylhtsang Sep 13, 2025
2cc6fd6
flash-attn-cute bwd sm90 (#1868)
tzadouri Sep 13, 2025
8ecf128
[Cute] Make testing utils standlone for cute (#1892)
drisspg Sep 17, 2025
589cc20
Bump pin for CuTeDSL (#1891)
drisspg Sep 17, 2025
5c1627a
Improve causal backward determinism perf with SPT schedule (#1893)
jayhshah Sep 17, 2025
1ceaa98
Upgrade to cutlass v4.2.1 (#1905)
johnnynunez Sep 23, 2025
3b24b08
switch to use cutlass.utils.get_smem_capacity_in_bytes instead of dep…
brandon-yujie-sun Sep 24, 2025
0165c96
Add Missing None Gradient in FA3 QKVPacked (#1908)
JackCharlesZhang Sep 24, 2025
add1756
C++11 fix warnings (#1904)
johnnynunez Sep 25, 2025
cc0a79b
[Cute] Write ex2 emulation in a more readable form
tridao Sep 27, 2025
5059fd5
[Cute] Simplify utils.py a bit
tridao Sep 27, 2025
c485eea
[Cute] Remove arith & vector import in utils.py
tridao Oct 1, 2025
cbd2490
[CuteDSL] Fix test (#1925)
drisspg Oct 7, 2025
5183de4
Refactors to enable FlexAttention (#1840)
drisspg Oct 8, 2025
a38d69d
[Cute] Fix softmax for cutlass-dsl==4.2.1
tridao Oct 11, 2025
437b35a
[Cute] Fix softmax for fwd_sm100
tridao Oct 12, 2025
ea03e06
[Cute,Bwd] Simplify bwd_preprocessing kernel
tridao Oct 12, 2025
fbdba01
[Cute,Fwd,Sm90] Simplify by passing around functions
tridao Oct 12, 2025
b528f4b
[Cute,Fwd,Sm90] Simplify score mode by passing around partial fn
tridao Oct 12, 2025
13f2077
[Cute] Optionally dump cubin and sass
tridao Oct 12, 2025
c172985
[Cute,Fwd,Sm90] Rename m_block_size->tile_m, n_block_size->tile_n
tridao Oct 12, 2025
9eee089
[Cute,Bwd,Sm90] Format file w ruff
tridao Oct 12, 2025
42e4e3e
[Cute,Bwd,Sm90] Fix bwd dK & dV, more async
tridao Oct 13, 2025
093b935
[Cute,Bwd,Sm90] Use cp.async.bulk instead of TMA for LSE & dPsum
tridao Oct 13, 2025
9be4a62
[Cute,Bwd,Sm90] Use 1 barrier for loading both K & V
tridao Oct 13, 2025
5576480
[Cute,Bwd,Sm90] Don't clear dK & dV, use zero_init mma flag instead
tridao Oct 13, 2025
5a5a65b
[Cute,Bwd,Sm90] Use TMA to store dK & dV
tridao Oct 13, 2025
66fd2a4
[Cute,Bwd,Sm90] Load K together w Q & LSE in the first iteration
tridao Oct 13, 2025
35384ec
[Cute,Sm90] Move gemm helper functions to hopper_helpers.py
tridao Oct 13, 2025
7c0e373
Swap masking to not use R2P
imbr92 Oct 13, 2025
60eb1ea
Pre-indent to make commit diffs readable
imbr92 Oct 13, 2025
25f5d09
Adding varlen support + tests
imbr92 Oct 13, 2025
b4e5896
Remove self refs in softmax for loop (#1924)
kevin-tong-augment Oct 13, 2025
13afe0d
[Cute,Bwd,Sm90] Make postprocessing kernel work
tridao Oct 13, 2025
d2c8a6c
[Cute] Run ruff format on bwd files
tridao Oct 13, 2025
ee3a533
[CI] Add pre-commit GH action
tridao Oct 13, 2025
93e433b
[Cute,Bwd,Sm90] Try dO_stage=1, PdS_stage=1
tridao Oct 14, 2025
57d0ce9
[Cute,Bwd,Sm90] Make causal work
tridao Oct 14, 2025
89b94f8
[Cute,Bwd,Sm90] Implement dQ_swapAB
tridao Oct 14, 2025
54d8aa6
[Cute,Bwd,Sm90] Implement SdP_swapAB
tridao Oct 14, 2025
72b793a
[AMD] Torch Compile Issues (#1756)
micmelesse Oct 14, 2025
5685ace
[Cute,Bwd,Sm90] Implement mma_dkv_is_rs
tridao Oct 14, 2025
a76e692
[Cute,Bwd,Sm90] Use block size 80x128
tridao Oct 14, 2025
6bc3d1f
[CUTE] Enable Pack GQA for score mods (#1937)
drisspg Oct 15, 2025
04adaf0
Add precommit list and then uncomment in chunks (#1941)
drisspg Oct 15, 2025
48ecd14
[ROCm] prepare CK sources for pytorch hipify v2 APIs (#1944)
jeffdaily Oct 18, 2025
cc843a2
[Cute] Add flake8 config file
tridao Oct 18, 2025
c712d43
[Cute,Fwd,Sm90] Load Q & K using the same mbarrier
tridao Oct 18, 2025
752c263
[Cute,Bwd,Sm90] Use the same producer states if Q_stage == dO_stage
tridao Oct 18, 2025
71ec343
[Cute,Bwd,Sm90] Split sdQaccum layout into 2 warp groups
tridao Oct 18, 2025
7a3a8fe
[Cute,Bwd,Sm90] Implement masking
tridao Oct 19, 2025
75fcbf2
[Cute,Fwd,Sm100] Parse swizzle from pointer, don't need to pass in
tridao Oct 19, 2025
b5e9a71
[Cute,Fwd,Sm100] Clean up
tridao Oct 19, 2025
b4fac7d
[Cute,Fwd,Sm100] Clean up mask
tridao Oct 19, 2025
9c14873
[Cute] Reformat blackwell_helpers.py, block_info.py
tridao Oct 19, 2025
aae355e
[Cute] Format mma_sm100_desc.py, seqlen_info.py
tridao Oct 19, 2025
83eb8d6
sm100 bwd add kernel and update postprocess mask and barriers (#1945)
tzadouri Oct 19, 2025
5fa6e8d
[Cute,Bwd,Sm100] Format flash_bwd_sm100.py and flash_bwd_postprocess
tridao Oct 19, 2025
498bfe6
[Cute,Bwd,Sm100] Rename var {m,n}_block_size->tile_{m,n}
tridao Oct 19, 2025
94f50b0
[Cute,Bwd,Sm100] Clean up a bit
tridao Oct 19, 2025
e925d10
add barrier module (#1946)
tzadouri Oct 19, 2025
d0d8adb
[Cute,Bwd,Sm100] Have a separate function to set up the mma
tridao Oct 19, 2025
796564d
[Cute,Bwd,Sm100] Load LSE with cpasync_bulk
tridao Oct 19, 2025
d0399b6
[Cute,Bwd,Sm100] Load dPsum with cpasync_bulk
tridao Oct 19, 2025
372f3e2
[Cute,Bwd,Sm100] Use copy_utils functions to load Q & dO
tridao Oct 19, 2025
c0c8c2d
[Cute,Bwd,Sm100] Load K & Q, V & dO in the first iteration
tridao Oct 19, 2025
7b17cd8
[Cute,Bwd,Sm100] Simplify mma by using functools.partial
tridao Oct 19, 2025
5c685ea
[Cute,Bwd,Sm100] Don't need q_dk_consumer_state
tridao Oct 19, 2025
8790c6e
[Cute,Bwd,Sm100] Simplify dQacc_reduce, don't need mbarrier
tridao Oct 20, 2025
7254904
[Cute,Bwd,Sm100] Iterate from m_block_min -> m_block_max
tridao Oct 20, 2025
2187695
[Cute,Bwd,Sm100] Try direct atomicadd rmem -> gmem
tridao Oct 20, 2025
12e1c04
[Cute,Bwd,Sm100] Combine pipeline_dK and pipeline_dV into one
tridao Oct 20, 2025
d101fa7
[Cute,Bwd,Sm100] All compute warps wait for lse_barrier
tridao Oct 20, 2025
82c9cbb
[Cute,Bwd,Sm100] sdQaccum doesn't need swizzle
tridao Oct 20, 2025
91f14ca
[Cute,Bwd,Sm100] Try gemm_ptx
tridao Oct 20, 2025
53c884b
[Cute,Bwd,Sm100] Clean up compute fn
tridao Oct 21, 2025
0f56550
[Cute,Bwd,Sm100] Combine pipeline_S and pipeline_P into 1
tridao Oct 21, 2025
22f7daa
[Cute,Bwd,Sm100] Don't shuffle LSE & dPsum, reduce state variables
tridao Oct 21, 2025
3cac07a
[Cute,Bwd,Sm100] Hardcode dS_stage = 1
tridao Oct 21, 2025
f29df7a
[Cute,Bwd,Sm100] Add option for delay tma store
tridao Oct 21, 2025
933b2c3
Fix hopper cuda 13 build (#1949)
kevmo314 Oct 21, 2025
a098f98
[CuteDSL] Fix hash function for cute.jit decorator (#1953)
drisspg Oct 21, 2025
143b0ba
Block Sparsity and Flex Attention mask mod support (#1942)
reubenconducts Oct 21, 2025
16c7f0f
cutlass v4.3.0 (#1952)
johnnynunez Oct 21, 2025
9dbed03
[Cute,Bwd,Sm100] Use CopyBulkG2SOp copy op instead of calling ptx
tridao Oct 21, 2025
1b8e1e6
[Cute,Bwd,Sm100] More cleanup
tridao Oct 22, 2025
e4d25a4
[CuTe DSL] Update "buffers" name to "aux_tensors"; fix flex bugs (#1961)
reubenconducts Oct 24, 2025
3effce8
Fix FA3 segfault with custom CUDA streams in ABI stable build (#1957)
kevmo314 Oct 24, 2025
9450df6
[Cute,Fwd,Sm100] Fix interface w score mod to get it to run
tridao Oct 24, 2025
7ef1a6f
[Cute,Sm100] In gemm ptx, add to base smem_address instead
tridao Oct 24, 2025
b3f437f
[Cute,Bwd,Sm100] Make postprocessing work, add interface
tridao Oct 25, 2025
6eb7c80
[Cute,Bwd,Sm100] Simplify layouts in compute_loop
tridao Oct 25, 2025
93a0afe
[Cute,Bwd,Sm100] Causal mask
tridao Oct 25, 2025
662cf9c
[Cute,Bwd,Sm100] Enable bwd tests
tridao Oct 25, 2025
79b9030
[Cute,Bwd] Enable bwd benchmarks
tridao Oct 25, 2025
510fe92
[Cute] Add store_shared_remote_fp32x4 util function
tridao Oct 26, 2025
b634499
[Cute,Bwd,Sm100] Tune registers
tridao Oct 26, 2025
e873ad0
[Cute,Sm100] acc_tmem_addr is Int32 instead of constexpr
tridao Oct 26, 2025
2c7177d
[Cute,Bwd,Sm100] Reduce sync
tridao Oct 26, 2025
6c56a0c
[Cute] Change utils.view_transpose back
tridao Oct 26, 2025
285bf12
[Cute,Bwd,Sm100] Remove delay_tma_store option
tridao Oct 26, 2025
c59ecd8
[Cute,Bwd,Sm100] Implement cluster
tridao Oct 26, 2025
25e6d94
[Cute] Copy benchmark util functions to cute directory
tridao Oct 27, 2025
53d3a99
[Cute,Bwd,Sm100] Use pipeline class for LSE and dPsum
tridao Oct 28, 2025
a5d545d
[Cute,Bwd,Sm100] Remove stage from sK, sV, tP, sdS
tridao Oct 28, 2025
b3f1b6a
[Cute,Bwd,Sm100] Fix wrong LSE and dPsum indexing in load
tridao Oct 28, 2025
67e8865
[Cute] Blocks tweaks (#1964)
drisspg Oct 28, 2025
7f7a497
[Cute,Bwd,Sm100] Use TS MMA for dK
tridao Oct 28, 2025
b613d9e
[Cute,Blocksparse] Group block sparse input torch tensors
tridao Oct 28, 2025
11336b7
[Cute,Bwd,Sm100] Separate mma_S and mma_dP
tridao Oct 29, 2025
419bdb7
[Cute,Bwd,Sm100] Try LPTBwdScheduler
tridao Oct 29, 2025
de1584b
[Cute,Bwd,Sm100] Try separating warps loading Q and dO
tridao Oct 29, 2025
0256114
BlockSparse Tweaks (#1970)
drisspg Oct 31, 2025
6c9eef9
[Cute] Fix main (#1982)
drisspg Nov 3, 2025
e724e25
[Cute,Fwd,Sm100] Implement SplitKV (#1940)
timmy-feng Nov 5, 2025
ad70a00
[Cute] Extract block-sparse utilities from SM80/90 (#1984)
drisspg Nov 5, 2025
c8abdd4
Enable python-3.10+ (#1998)
drisspg Nov 9, 2025
2ef346b
[Cute, Bwd, Sm100] Add GQA support (#2004)
jayhshah Nov 12, 2025
1338006
[Cute,Fwd,Sm100] fix major regression with split kv (#2006)
jayhshah Nov 12, 2025
16d78bb
[CuTe DSL] Block sparsity computation kernel (#1983)
reubenconducts Nov 12, 2025
fbf24f6
[NVIDIA] bump github actions (#1996)
johnnynunez Nov 13, 2025
5d2cd3b
[Cute,Fwd,Sm100] Support paged attention (#1999)
timmy-feng Nov 14, 2025
c7697bb
Add torch.compile support to flash attention 3
guilhermeleobas Jul 16, 2025
e1944ba
Don't return mutated variables in mha_bwd
guilhermeleobas Jul 24, 2025
a760ca3
Change fake_check flag to be opt-in; Remove build.sh and remove if-el…
guilhermeleobas Jul 25, 2025
24cc2b2
Remove print statements and update exception message
guilhermeleobas Jul 30, 2025
5e114d5
Fix flash_attn_backward_fake
guilhermeleobas Aug 6, 2025
734bc43
Add `safe_aot_autograd_check`
guilhermeleobas Aug 7, 2025
fde4bc0
Update namespace to flash_attn_3
guilhermeleobas Aug 19, 2025
ab79ae2
Add `flash_attn_forward.register_autograd`
guilhermeleobas Aug 22, 2025
6250fbe
Fix bug in `flash_attn_backward_fake`
guilhermeleobas Aug 22, 2025
1e3539e
Add support and tests for torch.export and aoti_compile_and_package
guilhermeleobas Sep 2, 2025
f174bd6
format code
guilhermeleobas Sep 3, 2025
6fe1c8c
update flash_api_stable.cpp
guilhermeleobas Sep 19, 2025
b555ac7
Fix flash_api_stable.cpp build
guilhermeleobas Oct 13, 2025
0aa4fa1
Only run schema_check if dtype is not float8_e4m3fn
guilhermeleobas Oct 13, 2025
47d7137
Correctly compute kBlockM for sm88/86/80
guilhermeleobas Oct 13, 2025
49fb775
Fix bug in boxed_mha_bwd
guilhermeleobas Oct 13, 2025
65dd580
don't run autograd_check when num_splits > 0
guilhermeleobas Nov 12, 2025
b4555bf
[Cute] Add block-sparsity support to SM100 (#1985)
drisspg Nov 18, 2025
43375aa
[Cute,Sm100,Fwd] use correction warps for epi when not using TMA (#2014)
jayhshah Nov 19, 2025
3fcde4b
Raise TypeError if out is specified when compiling _flash_attn_forward
guilhermeleobas Nov 21, 2025
052015a
add fastdivmod for oob reads in mask_mods (#2020)
drisspg Nov 21, 2025
d063b33
don't pass mask_fn to softmax_step generically (#2026)
jayhshah Nov 22, 2025
a986d01
swap order of decorators (#2029)
anakinxc Nov 24, 2025
20cda05
[Cute,Bwd,Sm100] enable deterministic mode for sm100 bwd and fix race…
jayhshah Nov 25, 2025
9194297
[NFC] Trivial fix to silence linter (#1928)
jduprat Nov 25, 2025
5cc6fa4
Add LICENSE and AUTHORS to flash_attn/cute (#2032)
jduprat Nov 25, 2025
63b66f2
[Cute] Add authors
tridao Nov 25, 2025
92ca9da
[Cute,Fwd] enable mask mod without blocksparsity (#2031)
reubenconducts Nov 25, 2025
672381f
Bump pin (#2025)
drisspg Nov 25, 2025
91ba87d
ruff all the smaller files (#2040)
drisspg Dec 2, 2025
de6a6ad
[Flash] Fix head dim 64 bwd (#2035)
drisspg Dec 2, 2025
26ba559
Add headdim64 tests (#2041)
drisspg Dec 2, 2025
59df2f9
Merge pull request #1769 from guilhermeleobas/guilhermeleobas/fa3-com…
v0i0 Dec 4, 2025
56fdf3e
[Cute,Bwd,Sm100] Add local for sm100 bwd (#2046)
jayhshah Dec 6, 2025
0d1ad61
Add hash attr to shortcut expensive check (#2048)
drisspg Dec 7, 2025
6328432
[AMD ROCm] Update to latest composable_kernel to improve performance …
rocking5566 Dec 7, 2025
c783ab2
fixing cute bwd func def (#2056)
liangel-02 Dec 9, 2025
bc0e4ac
Fix use-after-free in FA3 deterministic mode. The pytorch caching all…
skarupke Dec 12, 2025
e240e0f
[CUTE] Allow grads to be preallocated (#2065)
drisspg Dec 15, 2025
fd8d5eb
[Cute,Fwd] Extend score_mod to variable sequence length (#2043)
reubenconducts Dec 15, 2025
179f793
[CUTE] Seeing if tvvm reduces cpu overhead (#2042)
drisspg Dec 15, 2025
0a5339f
[FIRST] Fix softcap scoremod kwargs typo. (#2072)
LeoZDong Dec 16, 2025
ac9b5f1
basics working (#2070)
drisspg Dec 16, 2025
eacbc56
Blocksparse impl (#2085)
drisspg Dec 18, 2025
bba578d
Fix IMA in fwd on m boundary (#2091)
drisspg Dec 20, 2025
ceb4110
Update to dsl 3.4.3 (#2092)
drisspg Dec 22, 2025
5663adf
README for AMD ROCm (#2068)
seungrokj Dec 23, 2025
58fe37f
fix shuffle sync for pack gqa epilogue (#2097)
jayhshah Dec 24, 2025
11b32fd
improve paged cpasync
v0i0 Dec 24, 2025
d234051
Enable Thor (#2108)
johnnynunez Dec 29, 2025
4fd123e
[Cute] Add quack as dependency
tridao Dec 31, 2025
f3423a8
[Cute,Fwd,Sm90] Change PipelineTMAAsync sublass to signal per warp
tridao Jan 1, 2026
9b6dbac
Add pack-gqa support for blcoksparse impl w/ braodcasted H dim (#2098)
drisspg Jan 4, 2026
f98d345
[Cute,Fwd] improved block sparsity (#2100)
reubenconducts Jan 5, 2026
bb2efb3
[Cute] Fix minor lint issue in shuffle_sync
tridao Jan 5, 2026
f472175
Misc tests that should be xfailed for now (#2127)
drisspg Jan 5, 2026
3e87e42
Update cutlass to fix undefined symbol: cuDriverGetVersion. (#2142)
HydraQYH Jan 7, 2026
3c8ca4e
[Cute,Fwd,Sm100] Support `q_stage=1` for inference (#1993)
timmy-feng Jan 8, 2026
6dd7e74
[Cute] Fix two tests that were failing (#2149)
henrylhtsang Jan 8, 2026
c15ffe3
cleanup
v0i0 Jan 8, 2026
ed6a82f
[Cute, Bwd, Sm100] Add varlen for sm100 bwd (#2150)
jayhshah Jan 9, 2026
27a3b54
block-sparse backward SM90 (#2136)
drisspg Jan 10, 2026
844b10f
score-mod backward SM90 (#2137)
drisspg Jan 10, 2026
e317aa4
[Cute] Clarify and fix subtle cachekey bug (#2143)
drisspg Jan 10, 2026
26d4ee9
[CUTE][SM100] Fix backward gqa on sm100 post mask-mod semantic change…
drisspg Jan 10, 2026
8eff546
[CUTE][SM90]Enable pack-gqa with broadcasted maskmods (#2145)
drisspg Jan 10, 2026
5d4c953
[CUTE][SM90] GQA backward non deterministic (#2158)
drisspg Jan 10, 2026
ea8f735
[Cute,Bwd,Sm100] fix seqused in varlen bwd (#2167)
jayhshah Jan 10, 2026
ef7343b
[CUTE] Bump cutedsl to 4.3.5 (#2170)
drisspg Jan 12, 2026
dbf08eb
Merge pull request #2156 from v0i0/v0i0/improve-paged-ldgsts
v0i0 Jan 12, 2026
4cb272e
[Cute,Flex] Add option to create and cache __cute_hash__ (#2171)
reubenconducts Jan 12, 2026
4894657
[Cute][Flex] Remove no longer needed contig (#2172)
drisspg Jan 12, 2026
13696f2
[Cute] update row_max before safe overwrite for online_softmax (#2174)
jayhshah Jan 13, 2026
506441a
[Cute][Flex] add back in contig (#2177)
drisspg Jan 15, 2026
68649fb
[Cute][Flex]Add pack-gqa divmod (#2180)
drisspg Jan 15, 2026
88067b0
baseline local flops
henrylhtsang Jan 15, 2026
fffabc3
[Cute,Fwd,Sm100] distributed offset calculation for paged KV (#2104)
timmy-feng Jan 15, 2026
a512bd8
Add R2P dual bound masking for local attention
henrylhtsang Jan 15, 2026
2020964
remove benchmark result, undo changes to benchmark
henrylhtsang Jan 15, 2026
7108d1c
Add R2P dual bound masking for local attention
henrylhtsang Jan 15, 2026
e4ec1ad
switch from xor to mask_right & ~ mask_left
henrylhtsang Jan 16, 2026
ac88858
flip in_bound to out_bound
henrylhtsang Jan 16, 2026
e34d840
remove zero logic for right_s and left_s
henrylhtsang Jan 16, 2026
08e6518
remove 24 clamp
henrylhtsang Jan 16, 2026
94f0348
doc
henrylhtsang Jan 16, 2026
e94012a
lint
henrylhtsang Jan 16, 2026
2e6ae05
added back clamp to avoid "OverflowError: Python int too large to con…
henrylhtsang Jan 16, 2026
137ad8e
add comment
henrylhtsang Jan 16, 2026
2d6b146
Merge pull request #2185 from henrylhtsang/test_local_r2p
v0i0 Jan 17, 2026
a0f9f41
[Cute][Flex] Fix expanded tensor bug (#2189)
drisspg Jan 17, 2026
04e6ee1
[Cute, SM90] fix fwd varlen Cute implementation bug for H100 (#2194)
KareemMusleh Jan 20, 2026
f15ccf5
reduce chance of build oom (#2079)
Qubitium Jan 21, 2026
2cf8a1f
Merge remote-tracking branch 'upstream/main' into sync/upstream-main-…
LucasWilkinson Jan 22, 2026
35756f5
restore changes
LucasWilkinson Jan 22, 2026
997fc13
fix compile error
LucasWilkinson Jan 22, 2026
d3320d4
fix
LucasWilkinson Jan 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions .github/workflows/_build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
name: ~Build wheel template

on:
workflow_call:
inputs:
runs-on:
description: "The runner to use for the build"
required: true
type: string
python-version:
description: "The Python version to use for the build"
required: true
type: string
cuda-version:
description: "The CUDA version to use for the build"
required: true
type: string
torch-version:
description: "The PyTorch version to use for the build"
required: true
type: string
cxx11_abi:
description: "The C++11 ABI to use for the build"
required: true
type: string
upload-to-release:
description: "Upload wheel to this release"
required: false
type: boolean
default: false
release-version:
description: "Upload wheel to this release"
required: false
type: string

defaults:
run:
shell: bash -x -e -u -o pipefail {0}

jobs:
build-wheel:
runs-on: ${{ inputs.runs-on }}
name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }})
steps:
- name: Checkout
uses: actions/checkout@v5
with:
ref: ${{ inputs.release-version }}
submodules: recursive

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}

- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV

- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/easimon/maximize-build-space/tree/test-report
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL

- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
with:
swap-size-gb: 10

- name: Install CUDA ${{ inputs.cuda-version }}
if: ${{ inputs.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.29
id: cuda-toolkit
with:
cuda: ${{ inputs.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }}
method: "network"
sub-packages: '["nvcc"]'

- name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }}
run: |
pip install --upgrade pip
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
pip install typing-extensions==4.12.2
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
# detect if we're on ARM
if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then
PLAT=linux_aarch64
else
PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64
fi
echo "PLAT=$PLAT" >> $GITHUB_ENV
if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then
# pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# Can't use --no-deps because we need cudnn etc.
# Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904
pip install jinja2
TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl
TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl
pip install --no-cache-dir --pre "${TRITON_URL}"
pip install --no-cache-dir --pre "${TORCH_URL}"
else
pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
nvcc --version
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"

- name: Restore build cache
uses: actions/cache/restore@v4
with:
path: build.tar
key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }}
restore-keys: |
build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-

- name: Unpack build cache
run: |
echo ::group::Adjust timestamps
sudo find / -exec touch -t 197001010000 {} + || true
echo ::endgroup::

if [ -f build.tar ]; then
find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} +
tar -xpvf build.tar -C .
else
echo "No build.tar found, skipping"
fi

ls -al ./
ls -al build/ || true
ls -al csrc/ || true

- name: Build wheel
id: build_wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==75.8.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
# nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM

export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2)
export NVCC_THREADS=2
export FLASH_ATTENTION_FORCE_BUILD="TRUE"
export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }}

# 5h timeout since GH allows max 6h and we want some buffer
EXIT_CODE=0
timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$?

if [ $EXIT_CODE -eq 0 ]; then
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
fi

# Store exit code in GitHub env for later steps
echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT"

# Do not fail the job if timeout killed the build
exit $EXIT_CODE

- name: Log build logs after timeout
if: always() && steps.build_wheel.outputs.build_exit_code == 124
run: |
ls -al ./
tar -cvf build.tar . --atime-preserve=replace

- name: Save build cache timeout
if: always() && steps.build_wheel.outputs.build_exit_code == 124
uses: actions/cache/save@v4
with:
key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }}
path: build.tar

- name: Log Built Wheels
run: |
ls dist

- name: Get Release with tag
id: get_current_release
uses: joutvhu/get-release@v1
with:
tag_name: ${{ inputs.release-version }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

- name: Upload Release Asset
id: upload_release_asset
if: inputs.upload-to-release
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
47 changes: 47 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: Build wheels

on:
workflow_dispatch:
inputs:
runs-on:
description: "The runner to use for the build"
required: true
type: string
default: ubuntu-22.04
python-version:
description: "The Python version to use for the build"
required: true
type: string
cuda-version:
description: "The CUDA version to use for the build"
required: true
type: string
torch-version:
description: "The PyTorch version to use for the build"
required: true
type: string
cxx11_abi:
description: "Enable torch flag C++11 ABI (TRUE/FALSE)"
required: true
type: string
upload-to-release:
description: "Upload wheel to this release"
required: false
type: boolean
default: false
release-version:
description: "Upload wheel to this release"
required: false
type: string

jobs:
build-wheels:
uses: ./.github/workflows/_build.yml
with:
runs-on: ${{ inputs.runs-on }}
python-version: ${{ inputs.python-version }}
cuda-version: ${{ inputs.cuda-version }}
torch-version: ${{ inputs.torch-version }}
cxx11_abi: ${{ inputs.cxx11_abi }}
upload-to-release: ${{ inputs.upload-to-release }}
release-version: ${{ inputs.release-version }}
33 changes: 33 additions & 0 deletions .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Lint

on:
pull_request:
paths:
- 'flash_attn/cute/flash_bwd_sm90.py'
- 'flash_attn/cute/flash_bwd_preprocess.py'
- 'flash_attn/cute/flash_bwd_postprocess.py'
- 'flash_attn/cute/softmax.py'
- '.pre-commit-config.yaml'
push:
branches:
- main
paths:
- 'flash_attn/cute/flash_bwd_sm90.py'
- 'flash_attn/cute/flash_bwd_preprocess.py'
- 'flash_attn/cute/flash_bwd_postprocess.py'
- 'flash_attn/cute/softmax.py'
- '.pre-commit-config.yaml'

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5

- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.11'

- name: Run pre-commit
uses: pre-commit/action@v3.0.1
Loading