Skip to content

Commit 1523334

Browse files
committed
Add ROCm/rocWMMA support for RDNA3 (gfx1151) with AMD Windows setup guide
1 parent bfd980b commit 1523334

File tree

16 files changed

+2909
-191
lines changed

16 files changed

+2909
-191
lines changed

.gitignore

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,80 @@
1-
__pycache__
2-
spas_sage_attn.egg-info
3-
*.pkl
4-
/dist
5-
/build
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
67
*.so
7-
.DS_Store
8-
inst*.cu
9-
/unit_test
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
27+
# PyInstaller
28+
*.manifest
29+
*.spec
30+
31+
# pip
32+
pip-log.txt
33+
pip-delete-this-directory.txt
34+
35+
# Unit test / coverage reports
36+
htmlcov/
37+
.tox/
38+
.coverage
39+
.coverage.*
40+
.cache
41+
nosetests.xml
42+
coverage.xml
43+
*.cover
44+
.hypothesis/
45+
.pytest_cache/
46+
47+
# Translations
48+
*.mo
49+
*.pot
50+
51+
# Environments
52+
.env
53+
.venv
54+
env/
55+
venv/
56+
ENV/
57+
env.bak/
58+
venv.bak/
59+
60+
# IDE
61+
.idea/
62+
.vscode/
63+
.cursor/
64+
*.swp
65+
*.swo
66+
67+
# ROCm cloned libraries
68+
/third_party/
69+
70+
# HIP generated files
71+
*.hip
72+
73+
# Build artifacts
74+
*.o
75+
*.obj
76+
77+
# Instantiation generated files
78+
csrc/qattn/instantiations_sm80/*.cu
79+
csrc/qattn/instantiations_sm89/*.cu
80+
csrc/qattn/instantiations_sm90/*.cu

README_AMD_WINDOWS.md

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# SpargeAttn - AMD ROCm on Windows Setup Guide
2+
3+
This guide explains how to build and run SpargeAttn on Windows with AMD GPUs using ROCm.
4+
5+
> **Note:** These steps should also work on Linux with minor modifications (use bash commands instead of PowerShell, `source venv/bin/activate` instead of `.\venv\Scripts\Activate.ps1`, and skip the Visual Studio environment setup). However, Linux support has not been tested yet and may have issues.
6+
7+
## Supported Hardware
8+
9+
SpargeAttn on Windows has been tested with RDNA3/RDNA3.5 GPUs (gfx1100, gfx1101, gfx1102, gfx1103, gfx1151).
10+
11+
## Prerequisites
12+
13+
- Windows 10/11
14+
- Python 3.11, 3.12, or 3.13
15+
- Visual Studio 2022 with C++ build tools
16+
- AMD Adrenaline driver (latest recommended)
17+
18+
## Installation
19+
20+
### 1. Install ROCm and PyTorch from TheRock
21+
22+
Follow the instructions at [ROCm/TheRock RELEASES.md](https://github.com/ROCm/TheRock/blob/main/RELEASES.md) to install ROCm and PyTorch wheels for your GPU architecture.
23+
24+
#### Create a Virtual Environment
25+
26+
```powershell
27+
python -m venv venv
28+
.\venv\Scripts\Activate.ps1
29+
```
30+
31+
#### Install ROCm SDK and PyTorch
32+
33+
For **gfx1151** (AMD Strix Halo iGPU):
34+
```powershell
35+
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ --pre rocm-sdk[devel]
36+
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ --pre torch torchaudio torchvision
37+
```
38+
39+
For **gfx110X** (RX 7900 XTX, RX 7800 XT, RX 7700S, Radeon 780M):
40+
```powershell
41+
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/ --pre rocm-sdk[devel]
42+
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/ --pre torch torchaudio torchvision
43+
```
44+
45+
For **gfx120X** (RX 9060, RX 9070):
46+
```powershell
47+
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ --pre rocm-sdk[devel]
48+
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ --pre torch torchaudio torchvision
49+
```
50+
51+
#### Initialize ROCm SDK
52+
53+
```powershell
54+
rocm-sdk init
55+
```
56+
57+
### 2. Set Environment Variables
58+
59+
Open a PowerShell terminal and run:
60+
61+
```powershell
62+
# Activate Visual Studio environment
63+
cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } }
64+
65+
# Activate the virtual environment
66+
.\venv\Scripts\Activate.ps1
67+
68+
# Set ROCm paths using rocm-sdk
69+
$ROCM_ROOT = (rocm-sdk path --root).Trim()
70+
$ROCM_BIN = (rocm-sdk path --bin).Trim()
71+
$env:ROCM_HOME = $ROCM_ROOT
72+
$env:PATH = "$ROCM_ROOT\lib\llvm\bin;$ROCM_BIN;$env:PATH"
73+
74+
# Set compiler and build settings
75+
$env:CC = "clang-cl"
76+
$env:CXX = "clang-cl"
77+
$env:DISTUTILS_USE_SDK = "1"
78+
79+
# Enable experimental features
80+
$env:FLASH_ATTENTION_TRITON_AMD_ENABLE = "TRUE"
81+
$env:TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL = "1"
82+
```
83+
84+
### 3. Build and Install SpargeAttn
85+
86+
```powershell
87+
cd <path_to_spargeattn>
88+
pip install --no-build-isolation -v .
89+
```
90+
91+
## Testing
92+
93+
### Quick Correctness Test
94+
95+
Run this script to verify SpargeAttn is working correctly by comparing against PyTorch SDPA:
96+
97+
```python
98+
import torch
99+
import torch.nn.functional as F
100+
from spas_sage_attn.core import spas_sage_attn_meansim_cuda
101+
102+
device = torch.device('cuda')
103+
104+
# Create random test tensors (use float16 for ROCm compatibility)
105+
q = torch.randn(1, 12, 2048, 128, dtype=torch.float16, device=device)
106+
k = torch.randn(1, 12, 2048, 128, dtype=torch.float16, device=device)
107+
v = torch.randn(1, 12, 2048, 128, dtype=torch.float16, device=device)
108+
109+
# Compute reference output using PyTorch SDPA
110+
with torch.no_grad():
111+
sdpa = F.scaled_dot_product_attention(q.float(), k.float(), v.float()).to(torch.float16)
112+
113+
# Compute SpargeAttn output (with 100% sparsity = dense attention)
114+
sparge = spas_sage_attn_meansim_cuda(
115+
q, k, v,
116+
is_causal=False,
117+
smooth_k=False,
118+
simthreshd1=0.0, # No similarity threshold (keep all blocks)
119+
cdfthreshd=1.0, # 100% sparsity
120+
pvthreshd=0,
121+
tensor_layout='HND'
122+
)
123+
124+
# Compare outputs using cosine similarity
125+
cos = F.cosine_similarity(
126+
sdpa.flatten().float().unsqueeze(0),
127+
sparge.flatten().float().unsqueeze(0)
128+
)
129+
print(f'Cosine similarity: {cos.item():.6f}') # Should be ~0.9999
130+
```
131+
132+
Save this as `test_spargeattn.py` and run:
133+
134+
```powershell
135+
python test_spargeattn.py
136+
```
137+
138+
Expected output:
139+
```
140+
Cosine similarity: 0.999900
141+
```
142+
143+
A cosine similarity above 0.999 indicates the kernel is working correctly.
144+
145+
## Performance Notes
146+
147+
At L=4096, D=128, bf16 vs PyTorch SDPA (with aotriton):
148+
149+
| Sparsity | Time | Speedup vs SDPA |
150+
|----------|------|-----------------|
151+
| 100% | 33.0 ms | 0.18x |
152+
| 50% | 13.7 ms | 0.43x |
153+
| 25% | 7.4 ms | 0.79x |
154+
| **10%** | **3.2 ms** | **1.81x** |
155+
| 5% | 1.8 ms | 3.26x |
156+
| 2% | 1.0 ms | 6.07x |
157+
158+
**Break-even point**: ~20-25% sparsity. Below that, SpargeAttn is faster than dense SDPA.
159+
160+
## Known Issues
161+
162+
1. **No FP8 support on RDNA3** - rocWMMA on gfx11xx doesn't support FP8, so FP16/BF16 is used for V.
163+
164+
2. **Triton compiler warnings** - You may see `clang-cl: warning: unknown argument ignored` warnings during first run. These are harmless.
165+
166+
## Troubleshooting
167+
168+
### "LoadLibrary failed" or "cannot find amdhip64.dll"
169+
170+
Make sure you ran `rocm-sdk init` after installing the ROCm SDK packages.
171+
172+
### "LINK : fatal error LNK1104: cannot open file 'python312.lib'"
173+
174+
Ensure Visual Studio environment is activated before building:
175+
```powershell
176+
cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } }
177+
```
178+
179+
### "PermissionError" when compiling Triton kernels
180+
181+
This is a known Windows issue with temp file handling. Make sure you're using the latest triton-windows package with AMD Windows support patches. PR is currently WIP - https://github.com/woct0rdho/triton-windows/pull/179
182+

csrc/fused/rocm/dispatch_utils.h

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright (c) 2024 by SageAttention team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
#include <torch/extension.h>
19+
#include <cstdint>
20+
#include <sstream>
21+
#include <stdexcept>
22+
23+
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
24+
if (head_dim == 64) { \
25+
constexpr int HEAD_DIM = 64; \
26+
__VA_ARGS__ \
27+
} else if (head_dim == 128) { \
28+
constexpr int HEAD_DIM = 128; \
29+
__VA_ARGS__ \
30+
} else { \
31+
std::ostringstream err_msg; \
32+
err_msg << "Unsupported head dim: " << int(head_dim); \
33+
throw std::invalid_argument(err_msg.str()); \
34+
}
35+
36+
#define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \
37+
if (is_causal == 1) { \
38+
constexpr bool IS_CAUSAL = true; \
39+
__VA_ARGS__ \
40+
} else if (is_causal == 0) { \
41+
constexpr bool IS_CAUSAL = false; \
42+
__VA_ARGS__ \
43+
} else { \
44+
std::ostringstream err_msg; \
45+
err_msg << "Unsupported causal mode: " << int(is_causal); \
46+
throw std::invalid_argument(err_msg.str()); \
47+
}
48+
49+
#define DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, ...) \
50+
if (qk_quant_gran == 2) { \
51+
constexpr int QK_QUANT_GRAN = 2; \
52+
__VA_ARGS__ \
53+
} else if (qk_quant_gran == 3) { \
54+
constexpr int QK_QUANT_GRAN = 3; \
55+
__VA_ARGS__ \
56+
} else { \
57+
std::ostringstream err_msg; \
58+
err_msg << "Unsupported qk_quant_gran: " << int(qk_quant_gran); \
59+
throw std::invalid_argument(err_msg.str()); \
60+
}
61+
62+
#define DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, ...) \
63+
if (return_lse == 1) { \
64+
constexpr bool RETURN_LSE = true; \
65+
__VA_ARGS__ \
66+
} else if (return_lse == 0) { \
67+
constexpr bool RETURN_LSE = false; \
68+
__VA_ARGS__ \
69+
} else { \
70+
std::ostringstream err_msg; \
71+
err_msg << "Unsupported causal mode: " << int(return_lse); \
72+
throw std::invalid_argument(err_msg.str()); \
73+
}
74+
75+
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
76+
if (pytorch_dtype == at::ScalarType::Half) { \
77+
using c_type = half; \
78+
__VA_ARGS__ \
79+
} else if (pytorch_dtype == at::ScalarType::BFloat16) { \
80+
using c_type = hip_bfloat16; \
81+
__VA_ARGS__ \
82+
} else { \
83+
std::ostringstream oss; \
84+
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
85+
TORCH_CHECK(false, oss.str()); \
86+
}
87+
88+
#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \
89+
if (block_size == 64) { \
90+
constexpr int BLOCK_SIZE = 64; \
91+
__VA_ARGS__ \
92+
} else if (block_size == 128) { \
93+
constexpr int BLOCK_SIZE = 128; \
94+
__VA_ARGS__ \
95+
} else { \
96+
std::ostringstream err_msg; \
97+
err_msg << "Unsupported block_size " << int(block_size); \
98+
throw std::invalid_argument(err_msg.str()); \
99+
}
100+
101+
#define DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, ...) \
102+
if (warp_block_size == 16) { \
103+
constexpr int WARP_BLOCK_SIZE = 16; \
104+
__VA_ARGS__ \
105+
} else if (warp_block_size == 32) { \
106+
constexpr int WARP_BLOCK_SIZE = 32; \
107+
__VA_ARGS__ \
108+
} else { \
109+
std::ostringstream err_msg; \
110+
err_msg << "Unsupported warp_block_size " << int(warp_block_size); \
111+
throw std::invalid_argument(err_msg.str()); \
112+
}

0 commit comments

Comments
 (0)