From e222e368a98261e3be1dd802f82d423d617e4027 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 19 May 2026 22:39:01 +0800 Subject: [PATCH 1/5] remove flash-mask and using fleet --- .../layers/attention/flash_attn_backend.py | 13 ++++++++++++- requirements.txt | 1 - 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 727a27d0f48..ea94b05d438 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -85,7 +85,18 @@ def init_flash_attn_version(): if sm_version >= 100: try: paddle.enable_compat(scope={"cutlass"}) - from flash_mask.cute.interface import flashmask_attention as fa4 + try: + from paddlefleet.ops import is_flash_mask_available + + if is_flash_mask_available(): + from paddlefleet.ops.flash_mask.cute.interface import ( + flashmask_attention as fa4, + ) + else: + raise ModuleNotFoundError("flash_mask not available.") + + except (ImportError, ModuleNotFoundError): + logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.") global flashmask_attention_v4 flashmask_attention_v4 = fa4 diff --git a/requirements.txt b/requirements.txt index b2941d4aa00..a2f2811727b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,5 +47,4 @@ aistudio_sdk p2pstore py-cpuinfo flashinfer-python-paddle @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.2-py3-none-any.whl -flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl transformers>=4.55.1,<5.0.0 From 7d4bb25f09e072124767f2df841fabee45fc69e1 Mon Sep 17 00:00:00 2001 From: bingoo <1575938147@qq.com> Date: Wed, 20 May 2026 10:29:01 +0800 Subject: [PATCH 2/5] add test --- tests/layers/test_flash_attn_func.py | 117 +++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/tests/layers/test_flash_attn_func.py b/tests/layers/test_flash_attn_func.py index bb76d8837a8..3d8fb75605c 100644 --- a/tests/layers/test_flash_attn_func.py +++ b/tests/layers/test_flash_attn_func.py @@ -14,10 +14,14 @@ from __future__ import annotations +import sys +import types import unittest +from unittest import mock import paddle +from fastdeploy.model_executor.layers.attention import flash_attn_backend from fastdeploy.model_executor.layers.attention.flash_attn_backend import ( flash_attn_func, ) @@ -205,5 +209,118 @@ def test_fa4(self): ) +class TestInitFlashAttnVersion(unittest.TestCase): + """Tests for the init_flash_attn_version FA4 import branch (sm>=100).""" + + def setUp(self): + # Save state to restore after each test. + self._saved_version = flash_attn_backend.FLASH_ATTN_VERSION + self._saved_v4 = flash_attn_backend.flashmask_attention_v4 + self._saved_modules = { + name: sys.modules.get(name) + for name in ( + "paddlefleet", + "paddlefleet.ops", + "paddlefleet.ops.flash_mask", + "paddlefleet.ops.flash_mask.cute", + "paddlefleet.ops.flash_mask.cute.interface", + ) + } + + def tearDown(self): + flash_attn_backend.FLASH_ATTN_VERSION = self._saved_version + flash_attn_backend.flashmask_attention_v4 = self._saved_v4 + for name, mod in self._saved_modules.items(): + if mod is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = mod + + def _install_fake_paddlefleet(self, is_available: bool): + """Inject fake paddlefleet modules so the inner imports succeed.""" + pkg = types.ModuleType("paddlefleet") + pkg.__path__ = [] + ops = types.ModuleType("paddlefleet.ops") + ops.__path__ = [] + ops.is_flash_mask_available = lambda: is_available + flash_mask = types.ModuleType("paddlefleet.ops.flash_mask") + flash_mask.__path__ = [] + cute = types.ModuleType("paddlefleet.ops.flash_mask.cute") + cute.__path__ = [] + interface = types.ModuleType("paddlefleet.ops.flash_mask.cute.interface") + interface.flashmask_attention = mock.MagicMock(name="fa4") + + sys.modules["paddlefleet"] = pkg + sys.modules["paddlefleet.ops"] = ops + sys.modules["paddlefleet.ops.flash_mask"] = flash_mask + sys.modules["paddlefleet.ops.flash_mask.cute"] = cute + sys.modules["paddlefleet.ops.flash_mask.cute.interface"] = interface + return interface.flashmask_attention + + def test_fa4_import_success(self): + """Covers lines 88, 89, 91, 92 (is_flash_mask_available True branch).""" + fake_fa4 = self._install_fake_paddlefleet(is_available=True) + flash_attn_backend.FLASH_ATTN_VERSION = None + flash_attn_backend.flashmask_attention_v4 = None + + with ( + mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True), + mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), + mock.patch.object(paddle, "enable_compat", create=True, return_value=None), + ): + flash_attn_backend.init_flash_attn_version() + + self.assertEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) + self.assertIs(flash_attn_backend.flashmask_attention_v4, fake_fa4) + + def test_fa4_flash_mask_unavailable(self): + """Covers lines 88, 89, 91, 96, 98, 99 (raise + except path).""" + self._install_fake_paddlefleet(is_available=False) + flash_attn_backend.FLASH_ATTN_VERSION = None + flash_attn_backend.flashmask_attention_v4 = None + + with ( + mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True), + mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), + mock.patch.object(paddle, "enable_compat", create=True, return_value=None), + ): + # The inner except swallows ModuleNotFoundError, but `fa4` is then + # unbound, so the outer block raises NameError (not ImportError), + # which propagates. Verify the inner except actually executed by + # checking that FA4 was not selected. + try: + flash_attn_backend.init_flash_attn_version() + except NameError: + pass + + self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) + + def test_fa4_paddlefleet_import_error(self): + """Covers lines 88, 89, 98, 99 (ImportError caught by inner except).""" + # Ensure paddlefleet import fails. + for name in ( + "paddlefleet", + "paddlefleet.ops", + "paddlefleet.ops.flash_mask", + "paddlefleet.ops.flash_mask.cute", + "paddlefleet.ops.flash_mask.cute.interface", + ): + sys.modules.pop(name, None) + flash_attn_backend.FLASH_ATTN_VERSION = None + flash_attn_backend.flashmask_attention_v4 = None + + with ( + mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True), + mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), + mock.patch.object(paddle, "enable_compat", create=True, return_value=None), + ): + try: + flash_attn_backend.init_flash_attn_version() + except NameError: + pass + + self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) + + if __name__ == "__main__": unittest.main() From 300f323971ae5b11c4f542c0bd5a50016968ff0e Mon Sep 17 00:00:00 2001 From: bingoo <1575938147@qq.com> Date: Wed, 20 May 2026 14:32:21 +0800 Subject: [PATCH 3/5] replace paddlefleet.ops to paddlefleet_ops --- .../model_executor/layers/attention/flash_attn_backend.py | 4 ++-- tests/layers/test_flash_attn_func.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index ea94b05d438..1b608cd4c2f 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -86,10 +86,10 @@ def init_flash_attn_version(): try: paddle.enable_compat(scope={"cutlass"}) try: - from paddlefleet.ops import is_flash_mask_available + from paddlefleet_ops import is_flash_mask_available if is_flash_mask_available(): - from paddlefleet.ops.flash_mask.cute.interface import ( + from paddlefleet_ops.flash_mask.cute.interface import ( flashmask_attention as fa4, ) else: diff --git a/tests/layers/test_flash_attn_func.py b/tests/layers/test_flash_attn_func.py index 3d8fb75605c..0a6c740223e 100644 --- a/tests/layers/test_flash_attn_func.py +++ b/tests/layers/test_flash_attn_func.py @@ -220,10 +220,10 @@ def setUp(self): name: sys.modules.get(name) for name in ( "paddlefleet", - "paddlefleet.ops", - "paddlefleet.ops.flash_mask", - "paddlefleet.ops.flash_mask.cute", - "paddlefleet.ops.flash_mask.cute.interface", + "paddlefleet_ops", + "paddlefleet_ops.flash_mask", + "paddlefleet_ops.flash_mask.cute", + "paddlefleet_ops.flash_mask.cute.interface", ) } From f013e10aff41b85cc1cdc5dc488d13a2a63f3700 Mon Sep 17 00:00:00 2001 From: bingoo <1575938147@qq.com> Date: Wed, 20 May 2026 14:42:21 +0800 Subject: [PATCH 4/5] replace paddlefleet.ops to paddlefleet_ops --- tests/layers/test_flash_attn_func.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/layers/test_flash_attn_func.py b/tests/layers/test_flash_attn_func.py index 0a6c740223e..6cad6e6dc96 100644 --- a/tests/layers/test_flash_attn_func.py +++ b/tests/layers/test_flash_attn_func.py @@ -240,21 +240,21 @@ def _install_fake_paddlefleet(self, is_available: bool): """Inject fake paddlefleet modules so the inner imports succeed.""" pkg = types.ModuleType("paddlefleet") pkg.__path__ = [] - ops = types.ModuleType("paddlefleet.ops") + ops = types.ModuleType("paddlefleet_ops") ops.__path__ = [] ops.is_flash_mask_available = lambda: is_available - flash_mask = types.ModuleType("paddlefleet.ops.flash_mask") + flash_mask = types.ModuleType("paddlefleet_ops.flash_mask") flash_mask.__path__ = [] - cute = types.ModuleType("paddlefleet.ops.flash_mask.cute") + cute = types.ModuleType("paddlefleet_ops.flash_mask.cute") cute.__path__ = [] - interface = types.ModuleType("paddlefleet.ops.flash_mask.cute.interface") + interface = types.ModuleType("paddlefleet_ops.flash_mask.cute.interface") interface.flashmask_attention = mock.MagicMock(name="fa4") sys.modules["paddlefleet"] = pkg - sys.modules["paddlefleet.ops"] = ops - sys.modules["paddlefleet.ops.flash_mask"] = flash_mask - sys.modules["paddlefleet.ops.flash_mask.cute"] = cute - sys.modules["paddlefleet.ops.flash_mask.cute.interface"] = interface + sys.modules["paddlefleet_ops"] = ops + sys.modules["paddlefleet_ops.flash_mask"] = flash_mask + sys.modules["paddlefleet_ops.flash_mask.cute"] = cute + sys.modules["paddlefleet_ops.flash_mask.cute.interface"] = interface return interface.flashmask_attention def test_fa4_import_success(self): @@ -300,10 +300,10 @@ def test_fa4_paddlefleet_import_error(self): # Ensure paddlefleet import fails. for name in ( "paddlefleet", - "paddlefleet.ops", - "paddlefleet.ops.flash_mask", - "paddlefleet.ops.flash_mask.cute", - "paddlefleet.ops.flash_mask.cute.interface", + "paddlefleet_ops", + "paddlefleet_ops.flash_mask", + "paddlefleet_ops.flash_mask.cute", + "paddlefleet_ops.flash_mask.cute.interface", ): sys.modules.pop(name, None) flash_attn_backend.FLASH_ATTN_VERSION = None From 53336240aebc9d5332510cae49f184cc6d55a78e Mon Sep 17 00:00:00 2001 From: bingoo <1575938147@qq.com> Date: Wed, 20 May 2026 15:57:35 +0800 Subject: [PATCH 5/5] compat old api --- .../layers/attention/flash_attn_backend.py | 27 +++- tests/layers/test_flash_attn_func.py | 132 +++++++++++++----- 2 files changed, 118 insertions(+), 41 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 1b608cd4c2f..dd3040108c8 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -74,6 +74,8 @@ FLASH_ATTN_VERSION = None +from fastdeploy.model_executor.utils import try_import + def init_flash_attn_version(): """ @@ -86,14 +88,25 @@ def init_flash_attn_version(): try: paddle.enable_compat(scope={"cutlass"}) try: - from paddlefleet_ops import is_flash_mask_available - - if is_flash_mask_available(): - from paddlefleet_ops.flash_mask.cute.interface import ( - flashmask_attention as fa4, - ) + old_api = try_import(["paddlefleet.ops"]) + if old_api is not None: + from paddlefleet.ops import is_flash_mask_available + + if is_flash_mask_available(): + from paddlefleet.ops.flash_mask.cute.interface import ( + flashmask_attention as fa4, + ) + else: + raise ModuleNotFoundError("flash_mask not available.") else: - raise ModuleNotFoundError("flash_mask not available.") + from paddlefleet_ops import is_flash_mask_available + + if is_flash_mask_available(): + from paddlefleet_ops.flash_mask.cute.interface import ( + flashmask_attention as fa4, + ) + else: + raise ModuleNotFoundError("flash_mask not available.") except (ImportError, ModuleNotFoundError): logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.") diff --git a/tests/layers/test_flash_attn_func.py b/tests/layers/test_flash_attn_func.py index 6cad6e6dc96..4bd26f49393 100644 --- a/tests/layers/test_flash_attn_func.py +++ b/tests/layers/test_flash_attn_func.py @@ -212,20 +212,36 @@ def test_fa4(self): class TestInitFlashAttnVersion(unittest.TestCase): """Tests for the init_flash_attn_version FA4 import branch (sm>=100).""" + _MODULE_NAMES = ( + "paddlefleet", + "paddlefleet.ops", + "paddlefleet.ops.flash_mask", + "paddlefleet.ops.flash_mask.cute", + "paddlefleet.ops.flash_mask.cute.interface", + "paddlefleet_ops", + "paddlefleet_ops.flash_mask", + "paddlefleet_ops.flash_mask.cute", + "paddlefleet_ops.flash_mask.cute.interface", + ) + def setUp(self): # Save state to restore after each test. self._saved_version = flash_attn_backend.FLASH_ATTN_VERSION self._saved_v4 = flash_attn_backend.flashmask_attention_v4 - self._saved_modules = { - name: sys.modules.get(name) - for name in ( - "paddlefleet", - "paddlefleet_ops", - "paddlefleet_ops.flash_mask", - "paddlefleet_ops.flash_mask.cute", - "paddlefleet_ops.flash_mask.cute.interface", - ) - } + self._saved_modules = {name: sys.modules.get(name) for name in self._MODULE_NAMES} + # Make sure each test starts with a clean module state. + for name in self._MODULE_NAMES: + sys.modules.pop(name, None) + + def _block_old_api(self): + """Force `paddlefleet.ops` import to fail regardless of what is installed.""" + # Setting sys.modules[name] = None makes importlib.import_module raise ImportError. + sys.modules["paddlefleet"] = None + sys.modules["paddlefleet.ops"] = None + + def _block_new_api(self): + """Force `paddlefleet_ops` import to fail regardless of what is installed.""" + sys.modules["paddlefleet_ops"] = None def tearDown(self): flash_attn_backend.FLASH_ATTN_VERSION = self._saved_version @@ -236,10 +252,30 @@ def tearDown(self): else: sys.modules[name] = mod - def _install_fake_paddlefleet(self, is_available: bool): - """Inject fake paddlefleet modules so the inner imports succeed.""" + def _install_fake_paddlefleet_old_api(self, is_available: bool): + """Inject fake `paddlefleet.ops` (old API) modules.""" pkg = types.ModuleType("paddlefleet") pkg.__path__ = [] + ops = types.ModuleType("paddlefleet.ops") + ops.__path__ = [] + ops.is_flash_mask_available = lambda: is_available + pkg.ops = ops + flash_mask = types.ModuleType("paddlefleet.ops.flash_mask") + flash_mask.__path__ = [] + cute = types.ModuleType("paddlefleet.ops.flash_mask.cute") + cute.__path__ = [] + interface = types.ModuleType("paddlefleet.ops.flash_mask.cute.interface") + interface.flashmask_attention = mock.MagicMock(name="fa4_old") + + sys.modules["paddlefleet"] = pkg + sys.modules["paddlefleet.ops"] = ops + sys.modules["paddlefleet.ops.flash_mask"] = flash_mask + sys.modules["paddlefleet.ops.flash_mask.cute"] = cute + sys.modules["paddlefleet.ops.flash_mask.cute.interface"] = interface + return interface.flashmask_attention + + def _install_fake_paddlefleet_new_api(self, is_available: bool): + """Inject fake `paddlefleet_ops` (new API) modules.""" ops = types.ModuleType("paddlefleet_ops") ops.__path__ = [] ops.is_flash_mask_available = lambda: is_available @@ -248,18 +284,19 @@ def _install_fake_paddlefleet(self, is_available: bool): cute = types.ModuleType("paddlefleet_ops.flash_mask.cute") cute.__path__ = [] interface = types.ModuleType("paddlefleet_ops.flash_mask.cute.interface") - interface.flashmask_attention = mock.MagicMock(name="fa4") + interface.flashmask_attention = mock.MagicMock(name="fa4_new") - sys.modules["paddlefleet"] = pkg sys.modules["paddlefleet_ops"] = ops sys.modules["paddlefleet_ops.flash_mask"] = flash_mask sys.modules["paddlefleet_ops.flash_mask.cute"] = cute sys.modules["paddlefleet_ops.flash_mask.cute.interface"] = interface return interface.flashmask_attention - def test_fa4_import_success(self): - """Covers lines 88, 89, 91, 92 (is_flash_mask_available True branch).""" - fake_fa4 = self._install_fake_paddlefleet(is_available=True) + def test_fa4_old_api_import_success(self): + """Old API (`paddlefleet.ops`) is preferred when available.""" + fake_fa4 = self._install_fake_paddlefleet_old_api(is_available=True) + # Also install new API to verify the old API takes precedence. + new_fa4 = self._install_fake_paddlefleet_new_api(is_available=True) flash_attn_backend.FLASH_ATTN_VERSION = None flash_attn_backend.flashmask_attention_v4 = None @@ -272,10 +309,12 @@ def test_fa4_import_success(self): self.assertEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) self.assertIs(flash_attn_backend.flashmask_attention_v4, fake_fa4) + self.assertIsNot(flash_attn_backend.flashmask_attention_v4, new_fa4) - def test_fa4_flash_mask_unavailable(self): - """Covers lines 88, 89, 91, 96, 98, 99 (raise + except path).""" - self._install_fake_paddlefleet(is_available=False) + def test_fa4_old_api_flash_mask_unavailable(self): + """Old API present but `is_flash_mask_available` is False.""" + self._install_fake_paddlefleet_old_api(is_available=False) + self._block_new_api() flash_attn_backend.FLASH_ATTN_VERSION = None flash_attn_backend.flashmask_attention_v4 = None @@ -284,10 +323,6 @@ def test_fa4_flash_mask_unavailable(self): mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), mock.patch.object(paddle, "enable_compat", create=True, return_value=None), ): - # The inner except swallows ModuleNotFoundError, but `fa4` is then - # unbound, so the outer block raises NameError (not ImportError), - # which propagates. Verify the inner except actually executed by - # checking that FA4 was not selected. try: flash_attn_backend.init_flash_attn_version() except NameError: @@ -295,17 +330,46 @@ def test_fa4_flash_mask_unavailable(self): self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) - def test_fa4_paddlefleet_import_error(self): - """Covers lines 88, 89, 98, 99 (ImportError caught by inner except).""" - # Ensure paddlefleet import fails. - for name in ( - "paddlefleet", - "paddlefleet_ops", - "paddlefleet_ops.flash_mask", - "paddlefleet_ops.flash_mask.cute", - "paddlefleet_ops.flash_mask.cute.interface", + def test_fa4_new_api_import_success(self): + """Falls back to new API (`paddlefleet_ops`) when old API is missing.""" + fake_fa4 = self._install_fake_paddlefleet_new_api(is_available=True) + self._block_old_api() + flash_attn_backend.FLASH_ATTN_VERSION = None + flash_attn_backend.flashmask_attention_v4 = None + + with ( + mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True), + mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), + mock.patch.object(paddle, "enable_compat", create=True, return_value=None), ): - sys.modules.pop(name, None) + flash_attn_backend.init_flash_attn_version() + + self.assertEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) + self.assertIs(flash_attn_backend.flashmask_attention_v4, fake_fa4) + + def test_fa4_new_api_flash_mask_unavailable(self): + """New API present but `is_flash_mask_available` is False.""" + self._install_fake_paddlefleet_new_api(is_available=False) + self._block_old_api() + flash_attn_backend.FLASH_ATTN_VERSION = None + flash_attn_backend.flashmask_attention_v4 = None + + with ( + mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True), + mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100), + mock.patch.object(paddle, "enable_compat", create=True, return_value=None), + ): + try: + flash_attn_backend.init_flash_attn_version() + except NameError: + pass + + self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4) + + def test_fa4_paddlefleet_import_error(self): + """Neither old nor new API is importable.""" + self._block_old_api() + self._block_new_api() flash_attn_backend.FLASH_ATTN_VERSION = None flash_attn_backend.flashmask_attention_v4 = None