Skip to content

Commit 14fa9e7

Browse files
authored
[FA2] flash-attn-mma fully tiling-qkv🎉 (#211)
* Create flash_attn_mma_tiling_qkv_swizzle_q.cu * Create flash_attn_mma_tiling_qkv_swizzle_qk.cu * Create flash_attn_mma_tiling_qkv_swizzle_qkv.cu * Create flash_attn_mma_share_qkv_smooth_qkv.cu * Update README.md * Update flash_attn_mma_tiling_qkv.cu * Update flash_attn.cc * Update flash_attn_mma.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md
1 parent 7e22265 commit 14fa9e7

9 files changed

+982
-26
lines changed

README.md

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,19 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
5252
|✔️|✔️|✔️|✔️|
5353
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
5454
|✔️|✔️|✔️|✔️|
55-
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**|
55+
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QKV Fine-grained Tiling**|
5656
|✔️|✔️|✔️|✔️|
5757

58-
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](#mma-tiling-qk) method can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
58+
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + QKV Fully Fine-grained Tiling](#mma-tiling-qkv) method can achieve **90 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
5959

60-
|Algorithm| (B,H,N,D) | 3080 Laptop | L20 | RTX 4090 |
60+
|Algorithm| (B,H,N,D) | RTX 3080 Laptop | L20 | RTX 4090 |
6161
|:---:|:---:|:---:|:---:|:---:|
6262
|FlashAttention-2|(1,8,8192,64)|37 TFLOPS|100 TFLOPS|145 TFLOPS|
6363
|split-q+share-qkv+stage2|(1,8,8192,64)|**55 TFLOPS**|99 TFLOPS|**221 TFLOPS**|
6464
|FlashAttention-2|(1,48,8192,64)|37 TFLOPS|109 TFLOPS|163 TFLOPS|
6565
|split-q+share-qkv+stage2|(1,48,8192,64)|**48 TFLOPS**|107 TFLOPS|**224 TFLOPS**|
6666
|SDPA(EFFICIENT ATTENTION)|(1,48,8192,512)|16 TFLOPS|58 TFLOPS|85 TFLOPS|
67-
|split-q+tiling-qk+swizzle-qk+stage2|(1,48,8192,512)|**23 TFLOPS**|**81 TFLOPS**|**127 TFLOPS**|
67+
|split-q+tiling-qkv+stage2|(1,48,8192,512)|**23 TFLOPS**|**90 TFLOPS**|**135 TFLOPS**|
6868
|Precision Errors vs FA2/SDPA| / | max: < ~1e-3 | min: ~0.0 | mean: < ~1e-5 |
6969

7070
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` method, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
@@ -123,13 +123,26 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half*
123123
<div id="mma-tiling-qk"></div>
124124

125125
```C++
126-
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
126+
// Fine-grained tiling at the MMA level for Q@K^T results in a constant SRAM usage of
127127
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
128128
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
129-
// extend D (head dimension) up to 1024. Stay tuned for updates ~
129+
// extend D (head dimension) up to 1024.
130130
__global__ void // Q, K, V, O -> [B, H, N, D]
131131
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
132132
```
133+
134+
- 📚 Split Q + QKV Fully Fine-grained Tiling (**O(Brx16)~O(1) SRAM** vs FA2 **O(4xBrxd) SRAM**)
135+
136+
<div id="mma-tiling-qkv"></div>
137+
138+
```C++
139+
// Fine-grained tiling at the MMA level for all Q@K^T and P@V results in a constant SRAM usage of
140+
// Br * 16 or Bc * 16 for Q, K, V, leading to an overall SRAM complexity of O(Br * 16). Consequently,
141+
// this approach allows us to run faster than SDPA w or w/o MMA Acc F32.
142+
__global__ void // Q, K, V, O -> [B, H, N, D]
143+
flash_attn_mma_stages_split_q_tiling_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
144+
```
145+
133146
## ©️Citations🎉🎉
134147

135148
```BibTeX
@@ -334,12 +347,19 @@ The kernels listed here will guide you through a step-by-step progression, rangi
334347

335348
|📖 CUDA Kernel| 📖 Elem DType| 📖 Acc DType| 📖 Docs | 📖 Level |
336349
|:---|:---|:---|:---|:---|
350+
| ✔️ [How to implement MMA smem swizzle*](./kernels/swizzle/mma_simple_swizzle.cu)|f16|f16|[link](./kernels/swizzle)|⭐️⭐️⭐️|
337351
| ✔️ [flash_attn_mma_stages_split_kv*](./kernels/flash-attn/mma/basic/flash_attn_mma_split_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
338352
| ✔️ [flash_attn_mma_stages_split_q*](./kernels/flash-attn/mma/basic/flash_attn_mma_split_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
339353
| ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/basic/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
340354
| ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/basic/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
341355
| ✔️ [flash_attn_mma_stages...tiling_qk*](./kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
342-
| ? [flash_attn_mma_stages...tiling_qkv*](./kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
356+
| ✔️ [flash_attn_mma_stages...tiling_qkv*](./kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
357+
| ✔️ [flash_attn_mma_stages...shared_kv{f32}*](./kernels/flash-attn/mma/basic/flash_attn_mma_share_kv_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
358+
| ✔️ [flash_attn_mma_stages...shared_qkv{f32}*](./kernels/flash-attn/mma/basic/flash_attn_mma_share_qkv_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
359+
| ✔️ [flash_attn_mma_stages...tiling_qk{f32}*](./kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qk_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
360+
| ✔️ [flash_attn_mma_stages...tiling_qkv{f32}*](./kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qkv_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
361+
| ✔️ [flash_attn_mma...shared_kv{f32}{rr}*](./kernels/flash-attn/mma/others/flash_attn_mma_share_kv_F32F16F16F32_rr.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
362+
| ✔️ [flash_attn_mma...shared_qkv{f32}{rr}*](./kernels/flash-attn/mma/others/flash_attn_mma_share_qkv_F32F16F16F32_rr.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
343363
| ✔️ [flash_attn_mma...shared_kv_swizzle{q}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_kv_swizzle_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
344364
| ✔️ [flash_attn_mma...shared_kv_swizzle{qk}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_kv_swizzle_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
345365
| ✔️ [flash_attn_mma...shared_kv_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_kv_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
@@ -348,17 +368,10 @@ The kernels listed here will guide you through a step-by-step progression, rangi
348368
| ✔️ [flash_attn_mma...shared_qkv_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_qkv_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
349369
| ✔️ [flash_attn_mma...tiling_qk_swizzle{q}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
350370
| ✔️ [flash_attn_mma...tiling_qk_swizzle{qk}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
351-
| ✔️ [flash_attn_mma...tiling_qk_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
371+
| ✔️ [flash_attn_mma...tiling_qk_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
352372
| ? [flash_attn_mma...tiling_qkv_swizzle{q}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
353373
| ? [flash_attn_mma...tiling_qkv_swizzle{qk}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
354374
| ? [flash_attn_mma...tiling_qkv_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
355-
| ✔️ [flash_attn_mma...shared_kv{f32}*](./kernels/flash-attn/mma/basic/flash_attn_mma_share_kv_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
356-
| ✔️ [flash_attn_mma...shared_kv{f32}{rr}*](./kernels/flash-attn/mma/others/flash_attn_mma_share_kv_F32F16F16F32_rr.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
357-
| ✔️ [flash_attn_mma...shared_qkv{f32}*](./kernels/flash-attn/mma/basic/flash_attn_mma_share_qkv_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
358-
| ✔️ [flash_attn_mma...shared_qkv{f32}{rr}*](./kernels/flash-attn/mma/others/flash_attn_mma_share_qkv_F32F16F16F32_rr.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
359-
| ✔️ [flash_attn_mma...tiling_qk{f32}*](./kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qk_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
360-
| ✔️ [flash_attn_mma...tiling_qkv{f32}*](./kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qkv_F32F16F16F32_rr.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
361-
| ✔️ [How to implement MMA smem swizzle*](./kernels/swizzle/mma_simple_swizzle.cu)|f16|f16|[link](./kernels/swizzle)|⭐️⭐️⭐️|
362375

363376
**rr**: means reduce registers usage (for `d>128`); **f32**: means MMA accumulate with FP32 dtype, otherwise, FP16. softmax Acc dtype is always be FP32 for high precision; **swizzle**: now, only support smem swizzle for MMA.
364377

kernels/flash-attn/README.md

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**|
1313
|✔️|✔️|✔️|✔️|
1414

15-
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ (👇Benchmark)
15+
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ (MMA Acc F16, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
1616

1717
|Algorithm| (B,H,N,D) | NVIDIA RTX 3080 Laptop | NVIDIA L20 | NVIDIA GeForce RTX 4090 |
1818
|:---:|:---:|:---:|:---:|:---:|
@@ -21,10 +21,10 @@ This repository's implementation of FlashAttention is intended solely for learni
2121
|FlashAttention-2|(1,48,8192,64)|37 TFLOPS|109 TFLOPS|163 TFLOPS|
2222
|split-q+share-qkv+stage2|(1,48,8192,64)|**48 TFLOPS**|107 TFLOPS|**224 TFLOPS**|
2323
|SDPA(EFFICIENT ATTENTION)|(1,48,8192,512)|16 TFLOPS|58 TFLOPS|85 TFLOPS|
24-
|split-q+tiling-qk+swizzle-qk+stage2|(1,48,8192,512)|**23 TFLOPS**|**81 TFLOPS**|**127 TFLOPS**|
24+
|split-q+tiling-qkv+stage2|(1,48,8192,512)|**23 TFLOPS**|**90 TFLOPS**|**135 TFLOPS**|
2525
|Precision Errors vs FA2/SDPA| / | max: < ~1e-3 | min: ~0.0 | mean: < ~1e-5 |
2626

27-
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](#mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA (EFFICIENT ATTENTION).
27+
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + QKV Fully Fine-grained Tiling](#mma-tiling-qkv) method can achieve **90 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~
2828

2929
## 📖 Contents
3030

@@ -34,6 +34,7 @@ For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#
3434
- [📚 Shared KV SMEM](#mma-share-kv)
3535
- [📚 Fully Shared QKV SMEM](#mma-share-qkv)
3636
- [📚 QK Fine-grained Tiling](#mma-tiling-qk)
37+
- [📚 QKV Fully Fine-grained Tiling](#mma-tiling-qkv)
3738
- [📖 Prerequisites](#prerequisites)
3839
- [📖 Installation](#install)
3940
- [📖 Performance](#perf)
@@ -91,6 +92,18 @@ __global__ void // Q, K, V, O -> [B, H, N, D]
9192
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
9293
```
9394

95+
- 📚 Split Q + QKV Fully Fine-grained Tiling (**O(Brx16)~O(1) SRAM** vs FA2 **O(4xBrxd) SRAM**)
96+
97+
<div id="mma-tiling-qkv"></div>
98+
99+
```C++
100+
// Fine-grained tiling at the MMA level for all Q@K^T and P@V results in a constant SRAM usage of
101+
// Br * 16 or Bc * 16 for Q, K, V, leading to an overall SRAM complexity of O(Br * 16). Consequently,
102+
// this approach allows us to run faster than SDPA w or w/o MMA Acc F32, e.g d>=512.
103+
__global__ void // Q, K, V, O -> [B, H, N, D]
104+
flash_attn_mma_stages_split_q_tiling_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
105+
```
106+
94107
## 📖 Prerequisites
95108
<div id="prerequisites"></div>
96109
@@ -165,12 +178,17 @@ python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D
165178
----------------------------------------------------------------------------------------------------------------------------------
166179
```
167180

168-
- Example: B=1, H=48, N=8192, `D=512` (NVIDIA L20), FA2 not supported, `QK Tiling` Faster than SDPA~🎉🎉
181+
- Example: B=1, H=48, N=8192, `D=16384` (NVIDIA L20), FA2 not supported, `QKV Tiling` Faster than SDPA~🎉🎉
169182
```bash
170-
python3 flash_attn_mma.py --B 1 --H 48 --D 512 --N 16384 --show-all --check --iters 10
171-
-----------------------------------------B=1, H=48, N=16384, D=512, Warmup: 1, Iters: 10------------------------------------------
172-
mma(split-q+tiling-qk+stage1): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:387.384224ms, TFLOPS:68.28 (+0.00%)
173-
mma(split-q+tiling-qk+stage2): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:325.593209ms, TFLOPS:81.24 (+18.98%)
174-
(sdpa): ['0.00790405 ', '-0.02330017 ', '0.00875854 '], time:452.067018ms, TFLOPS:58.51
175-
----------------------------------------------------------------------------------------------------------------------------------
183+
---------------------------------------------------B=1, H=48, N=16384, D=512, Warmup: 1, Iters: 10----------------------------------------------------
184+
mma(split-q+tiling-qk+stage1): ['-0.00386429 ', '0.00828552 ', '0.01831055 '], time:374.5436ms, TFLOPS:70.63 (+0.00%)
185+
mma(split-q+tiling-qk+stage2): ['-0.00386429 ', '0.00828552 ', '0.01831055 '], time:320.5431ms, TFLOPS:82.52 (+16.85%)
186+
mma(split-q+tiling-qk+swizzle-q+stage1): ['-0.00386429 ', '0.00828552 ', '0.01831055 '], time:370.0427ms, TFLOPS:71.48
187+
mma(split-q+tiling-qk+swizzle-q+stage2): ['-0.00386429 ', '0.00828552 ', '0.01831055 '], time:318.7205ms, TFLOPS:83.00 (+0.57%)
188+
mma(split-q+tiling-qk+swizzle-qk+stage1): ['-0.00386429 ', '0.00828552 ', '0.01831055 '], time:374.6879ms, TFLOPS:70.60
189+
mma(split-q+tiling-qk+swizzle-qk+stage2): ['-0.00386429 ', '0.00828552 ', '0.01831055 '], time:321.8044ms, TFLOPS:82.20
190+
mma(split-q+tiling-qkv+stage1): ['-0.00386429 ', '0.00828552 ', '0.01831055 '], time:383.5075ms, TFLOPS:68.97
191+
mma(split-q+tiling-qkv+stage2): ['-0.00386429 ', '0.00828552 ', '0.01831055 '], time:290.3107ms, TFLOPS:91.12 (+9.79%)
192+
(sdpa): ['-0.00387764 ', '0.00831604 ', '0.01831055 '], time:452.0751ms, TFLOPS:58.51
193+
------------------------------------------------------------------------------------------------------------------------------------------------------
176194
```

0 commit comments

Comments
 (0)