diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 727a27d0f48..dd3040108c8 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -74,6 +74,8 @@ FLASH_ATTN_VERSION = None +from fastdeploy.model_executor.utils import try_import + def init_flash_attn_version(): """ @@ -85,7 +87,29 @@ def init_flash_attn_version(): if sm_version >= 100: try: paddle.enable_compat(scope={"cutlass"}) - from flash_mask.cute.interface import flashmask_attention as fa4 + try: + old_api = try_import(["paddlefleet.ops"]) + if old_api is not None: + from paddlefleet.ops import is_flash_mask_available + + if is_flash_mask_available(): + from paddlefleet.ops.flash_mask.cute.interface import ( + flashmask_attention as fa4, + ) + else: + raise ModuleNotFoundError("flash_mask not available.") + else: + from paddlefleet_ops import is_flash_mask_available + + if is_flash_mask_available(): + from paddlefleet_ops.flash_mask.cute.interface import ( + flashmask_attention as fa4, + ) + else: + raise ModuleNotFoundError("flash_mask not available.") + + except (ImportError, ModuleNotFoundError): + logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.") global flashmask_attention_v4 flashmask_attention_v4 = fa4 diff --git a/requirements.txt b/requirements.txt index b2941d4aa00..a2f2811727b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,5 +47,4 @@ aistudio_sdk p2pstore py-cpuinfo flashinfer-python-paddle @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.2-py3-none-any.whl -flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl transformers>=4.55.1,<5.0.0 diff --git a/tests/layers/test_flash_attn_func.py b/tests/layers/test_flash_attn_func.py index bb76d8837a8..4bd26f49393 100644 --- a/tests/layers/test_flash_attn_func.py +++ b/tests/layers/test_flash_attn_func.py @@ -14,10 +14,14 @@ from __future__ import annotations +import sys +import types import unittest +from unittest import mock import paddle +from fastdeploy.model_executor.layers.attention import flash_attn_backend from fastdeploy.model_executor.layers.attention.flash_attn_backend import ( flash_attn_func, ) @@ -205,5 +209,182 @@ def test_fa4(self): ) +class TestInitFlashAttnVersion(unittest.TestCase): + """Tests for the init_flash_attn_version FA4 import branch (sm>=100).""" + + _MODULE_NAMES = ( + "paddlefleet", + "paddlefleet.ops", + "paddlefleet.ops.flash_mask", + "paddlefleet.ops.flash_mask.cute", + "paddlefleet.ops.flash_mask.cute.interface", + "paddlefleet_ops", + "paddlefleet_ops.flash_mask", + "paddlefleet_ops.flash_mask.cute", + "paddlefleet_ops.flash_mask.cute.interface", + ) + + def setUp(self): + # Save state to restore after each test. + self._saved_version = flash_attn_backend.FLASH_ATTN_VERSION + self._saved_v4 = flash_attn_backend.flashmask_attention_v4 + self._saved_modules = {name: sys.modules.get(name) for name in self._MODULE_NAMES} + # Make sure each test starts with a clean module state. + for name in self._MODULE_NAMES: + sys.modules.pop(name, None) + + def _block_old_api(self): + """Force `paddlefleet.ops` import to fail regardless of what is installed.""" + # Setting sys.modules[name] = None makes importlib.import_module raise ImportError. + sys.modules["paddlefleet"] = None + sys.modules["paddlefleet.ops"] = None + + def _block_new_api(self): + """Force `paddlefleet_ops` import to fail regardless of what is installed.""" + sys.modules["paddlefleet_ops"] = None + + def tearDown(self): + flash_attn_backend.FLASH_ATTN_VERSION = self._saved_version + flash_attn_backend.flashmask_attention_v4 = self._saved_v4 + for name, mod in self._saved_modules.items(): + if mod is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = mod + + def _install_fake_paddlefleet_old_api(self, is_available: bool): + """Inject fake `paddlefleet.ops` (old API) modules.""" + pkg = types.ModuleType("paddlefleet") + pkg.__path__ = [] + ops = types.ModuleType("paddlefleet.ops") + ops.__path__ = [] + ops.is_flash_mask_available = lambda: is_available + pkg.ops = ops + flash_mask = types.ModuleType("paddlefleet.ops.flash_mask") + flash_mask.__path__ = [] + cute = types.ModuleType("paddlefleet.ops.flash_mask.cute") + cute.__path__ = [] + interface = types.ModuleType("paddlefleet.ops.flash_mask.cute.interface") + interface.flashmask_attention = mock.MagicMock(name="fa4_old") + + sys.modules["paddlefleet"] = pkg + sys.modules["paddlefleet.ops"] = ops + sys.modules["paddlefleet.ops.flash_mask"] = flash_mask + sys.modules["paddlefleet.ops.flash_mask.cute"] = cute + sys.modules["paddlefleet.ops.flash_mask.cute.interface"] = interface + return interface.flashmask_attention + + def _install_fake_paddlefleet_new_api(self, is_available: bool): + """Inject fake `paddlefleet_ops` (new API) modules.""" + ops = types.ModuleType("paddlefleet_ops") + ops.__path__ = [] + ops.is_flash_mask_available = lambda: is_available + flash_mask = types.ModuleType("paddlefleet_ops.flash_mask") + flash_mask.__path__ = [] + cute = types.ModuleType("paddlefleet_ops.flash_mask.cute") + cute.__path__ = [] + interface = types.ModuleType("paddlefleet_ops.flash_mask.cute.interface") + interface.flashmask_attention = mock.MagicMock(name="fa4_new") + + sys.modules["paddlefleet_ops"] = ops + sys.modules["paddlefleet_ops.flash_mask"] = flash_mask + sys.modules["paddlefleet_ops.flash_mask.cute"] = cute + sys.modules["paddlefleet_ops.flash_mask.cute.interface"] = interface + return interface.flashmask_attention + + def test_fa4_old_api_import_success(self): + """Old API (`paddlefleet.ops`) is preferred when available.""" + fake_fa4 = self._install_fake_paddlefleet_old_api(is_available=True) + # Also install new API to verify the old API takes precedence. + new_fa4 = self._install_fake_paddlefleet_new_api(is_available=True) + flash_attn_backend.FLASH_ATTN_VERSION = None + flash_attn_backend.flashmask_attention_v4 = None + + with ( + mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True), + mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), + mock.patch.object(paddle, "enable_compat", create=True, return_value=None), + ): + flash_attn_backend.init_flash_attn_version() + + self.assertEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) + self.assertIs(flash_attn_backend.flashmask_attention_v4, fake_fa4) + self.assertIsNot(flash_attn_backend.flashmask_attention_v4, new_fa4) + + def test_fa4_old_api_flash_mask_unavailable(self): + """Old API present but `is_flash_mask_available` is False.""" + self._install_fake_paddlefleet_old_api(is_available=False) + self._block_new_api() + flash_attn_backend.FLASH_ATTN_VERSION = None + flash_attn_backend.flashmask_attention_v4 = None + + with ( + mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True), + mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), + mock.patch.object(paddle, "enable_compat", create=True, return_value=None), + ): + try: + flash_attn_backend.init_flash_attn_version() + except NameError: + pass + + self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) + + def test_fa4_new_api_import_success(self): + """Falls back to new API (`paddlefleet_ops`) when old API is missing.""" + fake_fa4 = self._install_fake_paddlefleet_new_api(is_available=True) + self._block_old_api() + flash_attn_backend.FLASH_ATTN_VERSION = None + flash_attn_backend.flashmask_attention_v4 = None + + with ( + mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True), + mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), + mock.patch.object(paddle, "enable_compat", create=True, return_value=None), + ): + flash_attn_backend.init_flash_attn_version() + + self.assertEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) + self.assertIs(flash_attn_backend.flashmask_attention_v4, fake_fa4) + + def test_fa4_new_api_flash_mask_unavailable(self): + """New API present but `is_flash_mask_available` is False.""" + self._install_fake_paddlefleet_new_api(is_available=False) + self._block_old_api() + flash_attn_backend.FLASH_ATTN_VERSION = None + flash_attn_backend.flashmask_attention_v4 = None + + with ( + mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True), + mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), + mock.patch.object(paddle, "enable_compat", create=True, return_value=None), + ): + try: + flash_attn_backend.init_flash_attn_version() + except NameError: + pass + + self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) + + def test_fa4_paddlefleet_import_error(self): + """Neither old nor new API is importable.""" + self._block_old_api() + self._block_new_api() + flash_attn_backend.FLASH_ATTN_VERSION = None + flash_attn_backend.flashmask_attention_v4 = None + + with ( + mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True), + mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), + mock.patch.object(paddle, "enable_compat", create=True, return_value=None), + ): + try: + flash_attn_backend.init_flash_attn_version() + except NameError: + pass + + self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) + + if __name__ == "__main__": unittest.main()