Skip to content

Commit 7289e65

Browse files
authored
fix flash-attn comments (#354)
* Update README.md * Update flash_attn_mma_share_kv.cu * fix flash-attn comments
1 parent 76d2867 commit 7289e65

30 files changed

+118
-120
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
<div align='center'>
88
<img src=https://cdn.rawgit.com/sindresorhus/awesome/d7305f38d29fed78fa85652e3a63e154dd8e8829/media/badge.svg >
99
<img src=https://img.shields.io/badge/Language-CUDA-brightgreen.svg >
10-
<img src=https://img.shields.io/github/watchers/xlite-dev/LeetCUDA?color=9cc >
11-
<img src=https://img.shields.io/github/forks/xlite-dev/LeetCUDA.svg?style=social >
12-
<img src=https://img.shields.io/github/stars/xlite-dev/LeetCUDA.svg?style=social >
13-
<img src=https://img.shields.io/badge/Release-v3.0.12-brightgreen.svg >
10+
<img src=https://img.shields.io/github/forks/xlite-dev/LeetCUDA.svg?style=dark >
11+
<img src=https://img.shields.io/github/stars/xlite-dev/LeetCUDA.svg?style=dark >
1412
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
1513
</div>
1614
</div>

kernels/flash-attn/mma/basic/flash_attn_mma_share_kv.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,16 @@ __global__ void __launch_bounds__(WARP_SIZE *kMmaTileSeqLenQ *kMmaTileSeqLenK)
173173
uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
174174
uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
175175

176-
// --------------------- Registers/SMEM for thread block
177-
// ------------------------- block m_old, l_old, store in lane, use float to
176+
// Registers/SMEM for thread block
177+
// block m_old, l_old, store in lane, use float to
178178
// keep precision.
179179
float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
180180
float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
181181
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_max_old, -INFINITY);
182182
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_sum_old, 0.0f);
183183

184-
// ---------------------- Registers for S=Q@K^T/O=P@V
185-
// ---------------------------- registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
184+
// Registers for S=Q@K^T/O=P@V
185+
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
186186
// and O=P[Br,Bc]@V[Bc,d]=[Br,d]. Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4],
187187
// e.g R_Q[4][1][4] 16 regs. By the way, we have to reduce R_Z to 0 regs and
188188
// reuse R_Q for collective store. Then we can load Q from smem only once and

kernels/flash-attn/mma/basic/flash_attn_mma_share_kv_F32F16F16F32.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,16 @@ __global__ void __launch_bounds__(WARP_SIZE *kMmaTileSeqLenQ *kMmaTileSeqLenK)
174174
uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
175175
uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
176176

177-
// --------------------- Registers/SMEM for thread block
178-
// ------------------------- block m_old, l_old, store in lane, use float to
177+
// Registers/SMEM for thread block
178+
// block m_old, l_old, store in lane, use float to
179179
// keep precision.
180180
float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
181181
float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
182182
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_max_old, -INFINITY);
183183
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_sum_old, 0.0f);
184184

185-
// ---------------------- Registers for S=Q@K^T/O=P@V
186-
// ---------------------------- registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
185+
// Registers for S=Q@K^T/O=P@V
186+
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
187187
// and O=P[Br,Bc]@V[Bc,d]=[Br,d]. Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4],
188188
// e.g R_Q[4][1][4] 16 regs. By the way, we have to reduce R_Z to 0 regs and
189189
// reuse R_Q for collective store. Then we can load Q from smem only once and

kernels/flash-attn/mma/basic/flash_attn_mma_share_qkv.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,16 @@ __global__ void __launch_bounds__(WARP_SIZE *kMmaTileSeqLenQ *kMmaTileSeqLenK)
166166
uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
167167
uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
168168

169-
// --------------------- Registers/SMEM for thread block
170-
// ------------------------- block m_old, l_old, store in lane, use float to
169+
// Registers/SMEM for thread block
170+
// block m_old, l_old, store in lane, use float to
171171
// keep precision.
172172
float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
173173
float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
174174
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_max_old, -INFINITY);
175175
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_sum_old, 0.0f);
176176

177-
// ---------------------- Registers for S=Q@K^T/O=P@V
178-
// ---------------------------- registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
177+
// Registers for S=Q@K^T/O=P@V
178+
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
179179
// and O=P[Br,Bc]@V[Bc,d]=[Br,d]. Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4],
180180
// e.g R_Q[4][1][4] 16 regs. By the way, we have to reduce R_Z to 0 regs and
181181
// reuse R_Q for collective store. Then we can load Q from smem only once and

kernels/flash-attn/mma/basic/flash_attn_mma_share_qkv_F32F16F16F32.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,16 @@ __global__ void __launch_bounds__(WARP_SIZE *kMmaTileSeqLenQ *kMmaTileSeqLenK)
167167
uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
168168
uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
169169

170-
// --------------------- Registers/SMEM for thread block
171-
// ------------------------- block m_old, l_old, store in lane, use float to
170+
// Registers/SMEM for thread block
171+
// block m_old, l_old, store in lane, use float to
172172
// keep precision.
173173
float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
174174
float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
175175
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_max_old, -INFINITY);
176176
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_sum_old, 0.0f);
177177

178-
// ---------------------- Registers for S=Q@K^T/O=P@V
179-
// ---------------------------- registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
178+
// Registers for S=Q@K^T/O=P@V
179+
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
180180
// and O=P[Br,Bc]@V[Bc,d]=[Br,d]. Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4],
181181
// e.g R_Q[4][1][4] 16 regs. By the way, we have to reduce R_Z to 0 regs and
182182
// reuse R_Q for collective store. Then we can load Q from smem only once and

kernels/flash-attn/mma/basic/flash_attn_mma_split_kv.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ __global__ void __launch_bounds__(WARP_SIZE *kMmaTileSeqLenQ *kMmaTileSeqLenK)
133133
uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
134134
uint32_t smem_S_base_ptr = __cvta_generic_to_shared(S_tile_smem);
135135

136-
// --------------------- Registers/SMEM for thread block
137-
// ------------------------- block m_old, l_old, store in lane, use float to
136+
// Registers/SMEM for thread block
137+
// block m_old, l_old, store in lane, use float to
138138
// keep precision.
139139
float lane_block_row_max_old[kWarpTileSeqLenQ][2];
140140
float lane_block_row_sum_old[kWarpTileSeqLenQ][2];
@@ -146,8 +146,8 @@ __global__ void __launch_bounds__(WARP_SIZE *kMmaTileSeqLenQ *kMmaTileSeqLenK)
146146
__shared__ float block_row_max_new_smem[Br][kMmaTileSeqLenK + 1];
147147
__shared__ float block_row_sum_new_smem[Br][kMmaTileSeqLenK + 1];
148148

149-
// ---------------------- Registers for S=Q@K^T/O=P@V
150-
// ---------------------------- registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
149+
// Registers for S=Q@K^T/O=P@V
150+
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
151151
// and O=P[Br,Bc]@V[Bc,d]=[Br,d].
152152
uint32_t R_Q[kWarpTileSeqLenQ][4];
153153
uint32_t R_K[kWarpTileSeqLenK][2];

kernels/flash-attn/mma/basic/flash_attn_mma_split_q.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,16 @@ __global__ void __launch_bounds__(WARP_SIZE *kMmaTileSeqLenQ *kMmaTileSeqLenK)
154154
uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
155155
uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
156156

157-
// --------------------- Registers/SMEM for thread block
158-
// ------------------------- block m_old, l_old, store in lane, use float to
157+
// Registers/SMEM for thread block
158+
// block m_old, l_old, store in lane, use float to
159159
// keep precision.
160160
float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
161161
float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
162162
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_max_old, -INFINITY);
163163
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_sum_old, 0.0f);
164164

165-
// ---------------------- Registers for S=Q@K^T/O=P@V
166-
// ---------------------------- registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
165+
// Registers for S=Q@K^T/O=P@V
166+
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
167167
// and O=P[Br,Bc]@V[Bc,d]=[Br,d].
168168
uint32_t R_Q[kWarpTileSeqLenQ][4]; // [1][4]
169169
uint32_t R_K[kWarpTileSeqLenK][2]; // [8][2]

kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qk.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,16 +187,16 @@ __global__ void __launch_bounds__(WARP_SIZE *kMmaTileSeqLenQ *kMmaTileSeqLenK)
187187
uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
188188
uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
189189

190-
// --------------------- Registers/SMEM for thread block
191-
// ------------------------- block m_old, l_old, store in lane, use float to
190+
// Registers/SMEM for thread block
191+
// block m_old, l_old, store in lane, use float to
192192
// keep precision.
193193
float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
194194
float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
195195
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_max_old, -INFINITY);
196196
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_sum_old, 0.0f);
197197

198-
// ---------------------- Registers for S=Q@K^T/O=P@V
199-
// ---------------------------- registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
198+
// Registers for S=Q@K^T/O=P@V
199+
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
200200
// and O=P[Br,Bc]@V[Bc,d]=[Br,d].
201201
uint32_t R_Q[kWarpTileSeqLenQ][4]; // [1][4]
202202
uint32_t R_K[kWarpTileSeqLenK][2]; // [8][2]

kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qk_F32F16F16F32.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,16 @@ __global__ void __launch_bounds__(WARP_SIZE *kMmaTileSeqLenQ *kMmaTileSeqLenK)
188188
uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
189189
uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
190190

191-
// --------------------- Registers/SMEM for thread block
192-
// ------------------------- block m_old, l_old, store in lane, use float to
191+
// Registers/SMEM for thread block
192+
// block m_old, l_old, store in lane, use float to
193193
// keep precision.
194194
float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
195195
float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
196196
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_max_old, -INFINITY);
197197
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_sum_old, 0.0f);
198198

199-
// ---------------------- Registers for S=Q@K^T/O=P@V
200-
// ---------------------------- registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
199+
// Registers for S=Q@K^T/O=P@V
200+
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
201201
// and O=P[Br,Bc]@V[Bc,d]=[Br,d].
202202
uint32_t R_Q[kWarpTileSeqLenQ][4]; // [1][4]
203203
uint32_t R_K[kWarpTileSeqLenK][2]; // [8][2]

kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qkv.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,16 @@ __global__ void __launch_bounds__(WARP_SIZE *kMmaTileSeqLenQ *kMmaTileSeqLenK)
170170
uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
171171
uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
172172

173-
// --------------------- Registers/SMEM for thread block
174-
// ------------------------- block m_old, l_old, store in lane, use float to
173+
// Registers/SMEM for thread block
174+
// block m_old, l_old, store in lane, use float to
175175
// keep precision.
176176
float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
177177
float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
178178
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_max_old, -INFINITY);
179179
fill_2D_regs<float, kWarpTileSeqLenQ, 2>(lane_block_row_sum_old, 0.0f);
180180

181-
// ---------------------- Registers for S=Q@K^T/O=P@V
182-
// ---------------------------- registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
181+
// Registers for S=Q@K^T/O=P@V
182+
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc]
183183
// and O=P[Br,Bc]@V[Bc,d]=[Br,d].
184184
uint32_t R_Q[kWarpTileSeqLenQ][4]; // [1][4]
185185
uint32_t R_K[kWarpTileSeqLenK][2]; // [8][2]

0 commit comments

Comments
 (0)