Skip to content

Commit 130fd8d

Browse files
authored
[core] use kernels to support _flash_3_hub attention backend (huggingface#12236)
* feat: try loading fa3 using kernels when available. * up * change to Hub. * up * up * up * switch env var. * up * up * up * up * up * up
1 parent bcd4d77 commit 130fd8d

File tree

3 files changed

+88
-1
lines changed

3 files changed

+88
-1
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_flash_attn_3_available,
2727
is_flash_attn_available,
2828
is_flash_attn_version,
29+
is_kernels_available,
2930
is_sageattention_available,
3031
is_sageattention_version,
3132
is_torch_npu_available,
@@ -35,7 +36,7 @@
3536
is_xformers_available,
3637
is_xformers_version,
3738
)
38-
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
39+
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
3940

4041

4142
_REQUIRED_FLASH_VERSION = "2.6.3"
@@ -67,6 +68,17 @@
6768
flash_attn_3_func = None
6869
flash_attn_3_varlen_func = None
6970

71+
if DIFFUSERS_ENABLE_HUB_KERNELS:
72+
if not is_kernels_available():
73+
raise ImportError(
74+
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
75+
)
76+
from ..utils.kernels_utils import _get_fa3_from_hub
77+
78+
flash_attn_interface_hub = _get_fa3_from_hub()
79+
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
80+
else:
81+
flash_attn_3_func_hub = None
7082

7183
if _CAN_USE_SAGE_ATTN:
7284
from sageattention import (
@@ -153,6 +165,8 @@ class AttentionBackendName(str, Enum):
153165
FLASH_VARLEN = "flash_varlen"
154166
_FLASH_3 = "_flash_3"
155167
_FLASH_VARLEN_3 = "_flash_varlen_3"
168+
_FLASH_3_HUB = "_flash_3_hub"
169+
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
156170

157171
# PyTorch native
158172
FLEX = "flex"
@@ -351,6 +365,17 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
351365
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
352366
)
353367

368+
# TODO: add support Hub variant of FA3 varlen later
369+
elif backend in [AttentionBackendName._FLASH_3_HUB]:
370+
if not DIFFUSERS_ENABLE_HUB_KERNELS:
371+
raise RuntimeError(
372+
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
373+
)
374+
if not is_kernels_available():
375+
raise RuntimeError(
376+
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
377+
)
378+
354379
elif backend in [
355380
AttentionBackendName.SAGE,
356381
AttentionBackendName.SAGE_VARLEN,
@@ -657,6 +682,44 @@ def _flash_attention_3(
657682
return (out, lse) if return_attn_probs else out
658683

659684

685+
@_AttentionBackendRegistry.register(
686+
AttentionBackendName._FLASH_3_HUB,
687+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
688+
)
689+
def _flash_attention_3_hub(
690+
query: torch.Tensor,
691+
key: torch.Tensor,
692+
value: torch.Tensor,
693+
scale: Optional[float] = None,
694+
is_causal: bool = False,
695+
window_size: Tuple[int, int] = (-1, -1),
696+
softcap: float = 0.0,
697+
deterministic: bool = False,
698+
return_attn_probs: bool = False,
699+
) -> torch.Tensor:
700+
out = flash_attn_3_func_hub(
701+
q=query,
702+
k=key,
703+
v=value,
704+
softmax_scale=scale,
705+
causal=is_causal,
706+
qv=None,
707+
q_descale=None,
708+
k_descale=None,
709+
v_descale=None,
710+
window_size=window_size,
711+
softcap=softcap,
712+
num_splits=1,
713+
pack_gqa=None,
714+
deterministic=deterministic,
715+
sm_margin=0,
716+
return_attn_probs=return_attn_probs,
717+
)
718+
# When `return_attn_probs` is True, the above returns a tuple of
719+
# actual outputs and lse.
720+
return (out[0], out[1]) if return_attn_probs else out
721+
722+
660723
@_AttentionBackendRegistry.register(
661724
AttentionBackendName._FLASH_VARLEN_3,
662725
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

src/diffusers/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
4747
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
4848
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
49+
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
4950

5051
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
5152
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from ..utils import get_logger
2+
from .import_utils import is_kernels_available
3+
4+
5+
logger = get_logger(__name__)
6+
7+
8+
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
9+
10+
11+
def _get_fa3_from_hub():
12+
if not is_kernels_available():
13+
return None
14+
else:
15+
from kernels import get_kernel
16+
17+
try:
18+
# TODO: temporary revision for now. Remove when merged upstream into `main`.
19+
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
20+
return flash_attn_3_hub
21+
except Exception as e:
22+
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
23+
raise

0 commit comments

Comments
 (0)