Skip to content

Commit 33adaa7

Browse files
committed
post: Efficient Attention
FlashAttentio Forward
1 parent cae6316 commit 33adaa7

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

_posts/DeepLearning/Kernel Fusion/2025-03-07-fused.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,23 @@ FlashAttention의 forward는 이와 같습니다. algorithm의 line별로 설명
243243

244244
8. $$ Query,O, \ell , m $$의 block을 Cache에 load합니다.
245245

246+
9. $$Query,Key $$를 cache에서 register로 바로 load하여 $$ \mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^{T} \in \mathbb{R}^{B_r \times B_c} $$ 를 계산합니다.
247+
248+
10. $$ \tilde{m}_{ij} = \mathrm{rowmax}(\mathbf{S}_{ij}) \in \mathbb{R}^{B_r}, \quad \tilde{\mathbf{P}}_{ij} = \exp(\mathbf{S}_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r \times B_c} , \tilde{\ell}_{ij} = \mathrm{rowsum}(\tilde{\mathbf{P}}_{ij}) \in \mathbb{R}^{B_r} $$ 를 계산합니다. $$ \tilde{m}_{ij} $$ 는 Online softmax에서의 $$ m_i $$ ,Online self-attention 에서의 $$ m^* $$와 동일한 역할을 합니다. 이번에는 $$ \mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^{T} \in \mathbb{R}^{B_r \times B_c} $$ 이므로 $$ \tilde{m}_{ij} \in \mathbb{R}^{B_r} $$가 됩니다. $$ \tilde{\ell}_{ij} $$ 는 각 Block의 exponential의 합을 의미합니다. $$ \tilde{\mathbf{P}}_{ij} \in \mathbb{R}^{B_r \times B_c} $$ 이므로 $$ B_c $$ 차원을 reduction하는 방향으로 max를 계산하므로 $$ \tilde{\ell}_{ij} \in \mathbb{R}^{B_r} $$가 됩니다.$$ \tilde{\ell}_{ij} $$는 Online softmax에서의 $$d_S$$ 와 같은 역할을 합니다.
249+
250+
11.
251+
246252
### Backward
247253

248254

249255
## Conclusion
256+
$$
257+
\tilde{m}_{ij} = \mathrm{rowmax}(\mathbf{S}_{ij}) \in \mathbb{R}^{B_r}, \quad \tilde{\mathbf{P}}_{ij} = \exp(\mathbf{S}_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r \times B_c} , \tilde{\ell}_{ij} = \mathrm{rowsum}(\tilde{\mathbf{P}}_{ij}) \in \mathbb{R}^{B_r}
258+
$$
259+
250260

261+
$$
262+
\tilde{\ell}_{ij} = \mathrm{rowsum}(\tilde{\mathbf{P}}_{ij}) \in \mathbb{R}^{B_r}
263+
$$
251264

252-
$$ \mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^{T} \in \mathbb{R}^{B_r \times B_c} $$
253265
## References

0 commit comments

Comments
 (0)