Skip to content

Commit 93fb1db

Browse files
committed
post :Efficient Attention
FlashAttention Forward
1 parent 33adaa7 commit 93fb1db

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

_posts/DeepLearning/Kernel Fusion/2025-03-07-fused.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,11 @@ FlashAttention의 forward는 이와 같습니다. algorithm의 line별로 설명
247247

248248
10. $$ \tilde{m}_{ij} = \mathrm{rowmax}(\mathbf{S}_{ij}) \in \mathbb{R}^{B_r}, \quad \tilde{\mathbf{P}}_{ij} = \exp(\mathbf{S}_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r \times B_c} , \tilde{\ell}_{ij} = \mathrm{rowsum}(\tilde{\mathbf{P}}_{ij}) \in \mathbb{R}^{B_r} $$ 를 계산합니다. $$ \tilde{m}_{ij} $$ 는 Online softmax에서의 $$ m_i $$ ,Online self-attention 에서의 $$ m^* $$와 동일한 역할을 합니다. 이번에는 $$ \mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^{T} \in \mathbb{R}^{B_r \times B_c} $$ 이므로 $$ \tilde{m}_{ij} \in \mathbb{R}^{B_r} $$가 됩니다. $$ \tilde{\ell}_{ij} $$ 는 각 Block의 exponential의 합을 의미합니다. $$ \tilde{\mathbf{P}}_{ij} \in \mathbb{R}^{B_r \times B_c} $$ 이므로 $$ B_c $$ 차원을 reduction하는 방향으로 max를 계산하므로 $$ \tilde{\ell}_{ij} \in \mathbb{R}^{B_r} $$가 됩니다.$$ \tilde{\ell}_{ij} $$는 Online softmax에서의 $$d_S$$ 와 같은 역할을 합니다.
249249

250-
11.
250+
11. Online self-attention처럼 max,분모를 새로 update해야 합니다. $$ m_i^{new} $$ 는 새로운 max 입니다 . $$ m_i^{\mathrm{new}} = \max(m_i, \tilde{m}_{ij}) \in \mathbb{R}^{B_r} $$ 는 Online softmax의 $$m^* = \max(m^*,s_i)$$ 와 동등한 역할을 합니다. $$ \ell_i^{\mathrm{new}} $$ 새로운 분모의 역할을 하게 됩니다. $$ \ell_i^{\mathrm{new}} = e^{m_i - m_i^{\mathrm{new}}} \, \ell_i + e^{\tilde{m}_{ij} - m_i^{\mathrm{new}}} \, \tilde{\ell}_{ij} \in \mathbb{R}^{B_r} $$ 는 Online softmax에서 $$d_S \leftarrow d_{S-1} \, e^{\,m_{S-1} - m_S} + e^{\,x_S - m_S} $$ 와 동등한 역할을 합니다.
251+
252+
12. $$ \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1} $$ 는 기존 $$ O $$에 새로운 분모를 곱해주는 과정입니다. $$ \mathrm{diag}(\ell_i) $$ 는 기존의 곱해진 분모를 없애주기 위한 용도입니다. $$ e^{m_i - m_i^{\mathrm{new}}} \, \mathbf{O}_i $$ 는 기존의 $$ O $$는 이전의 max로 safe softmax를 적용했기에 기존의 max인 $$ m_i $$ 를 더해서 없애주고 새로운 max인 $$ m_i^{\mathrm{new}} $$ 를 빼주는 과정입니다. $$ \tilde{\mathbf{P}}_{ij} $$ 는 block의 max인 $$ \tilde{m}_{ij} $$ 를 빼주면서 safe softmax를 적용했습니다. 따라서, 새로운 max를 빼주기 위해서 $$ e^{\tilde{m}_{ij} - m_i^{\mathrm{new}}} $$ 를 곱해줍니다. 이를 통해 , $$ O $$를 update 합니다.
253+
254+
13. $$ \ell_i $$ , $$m_i$$ 를 새로운 max,분모로 update를 해줍니다.
251255

252256
### Backward
253257

@@ -260,6 +264,34 @@ $$
260264

261265
$$
262266
\tilde{\ell}_{ij} = \mathrm{rowsum}(\tilde{\mathbf{P}}_{ij}) \in \mathbb{R}^{B_r}
267+
$$
268+
269+
$$ m_i^{new} $$
270+
271+
$$ \ell $$
272+
273+
274+
$$
275+
\ell_i^{\mathrm{new}} = e^{m_i - m_i^{\mathrm{new}}} \, \ell_i + e^{\tilde{m}_{ij} - m_i^{\mathrm{new}}} \, \tilde{\ell}_{ij} \in \mathbb{R}^{B_r}
276+
$$
277+
278+
263279
$$
280+
\mathbf{O}_i \leftarrow \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1} \left( \mathrm{diag}(\ell_i) \, e^{m_i - m_i^{\mathrm{new}}} \, \mathbf{O}_i + e^{\tilde{m}_{ij} - m_i^{\mathrm{new}}} \, \tilde{\mathbf{P}}_{ij} \, \mathbf{V}_j \right)
281+
$$
282+
283+
284+
$$ \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1} $$
285+
286+
287+
$$ m_i^{\mathrm{new}} = \max(m_i, \tilde{m}_{ij}) \in \mathbb{R}^{B_r} $$
288+
289+
290+
$$ \mathrm{diag}(\ell_i) $$
291+
292+
$$ e^{m_i - m_i^{\mathrm{new}}} \, \mathbf{O}_i $$
293+
294+
$$ e^{\tilde{m}_{ij} - m_i^{\mathrm{new}}} $$
264295

296+
$$ \ell_i $$
265297
## References

0 commit comments

Comments
 (0)