Skip to content

Commit ab2834e

Browse files
authored
[HGEMM] Update toy-hgemm library 0.1.0 (#149)
* Update README.md * Update README.md * Update hgemm_cublas.cu * Update hgemm_mma_stage_tn_cute.cu * Update hgemm.py * Update makefile * Update hgemm.cc * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update hgemm.py * Update setup.py
1 parent 7d01ce1 commit ab2834e

File tree

8 files changed

+80
-40
lines changed

8 files changed

+80
-40
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
<div id="hgemm-sgemm"></div>
2222

2323
<div align='left'>
24-
<img src='https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99' height="150px" width="265px">
25-
<img src='https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85' height="150px" width="265px">
26-
<img src='https://github.com/user-attachments/assets/9472e970-c083-4b31-9252-3eeecc761078' height="150px" width="265px">
24+
<img src='https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99' height="150px" width="267px">
25+
<img src='https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85' height="150px" width="267px">
26+
<img src='https://github.com/user-attachments/assets/9472e970-c083-4b31-9252-3eeecc761078' height="150px" width="267px">
2727
</div>
2828

29-
Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA)` implemented in this repo (`blue`🔵) can achieve `95%~99%` of its (`orange`🟠) performance. Please check [toy-hgemm library🔥🔥](./kernels/hgemm) for more details.
29+
Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA)` implemented in this repo (`blue`🔵) can achieve `99%~100%+` of its (`orange`🟠) performance. Please check [toy-hgemm library🔥🔥](./kernels/hgemm) for more details.
3030

3131
|CUDA Cores|Sliced K(Loop over K)|Tile Block|Tile Thread|
3232
|:---:|:---:|:---:|:---:|
@@ -35,9 +35,9 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
3535
|✔️|✔️|✔️|✔️|
3636
|Copy Async|Tile MMA(More Threads)|Tile Warp(More Values)|Multi Stages|
3737
|✔️|✔️|✔️|✔️|
38-
|Reg Double Buffers|Block Swizzle|Warp Swizzle|Collective Store(Warp Shuffle)|
38+
|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle(CuTe)|
3939
|✔️|✔️|✔️|✔️|
40-
|Row Major(NN)|Col Major(TN)|SGEMM TF32|SMEM Swizzle(CuTe)|
40+
|Collective Store(Warp Shfl)|Row Major(NN)|Col Major(TN)|SGEMM F32/TF32|
4141
|✔️|✔️|✔️|✔️|
4242

4343
## ©️Citations🎉🎉

kernels/hgemm/README.md

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
## 🔥🔥Toy-HGEMM Library: Achieve the performance of cuBLAS
1+
# 🔥🔥Toy-HGEMM Library: Achieve the performance of cuBLAS
22

3-
|CUDA Cores|Sliced K(Loop over K)|Tile Block|Tile Thread|
3+
|CUDA Cores|Sliced K(Loop over K)|Tile Block(BMxBN)|Tile Thread(t 8x8)|
44
|:---:|:---:|:---:|:---:|
55
|✔️|✔️|✔️|✔️|
66
|WMMA(m16n16k16)|MMA(m16n8k16)|Pack LDST(128 bits)|SMEM Padding|
77
|✔️|✔️|✔️|✔️|
8-
|Copy Async|Tile MMA(More Threads)|Tile Warp(More Values)|Multi Stages|
8+
|Copy Async(cp.async.cg/ca)|Tile MMA(More Threads)|Tile Warp(More Values)|Multi Stages(2/3/4/5)|
99
|✔️|✔️|✔️|✔️|
10-
|Reg Double Buffers|Block Swizzle|Warp Swizzle|Collective Store(Warp Shfl)|
10+
|Register Double Buffers|Block Swizzle(Zigzag N)|Warp Swizzle(Zigzag N)|SMEM Swizzle(CUTLASS/CuTe)|
1111
|✔️|✔️|✔️|✔️|
12-
|Row Major(NN)|Col Major(TN)|SGEMM TF32|SMEM Swizzle(CuTe)|
12+
|Collective Store(Warp Shuffle & Reg Reuse)|Row Major(NN)|Col Major(TN)|SGEMM FP32/TF32|
1313
|✔️|✔️|✔️|✔️|
1414

15-
1615
## 📖 HGEMM CUDA Kernels in Toy-HGEMM Library 🎉🎉
1716

17+
<div id="kernels"></div>
18+
1819
```C++
1920
void hgemm_naive_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c);
2021
void hgemm_sliced_k_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c);
@@ -49,10 +50,23 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Te
4950
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
5051
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
5152
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
52-
void hgemm_mma_stages_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
53+
void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
5354
```
5455
55-
## 📖 安装
56+
## 📖 目录
57+
58+
- [📖 安装](#install)
59+
- [📖 测试](#test)
60+
- [📖 NVIDIA L20 性能数据](#perf-l20)
61+
- [📖 NVIDIA RTX 4090 性能数据](#perf-4090)
62+
- [📖 NVIDIA RTX 3080 Laptop 性能数据](#perf-3080)
63+
- [📖 性能优化笔记](#opt-docs)
64+
- [📖 参考文献](#ref)
65+
66+
## 📖 安装
67+
68+
<div id="install"></div>
69+
5670
本仓库实现的HGEMM可以作为一个python库使用(可选)
5771
```bash
5872
git submodule update --init --recursive --force # 更新cutlass, 必须
@@ -61,6 +75,8 @@ python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip un
6175

6276
## 📖 测试
6377

78+
<div id="test"></div>
79+
6480
**CUTLASS**: 更新CUTLASS依赖库
6581
```bash
6682
git submodule update --init --recursive --force
@@ -125,9 +141,11 @@ M N K = 16384 16384 16384, Time = 0.07668429 0.07669371 0.07670784 s, A
125141

126142
## 📖 目前性能
127143

144+
<div id="perf-l20"></div>
145+
128146
### NVIDIA L20
129147

130-
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),整体上能达到cuBLAS大概99%左右的性能。使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。CuTe版本的HGEMM性能基本持平cuBLAS,部分case会超越cuBLAS,能达到 116-117 TFLOPS。目前通过 SMEM Padding 和 SMEM swizzle的方式缓解bank conflicts。对于 NN layout,使用 SMEM Padding 缓解 bank conflicts;对于 TN layout,通过cutlass cute的 SMEM Swizzle 消除 bank conflicts。
148+
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),整体上能达到cuBLAS大概`99~100+%`左右的性能。使用WMMA API能达到cuBLAS大概`95%~98%`左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分 case 会超越 cuBLAS。CuTe 版本的 HGEMM 实现了 Block Swizzle(L2 Cache friendly)和 SMEM Swizzle(bank conflicts free),性能最优,大规模矩阵乘能达到 116-117 TFLOPS,是 cuBLAS 大概`98%~100%+`左右的性能,很多case会超越cuBLAS。目前通过 SMEM Padding 和 SMEM Swizzle 的方式缓解 bank conflicts。对于 NN layout,使用 SMEM Padding 缓解 bank conflicts;对于 TN layout,通过 CUTLASS/CuTe 的 SMEM Swizzle 消除 bank conflicts。
131149

132150
<div id="NV-L20"></div>
133151

@@ -148,6 +166,9 @@ python3 hgemm.py --cute-tn --mma --plot
148166
```
149167

150168
### NVIDIA GeForce RTX 4090
169+
170+
<div id="perf-4090"></div>
171+
151172
在NVIDIA RTX 4090上(FP16 Tensor Cores算力为330 TFLOPS),WMMA(m16n16k16)性能表现比MMA(m16n8k16)要更好,大分部MNK下,本仓库的实现能达到cuBLAS 95%~99%的性能,某些case能超过cuBLAS。就本仓库的实现而言,在RTX 4090上,大规模矩阵乘(MNK>=8192),WMMA表现更优,小规模矩阵乘,MMA表现更优。
152173

153174
<!---
@@ -164,6 +185,8 @@ python3 hgemm.py --cute-tn --mma --wmma-all --plot
164185

165186
### NVIDIA GeForce RTX 3080 Laptop
166187

188+
<div id="perf-3080"></div>
189+
167190
在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,使用Windows WSL2 + RTX 3080 Laptop进行测试。
168191

169192
<!--
@@ -179,6 +202,9 @@ python3 hgemm.py --wmma-all --plot
179202

180203
## 📖 性能优化笔记
181204

205+
<div id="opt-docs"></div>
206+
207+
182208
### PyTorch HGEMM Profile
183209

184210
在Ada架构下,PyTorch 2.4对FP16使用matmul时,会调用:
@@ -282,7 +308,9 @@ TODO
282308
283309
</details>
284310
285-
## 参考文献
311+
## 📖 参考文献
312+
313+
<div id="ref"></div>
286314
287315
- [CUDA编程概念】一、什么是bank conflict?](https://zhuanlan.zhihu.com/p/659142274)
288316
- [解决 bank conflict](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/README.md)

kernels/hgemm/cublas/hgemm_cublas.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ void cublas_tensor_op_nn(half *A, half *B, half *C, size_t M, size_t N, size_t
4343
static half alpha = 1.0;
4444
static half beta = 0.0;
4545

46+
if (g_handle == nullptr) {
47+
init_cublas_handle();
48+
}
49+
4650
cublasGemmEx(g_handle,
4751
CUBLAS_OP_N,
4852
CUBLAS_OP_N,
@@ -62,6 +66,10 @@ void cublas_tensor_op_tn(half *A, half *B, half *C, size_t M, size_t N, size_t
6266
static half alpha = 1.0;
6367
static half beta = 0.0;
6468

69+
if (g_handle == nullptr) {
70+
init_cublas_handle();
71+
}
72+
6573
cublasGemmEx(g_handle,
6674
CUBLAS_OP_T,
6775
CUBLAS_OP_N,

kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,8 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
461461
);
462462

463463

464-
// Multi stages CuTe HGEMM with smem and block swizzle.
465-
void hgemm_mma_stages_tn_cute(
464+
// Multi stages CuTe HGEMM with SMEM Swizzle and Block Swizzle.
465+
void hgemm_mma_stages_block_swizzle_tn_cute(
466466
torch::Tensor a, torch::Tensor b, torch::Tensor c,
467467
int stages, bool swizzle, int swizzle_stride) {
468468
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)

kernels/hgemm/hgemm.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ def get_args():
5858
pretty_print_line()
5959

6060

61-
hgemm = try_load_hgemm_library(force_build=args.force_build,
62-
verbose=args.verbose)
61+
hgemm = try_load_hgemm_library(force_build=args.force_build, verbose=args.verbose)
6362

6463
MAX_TFLOPS = -1
6564
STATIS_INFO: dict[str, list[float]] = {}
@@ -69,14 +68,12 @@ def get_args():
6968
CUBLAS_TN_TOTAL_TFLOPS = 0
7069

7170

72-
def make_block_swizzle_stride(N: int, K: int):
71+
def make_block_swizzle_stride(N: int, K: int, swizzle_factor: float = None):
7372
# make swizzle stride as N/8,N/4,N/2 and multiples of 256
74-
if args.swizzle_factor is None:
73+
if swizzle_factor is None:
7574
swizzle_factor = 0.5 if N <= 4096 else 0.25
7675
if all((N >= 14848, K > 8192, N % 8 == 0)):
7776
swizzle_factor = 0.125
78-
else:
79-
swizzle_factor = args.swizzle_factor
8077

8178
swizzle_stride = int(N * swizzle_factor)
8279
swizzle_stride = swizzle_stride if swizzle_stride >= 256 else 1
@@ -100,7 +97,7 @@ def run_benchmark(perf_func: callable,
10097
K = a.size(1)
10198
N = b.size(1) # TN still has shape [K,N]
10299
if swizzle:
103-
swizzle_stride = make_block_swizzle_stride(N, K)
100+
swizzle_stride = make_block_swizzle_stride(N, K, args.swizzle_factor)
104101
swizzle = swizzle if swizzle_stride >= 256 else False
105102
else:
106103
swizzle_stride = 1 # means no thread block swizzle
@@ -110,6 +107,10 @@ def run_benchmark(perf_func: callable,
110107

111108
if out is not None:
112109
out.fill_(0)
110+
111+
if "cublas" in tag:
112+
hgemm.init_cublas_handle()
113+
113114
if out is not None:
114115
for i in range(warmup):
115116
if stages > 1:
@@ -177,6 +178,9 @@ def run_benchmark(perf_func: callable,
177178
CUBLAS_TOTAL_TFLOPS += TFLOPS
178179

179180
torch.cuda.synchronize()
181+
if "cublas" in tag:
182+
hgemm.destroy_cublas_handle()
183+
180184
del out_flat
181185
out_flat = None
182186
gc.collect()
@@ -262,6 +266,7 @@ def skip_it(tag: str) -> bool:
262266
save_path = f"{args.save_dir}/{device_name}_{args.save_tag}.png"
263267
else:
264268
save_path = f"{args.save_dir}/{device_name}.png"
269+
os.makedirs(args.save_dir, exist_ok=True)
265270
plt.savefig(save_path, dpi=300)
266271
pretty_print_line(f"plot hgemm TFLOPS done, saved as {save_path}")
267272

@@ -383,24 +388,20 @@ def get_mnk(sep: int = args.SEP):
383388
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b_col_major, "tn(mma2x4+warp4x4+stage3+dsmem+swizzle<block>)", c, stages=3, swizzle=True)
384389
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b_col_major, "tn(mma2x4+warp4x4+stage2+dsmem+swizzle<block>)", c, stages=2, swizzle=True)
385390
if args.enable_cute_tn:
386-
run_benchmark(hgemm.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage4+swizzle<smem>)", c, stages=4)
387-
run_benchmark(hgemm.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage3+swizzle<smem>)", c, stages=3)
388-
run_benchmark(hgemm.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage2+swizzle<smem>)", c, stages=2)
389-
run_benchmark(hgemm.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage4+swizzle<smem+block>)", c, stages=4, swizzle=True)
390-
run_benchmark(hgemm.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage3+swizzle<smem+block>)", c, stages=3, swizzle=True)
391-
run_benchmark(hgemm.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage2+swizzle<smem+block>)", c, stages=2, swizzle=True)
391+
run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage4+swizzle<smem>)", c, stages=4)
392+
run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage3+swizzle<smem>)", c, stages=3)
393+
run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage2+swizzle<smem>)", c, stages=2)
394+
run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage4+swizzle<smem+block>)", c, stages=4, swizzle=True)
395+
run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage3+swizzle<smem+block>)", c, stages=3, swizzle=True)
396+
run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage2+swizzle<smem+block>)", c, stages=2, swizzle=True)
392397
# TN layout: cublas
393398
if not args.disable_cublas_tn and any((args.enable_mma_tn, args.enable_cute_tn)):
394-
hgemm.init_cublas_handle()
395399
run_benchmark(hgemm.hgemm_cublas_tensor_op_tn, a, b_col_major, "tn(cublas)", c)
396-
hgemm.destroy_cublas_handle()
397400
# NN layout: cublas/torch
398401
if (not args.disable_cublas) and any((
399402
args.enable_mma, args.enable_mma_all, args.enable_wmma, args.enable_wmma_all,
400403
args.enable_cuda, args.enable_cuda_all, args.enable_torch)):
401-
hgemm.init_cublas_handle()
402404
run_benchmark(hgemm.hgemm_cublas_tensor_op_nn, a, b, "(cublas)", c)
403-
hgemm.destroy_cublas_handle()
404405
if args.enable_torch:
405406
run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)")
406407
torch.cuda.synchronize()

kernels/hgemm/makefile

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
INCLUDE_DIRS=-I ./utils -I ../../third-party/cutlass/include -I ../../third-party/cutlass/tools/util/include
2+
ARCHS=-gencode arch=compute_80,code=sm_80 -gencode arch=compute_89,code=sm_89
23
default:
3-
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.bin -O2 -arch=sm_89 -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
4-
nvcc cublas/hgemm_cublas.cu -o hgemm_cublas.bin -O2 -arch=sm_89 -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
5-
nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.bin -O2 -arch=sm_89 -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
4+
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.bin -O2 $(ARCHS) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
5+
nvcc cublas/hgemm_cublas.cu -o hgemm_cublas.bin -O2 $(ARCHS) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
6+
nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.bin -O2 $(ARCHS) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
67

kernels/hgemm/pybind/hgemm.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch:
4848
// from hgemm_mma_stage_tn.cu
4949
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
5050
// from hgemm_mma_stage_tn_cute.cu
51-
void hgemm_mma_stages_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
51+
void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
5252

5353

5454
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
@@ -96,6 +96,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
9696
// TN: A row major MxK, B col major NxK, C row major MxN
9797
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn)
9898
// TN: cute hgemm with smem & block swizzle
99-
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_stages_tn_cute)
99+
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_stages_block_swizzle_tn_cute)
100100
}
101101

kernels/hgemm/setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
generator_flag = []
1818
cc_flag = []
1919
cc_flag.append("-gencode")
20+
cc_flag.append("arch=compute_80,code=sm_80")
21+
cc_flag.append("-gencode")
2022
cc_flag.append("arch=compute_89,code=sm_89")
2123

2224

0 commit comments

Comments
 (0)