Skip to content

Commit 5625c3a

Browse files
authored
feat: native attn support ring-attn (#643)
* feat: native attn support ring-attn * feat: native attn support ring
1 parent 8c331b8 commit 5625c3a

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

src/cache_dit/parallelism/attention/_attention_dispatch.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,6 @@ def _native_attention_forward_op(
164164
_save_ctx: bool = True,
165165
_parallel_config: Optional["ParallelConfig"] = None,
166166
):
167-
# Native attention does not return_lse
168-
if return_lse:
169-
raise ValueError("Native attention does not support return_lse=True")
170-
171167
# used for backward pass
172168
if _save_ctx:
173169
ctx.save_for_backward(query, key, value)
@@ -177,6 +173,24 @@ def _native_attention_forward_op(
177173
ctx.scale = scale
178174
ctx.enable_gqa = enable_gqa
179175

176+
if return_lse:
177+
# Use native flash attention to get lse if return_lse is True
178+
if attn_mask is not None:
179+
raise ValueError(
180+
"`attn_mask` is not yet supported for native flash attention with lse."
181+
)
182+
out, lse = torch.ops.aten._scaled_dot_product_flash_attention(
183+
query.transpose(1, 2),
184+
key.transpose(1, 2),
185+
value.transpose(1, 2),
186+
dropout_p=dropout_p,
187+
is_causal=is_causal,
188+
scale=scale,
189+
)[:2]
190+
out = out.transpose(1, 2)
191+
lse = lse.transpose(1, 2)
192+
return out, lse
193+
180194
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
181195
out = torch.nn.functional.scaled_dot_product_attention(
182196
query=query,

0 commit comments

Comments
 (0)